##### Authors: Luca Barbati (5082540), Roberto Lazzarini (4937188)

# Imports

In [None]:
import os
import random
import re
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, Input, Model
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import math
from enum import Enum
import cv2
import matplotlib.patches as patches

# Classes

| Class          | Depends On                        |
|----------------|-----------------------------------|
| DatasetManager | cfg                               |
| SiameseModel   | cfg                               |
| Visualizer     | cfg, DatasetManager                              |
| Controller     | cfg, DatasetManager, SiameseModel, Visualizer |


## 1. Config and enums

The `Config` class serves the purpose to store any global setting that needs to be accessed during the whole duration of the experiment.

In [None]:
class ModelType(Enum):
    DISTANCE_BASED = "distance_based"
    BINARY_CLASSIFICATION = "binary_classification"

class DistanceMetric(Enum):
    EUCLIDEAN = "euclidean"
    COSINE = "cosine"

In [None]:
class Config:
    def __init__(self, model_type, N_SHOTS=5, FORGERIES_ALPHA=0.5, ENVIRONMENT="kaggle", GIANT_EXTRA_SET="True"):
        # dealing with N_SHOT logic  
        self.MAX_SHOTS_PER_AUTHOR = 24
        self.MAX_SHOTS_PER_TRAIN = int(self.MAX_SHOTS_PER_AUTHOR/2) # leaving half of the possible dataset for the extra validation set     
        if int(N_SHOTS) > self.MAX_SHOTS_PER_TRAIN:
            raise ValueError(f"N_SHOTS cannot be greater than 24 (got {N_SHOTS})")
        self.N_SHOTS = int(N_SHOTS)
        
        # setting forgeries alpha
        self.FORGERIES_ALPHA = float(FORGERIES_ALPHA)
        
        # Setup directories
        self.ORG_DIR, self.FORG_DIR, self.OUTPUT_DIR = self._set_environment(ENVIRONMENT)
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        
        # Fixed parameters
        self.EPOCHS = 50
        self.TRAIN_RATIO = 0.7
        self.VAL_RATIO = 0.3
        self.IMG_SIZE = (155, 220)
        self.BATCH_SIZE = 8
        self.GIANT_EXTRA_SET = GIANT_EXTRA_SET # if True, the extra validation set will be composed of every positive or negative_forgery element not used in other datasets, if False, it will be proportional to the test_set size 
        
        # Reproducibility
        self.SEED = 42
        tf.keras.utils.set_random_seed(self.SEED)
        
        # Model type + metric
        self.MODEL_TYPE, self.DISTANCE_METRIC, self.MODEL_TYPE_STRING = self._choose_model_type(model_type)

    # Sets the correct directories to work either locally or on the Kaggle online environment
    def _set_environment(self, environment):
        dataset_name = "cedardataset"
        if environment == "kaggle":
            root = f"/kaggle/input/{dataset_name}/signatures"
            output_dir = "/kaggle/working/output"
        elif environment == "local":
            root = "./signatures"
            output_dir = "./output"
        else:
            raise ValueError(f"Unknown ENVIRONMENT: {environment}")
        org_dir = f"{root}/full_org"
        forg_dir = f"{root}/full_forg"
        self._am_i_in_the_right_environment(org_dir, forg_dir)
        os.makedirs(output_dir, exist_ok=True)
        return org_dir, forg_dir, output_dir

    # Chooses the model type and distance metric based on the provided model_type string
    def _choose_model_type(self, model_type):
        if model_type == "distance_euclidean":
            return ModelType.DISTANCE_BASED, DistanceMetric.EUCLIDEAN, model_type
        elif model_type == "distance_cosine":
            return ModelType.DISTANCE_BASED, DistanceMetric.COSINE, model_type
        elif model_type == "binary_classification":
            return ModelType.BINARY_CLASSIFICATION, None, model_type
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    # raises an error if the directories are set up incorrectly
    @staticmethod
    def _am_i_in_the_right_environment(org_dir, forg_dir):
        for d in [org_dir, forg_dir]:
            if not os.path.exists(d):
                raise RuntimeError("[ERROR] Wrong environment. Try switching it the other way around.")

## 2. Data layer

The `DatasetManager` class is responsible of:
- retrieving the data samples from the dataset
- building train, validation, test and extra sets on the basis of the chosen parameters
- transforming the raw elements of each set into their correspondent pre-processed counterparts, on the basis of the pre-processing parameters 

In [None]:
class DatasetManager:
    # regex patterns to extract author IDs and forgeries
    AUTHOR_RE = re.compile(r"original_(\d+)_\d+\.png")
    FORGERY_RE_TEMPLATE = r"forgeries_{}_\d+\.png"

    def __init__(self, config, blur_kernel_size=3, binarize=True, center=True):
        self.cfg = config
        
        # setting pre-processing parameters
        self.blur_kernel_size = blur_kernel_size
        self.binarize = binarize
        self.center = center
        
        # initialize directories and file lists
        self.file_list = self._load_file_list()
        self.author_ids = self._extract_all_author_ids()
        self.forg_dir_files = os.listdir(self.cfg.FORG_DIR)

        # initialize empty DataFrames for positive and negative pairs
        empty_cols = ["image1", "image2", "label"]
        self.pos_df = pd.DataFrame(columns=empty_cols)
        self.neg_df = pd.DataFrame(columns=empty_cols)
        
        # initialize empty DataFrames for train, validation, test, and extra sets
        self.train_df = pd.DataFrame(columns=empty_cols)
        self.val_df = pd.DataFrame(columns=empty_cols)
        self.test_df = pd.DataFrame(columns=empty_cols)
        self.extra_df = pd.DataFrame(columns=empty_cols)

    def prepare_datasets(self, verbose=True):
        positives_path_to_exclude = self._generate_positive_pairs()
        forgeries_path_to_exclude = self._generate_negative_pairs()

        self._split_data()
        train = self._make_numpy_dataset(self.train_df)
        val = self._make_numpy_dataset(self.val_df)
        test = self._make_numpy_dataset(self.test_df)
        len_test = len(self.test_df)

        extra_set = self._prepare_extra_set(positives_path_to_exclude, forgeries_path_to_exclude, len_test)
        
        if verbose:
            print(f"[INFO] Dataset prepared with {train[0][0].shape[0]} training samples, {val[0][0].shape[0]} validation samples, and {test[0][0].shape[0]} test samples.")
            if extra_set:
                print(f"[INFO] Extra set has size {extra_set[0][0].shape[0]}")
            else:
                print(f"[INFO] Extra set has size 0")

        return train, val, test, extra_set

    def _prepare_extra_set(self, positives_path_to_exclude, forgeries_path_to_exclude, len_test):
        all_forgeries_path = self._generate_negatives_forg(self.cfg.MAX_SHOTS_PER_AUTHOR, self.forg_dir_files)
        forgeries_path_to_use = [x for x in all_forgeries_path if x not in forgeries_path_to_exclude]

        all_positives_path = self._generate_all_positive_pairs()
        positives_path_to_use = [x for x in all_positives_path if x not in positives_path_to_exclude]

        if self.cfg.GIANT_EXTRA_SET:
            len_test = min(len(positives_path_to_use), len(forgeries_path_to_use))
        
        extra_set = []
        if len(forgeries_path_to_use) > 0 and len(positives_path_to_use) > 0:
            extra_set = forgeries_path_to_use[:len_test // 2] + positives_path_to_use[:len_test // 2]
            extra_set = pd.DataFrame(extra_set, columns=["image1", "image2", "label"])
            self.extra_df = extra_set.copy()
            extra_set = self._make_numpy_dataset(extra_set)
        return extra_set

    def load_and_preprocess(self, path):
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise FileNotFoundError(f"File not found: {path}")
        img = cv2.resize(img, (self.cfg.IMG_SIZE[1], self.cfg.IMG_SIZE[0]))
        img = cv2.GaussianBlur(img, (self.blur_kernel_size, self.blur_kernel_size), 1.0)
        if self.binarize:
            _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        img = img.astype(np.float32) / 255.0
        img = np.expand_dims(img, axis=-1)
        if self.center:
            img = self._center_signature(img)
        return img

    def _load_file_list(self):
        return [
            f for f in os.listdir(self.cfg.ORG_DIR) if f.lower().endswith('.png')
        ]

    def _extract_all_author_ids(self):
        return sorted({
            self._extract_author_id(f)
            for f in self.file_list
            if self._extract_author_id(f) is not None
        })

    def _extract_author_id(self, filename):
        match = self.AUTHOR_RE.match(filename)
        return match.group(1) if match else None

    def _generate_positive_pairs(self):
        pos_pairs = []
        n_pos = self.cfg.N_SHOTS
        for aut in self.author_ids:
            ref = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_1.png")
            for idx in range(2, n_pos + 2):
                gen = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_{idx}.png")
                if os.path.exists(ref) and os.path.exists(gen):
                    pos_pairs.append((ref, gen, 1))
        self.pos_df = pd.DataFrame(pos_pairs, columns=["image1", "image2", "label"])
        return pos_pairs

    def _generate_all_positive_pairs(self):
        """Generate all possible positive pairs for all authors, with no sampling."""
        pos_pairs = []
        for aut in self.author_ids:
            ref = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_1.png")
            for idx in range(1, self.cfg.MAX_SHOTS_PER_AUTHOR + 1):
                gen = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_{idx}.png")
                if os.path.exists(ref) and os.path.exists(gen):
                    pos_pairs.append((ref, gen, 1))
        return pos_pairs

    def _generate_negative_pairs(self):
        n_shots = self.cfg.N_SHOTS
        forgeries_alpha = self.cfg.FORGERIES_ALPHA

        if forgeries_alpha == 0.5:
            n_neg_cross = int(math.ceil(n_shots * 0.5))
            n_neg_forg = int(math.ceil(n_shots * 0.5))
        else:
            n_neg_cross = int(math.floor(n_shots * (1 - forgeries_alpha)))
            n_neg_forg = int(n_shots - n_neg_cross)
        
    
        neg_pairs = []
        neg_pairs += self._generate_negatives_auth(n_neg_cross)
        forgeries_to_exclude = self._generate_negatives_forg(n_neg_forg, self.forg_dir_files)
        neg_pairs += forgeries_to_exclude

        self.neg_df = pd.DataFrame(neg_pairs, columns=["image1", "image2", "label"])
        print(f"N_POS: {n_shots}, N_NEG_FORG: {n_neg_forg}, N_NEG_CROSS: {n_neg_cross}")
        return forgeries_to_exclude

    def _generate_negatives_auth(self, n_neg_cross):
        neg_pairs = []
        for aut in self.author_ids:
            ref = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_1.png")
            other_authors = [a for a in self.author_ids if a != aut]
            sampled_others = random.sample(other_authors, min(n_neg_cross, len(other_authors))) if other_authors else []
            for neg_aut in sampled_others:
                neg_img = os.path.join(self.cfg.ORG_DIR, f"original_{neg_aut}_2.png")
                if os.path.exists(neg_img):
                    neg_pairs.append((ref, neg_img, 0))
        return neg_pairs

    def _generate_negatives_forg(self, n_neg_forg, forg_dir_files):
        neg_pairs = []
        for aut in self.author_ids:
            ref = os.path.join(self.cfg.ORG_DIR, f"original_{aut}_1.png")
            forgery_pat = re.compile(self.FORGERY_RE_TEMPLATE.format(aut))
            forger_list = [f for f in forg_dir_files if forgery_pat.match(f)]
            sampled_forgers = random.sample(forger_list, min(n_neg_forg, len(forger_list))) if forger_list else []
            for forg_file in sampled_forgers:
                forg_path = os.path.join(self.cfg.FORG_DIR, forg_file)
                if os.path.exists(forg_path):
                    neg_pairs.append((ref, forg_path, 0))
        return neg_pairs

    def _split_data(self):
        df = pd.concat([self.pos_df, self.neg_df], ignore_index=True)
        trainval_df, self.test_df = train_test_split(
            df, train_size=self.cfg.TRAIN_RATIO, stratify=df["label"], random_state=self.cfg.SEED
        )
        self.train_df, self.val_df = train_test_split(
            trainval_df, test_size=self.cfg.VAL_RATIO, stratify=trainval_df["label"], random_state=self.cfg.SEED
        )

    def _center_signature(self, img):
        img_bin = (img.squeeze(-1) * 255).astype(np.uint8)
        canvas_h, canvas_w = img_bin.shape
        coords = cv2.findNonZero(255 - img_bin)
        if coords is None:
            canvas = np.ones((canvas_h, canvas_w), dtype=np.uint8) * 255
        else:
            x, y, w, h = cv2.boundingRect(coords)
            signature_crop = img_bin[y:y + h, x:x + w]
            scale = min(canvas_w / w, canvas_h / h, 1.0)
            new_w, new_h = int(w * scale), int(h * scale)
            signature_resized = cv2.resize(signature_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
            canvas = np.ones((canvas_h, canvas_w), dtype=np.uint8) * 255
            start_x = (canvas_w - new_w) // 2
            start_y = (canvas_h - new_h) // 2
            canvas[start_y:start_y + new_h, start_x:start_x + new_w] = signature_resized
        centered_img = (canvas.astype(np.float32) / 255.0)[..., np.newaxis]
        return centered_img

    def _preprocess_pair(self, p1, p2, label):
        img1 = self.load_and_preprocess(p1)
        img2 = self.load_and_preprocess(p2)
        return (img1, img2), float(label)

    def _make_numpy_dataset(self, df):
        images1, images2, labels = [], [], []
        for _, row in df.iterrows():
            try:
                (img1, img2), label = self._preprocess_pair(row['image1'], row['image2'], row['label'])
                images1.append(img1)
                images2.append(img2)
                labels.append(label)
            except Exception as e:
                print(f"[WARN] Could not process pair {row['image1']} {row['image2']}: {e}")
        images1 = np.stack(images1)
        images2 = np.stack(images2)
        labels = np.array(labels)
        return (images1, images2), labels

## 3. Model layer

The `SiameseModel` class is responsible of:
- building the model, depending on: 
    - the chosen strategy (either distance-based or binary classification) 
    - in case of distance-based models, the distance metric (cosine or euclidean). 
- training the model
- getting the predictions
- evaluating the predictions

In [None]:
class SiameseModel:
    def __init__(self, config):
        # initializes the Siamese model with the provided configuration
        self.cfg = config
        self.model = None
        self.history = None
        self.best_threshold = None
        
        # sets up the model type and distance metric
        self.strategy = self._pick_strategy()
        
        # sets up the callbacks for training
        self.callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                monitor="val_loss",
                mode="min",
                save_best_only=True,
                filepath=os.path.join(self.cfg.OUTPUT_DIR, "best_siamese_model.keras"),
                verbose=0,
            ),
            tf.keras.callbacks.EarlyStopping(
                monitor="val_loss",
                mode="min",
                patience=10,
                restore_best_weights=True,
                verbose=0,
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor="val_loss",
                factor=0.5,
                patience=3,
                min_lr=1e-7,
                verbose=0,
            ),
        ]
    def _pick_strategy(self):
        if self.cfg.MODEL_TYPE == ModelType.DISTANCE_BASED:
            return DistanceBasedStrategy(self.cfg)
        else:
            return BinaryClassificationStrategy(self.cfg)

    # builds the single branch that will compose the Siamese network
    @staticmethod
    def signet_cnn_simple(input_shape=(155, 220, 1)):
        inputs = Input(shape=input_shape)
        x = layers.Conv2D(32, (3, 3), padding='same')(inputs)
        x = layers.LeakyReLU()(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2,2))(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Conv2D(64, (3,3), padding='same')(x)
        x = layers.LeakyReLU()(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2,2))(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Conv2D(128, (3,3), padding='same')(x)
        x = layers.LeakyReLU()(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((2,2))(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Conv2D(256, (3,3), padding='same')(x)
        x = layers.LeakyReLU()(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D((4,4))(x)
        x = layers.Conv2D(256, (1,1))(x)
        x = layers.LeakyReLU()(x)
        x = layers.Flatten()(x)
        x = layers.Dense(128)(x)
        x = layers.ReLU()(x)
        x = layers.Dense(32)(x)
        x = layers.ReLU()(x)
        return Model(inputs, x, name="SigNetCNN_Simple")

    # builds the Siamese model using the previously established single branch and strategy
    def build(self):
        input_shape = self.cfg.IMG_SIZE + (1,)
        embedding_net = self.signet_cnn_simple(input_shape)
        input_a = layers.Input(shape=input_shape)
        input_b = layers.Input(shape=input_shape)
        output = self.strategy.build_model(embedding_net, input_a, input_b)
        self.model = models.Model([input_a, input_b], output)
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        metrics_config = self.strategy.get_model_metrics()
        self.model.compile(
            optimizer=optimizer,
            loss=metrics_config['loss'],
            metrics=metrics_config['metrics']
        )
        
    def train(self, x_train, y_train, x_val, y_val, epochs, callbacks=None):
        self.history = self.model.fit(
            x=[x_train[0], x_train[1]],
            y=y_train,
            validation_data=([x_val[0], x_val[1]], y_val),
            epochs=epochs,
            callbacks=self.callbacks,
            verbose=1,
            batch_size=self.cfg.BATCH_SIZE
        )
        return self.history

    def predict(self, x1, x2):
        return self.model.predict([x1, x2], verbose=0)

    def get_predictions_and_labels(self, dataset):
        (X1, X2), Y = dataset
        preds = self.model.predict([X1, X2], verbose=0)
        return preds.flatten(), Y.flatten()

    def evaluate_predictions(self, predictions, labels):
        results = self.strategy.evaluate_predictions(predictions, labels)
        # update best_threshold attribute if present
        if 'best_threshold' in results:
            self.best_threshold = results['best_threshold']
            #print(f"I am updating the threshold to {self.best_threshold}.")
        else:
            self.best_threshold = None
        return results

In [None]:
class DistanceBasedStrategy:
    def __init__(self, config):
        self.cfg = config

    def build_model(self, embedding_net, input_a, input_b):
        feat_a = embedding_net(input_a)
        feat_b = embedding_net(input_b)
        if self.cfg.DISTANCE_METRIC == DistanceMetric.EUCLIDEAN:
            distance_layer = layers.Lambda(self.euclidean_distance)([feat_a, feat_b])
        else:
            distance_layer = layers.Lambda(self.cosine_distance)([feat_a, feat_b])
        return distance_layer

    def get_model_metrics(self):
        return {'loss': self.contrastive_loss_improved, 'metrics': [self.siamese_accuracy]}

    @staticmethod
    def euclidean_distance(vects):
        x, y = vects
        sum_sq = tf.reduce_sum(tf.square(x - y), axis=1, keepdims=True)
        return tf.sqrt(tf.maximum(sum_sq, tf.keras.backend.epsilon()))

    @staticmethod
    def cosine_distance(vects):
        x, y = vects
        x = tf.nn.l2_normalize(x, axis=1)
        y = tf.nn.l2_normalize(y, axis=1)
        return 1 - tf.reduce_sum(x * y, axis=1, keepdims=True)

    @staticmethod
    def contrastive_loss_improved(y_true, y_pred, margin=1.0):
        y_true = tf.cast(y_true, tf.float32)
        positive_loss = y_true * tf.square(y_pred)
        negative_loss = (1 - y_true) * tf.square(tf.maximum(margin - y_pred, 0))
        return tf.reduce_mean(positive_loss + negative_loss)

    @staticmethod
    def siamese_accuracy(y_true, y_pred, threshold=1.0):
        y_true = tf.cast(y_true, tf.float32)
        predictions = tf.cast(y_pred < threshold, tf.float32)
        return tf.keras.metrics.binary_accuracy(y_true, predictions)

    def evaluate_predictions(self, predictions, labels):
        best_threshold, best_acc = self.find_best_threshold(predictions, labels)
        test_predictions = (predictions < best_threshold).astype(int)
        test_accuracy = (test_predictions == labels).mean()
        try:
            auc_score = roc_auc_score(labels, 1 - predictions)
        except Exception:
            auc_score = None
        return {
            'accuracy': test_accuracy,
            'auc': auc_score,
            'best_threshold': best_threshold,
            'predictions': test_predictions
        }

    @staticmethod
    def find_best_threshold(distances, labels):
        thresholds = np.linspace(distances.min(), distances.max(), 100)
        best_acc = 0
        best_thresh = 0
        for thresh in thresholds:
            predictions = (distances < thresh).astype(int)
            accuracy = (predictions == labels).mean()
            if accuracy > best_acc:
                best_acc = accuracy
                best_thresh = thresh
        return best_thresh, best_acc


In [None]:
class BinaryClassificationStrategy:
    def __init__(self, config):
        self.cfg = config

    def build_model(self, embedding_net, input_a, input_b):
        feat_a = embedding_net(input_a)
        feat_b = embedding_net(input_b)
        concatenated_emb = layers.Concatenate(axis=1)([feat_a, feat_b])
        x = layers.Dense(128, activation='relu')(concatenated_emb)
        x = layers.Dense(64, activation='relu')(x)
        x = layers.Dense(1, activation='sigmoid')(x)
        return x

    def get_model_metrics(self):
        return {'loss': 'binary_crossentropy', 'metrics': ['accuracy']}

    def evaluate_predictions(self, predictions, labels):
        test_predictions = (predictions >= 0.5).astype(int)
        test_accuracy = (test_predictions == labels).mean()
        try:
            auc_score = roc_auc_score(labels, predictions)
        except Exception:
            auc_score = None
        return {
            'accuracy': test_accuracy,
            'auc': auc_score,
            'best_threshold': 0.5,
            'predictions': test_predictions
        }


## 4. View Layer

As one can easily deduce, the `Visualizer` class is responsible of visually showcasing the model's results by showing its predictions and plots of various metrics. 

In [None]:
class Visualizer:
    def __init__(self, config):
        self.cfg = config
        self.dm = None # not explicitly a class-inherited parameter, but will be set later when the DatasetManager is created

    def plot_history(self, history):
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss')
        plt.legend()
        plt.subplot(1, 2, 2)
        key_acc = [k for k in history.history if 'accuracy' in k]
        for k in key_acc:
            plt.plot(history.history[k], label=k)
        plt.title('Accuracy')
        plt.legend()
        plt.tight_layout()
        plt.show()

    def plot_accuracy_vs_shots(self, shots_list, test_accs, forgery_accs):
        plt.figure(figsize=(8,5))
        plt.plot(shots_list, test_accs, '-o', label='Test accuracy')
        plt.plot(shots_list, forgery_accs, '-s', label='Forgery accuracy')
        plt.xlabel('N_SHOTS')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs. N_SHOTS')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_accuracy_vs_alpha(self, alpha_list, test_accs, forgery_accs):
        plt.figure(figsize=(8,5))
        plt.plot(alpha_list, test_accs, '-o', label='Test accuracy')
        plt.plot(alpha_list, forgery_accs, '-s', label='Forgery accuracy')
        plt.xlabel('FORGERIES_ALPHA')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs. FORGERIES_ALPHA')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def print_test_results_table(self, preds, labels, threshold):
        pred_labels = (preds < threshold).astype(int)  # For distance-based; adjust if using classification
        print("Prediction\tLabel\tCorrect")
        for i in range(len(labels)):
            print(f"{pred_labels[i]}\t\t{int(labels[i])}\t{pred_labels[i]==labels[i]}")

    def plot_test_examples(self, model, test_df, evaluation_results, outpath, num_examples=4):
        test_df_sample = test_df.sample(n=min(num_examples, len(test_df)), random_state=self.cfg.SEED)
        fig, axes = plt.subplots(2, num_examples, figsize=(4*num_examples, 8))
    
        for i, (_, row) in enumerate(test_df_sample.iterrows()):
            if not os.path.exists(row['image1']) or not os.path.exists(row['image2']):
                print(f"[WARN] Missing file: {row['image1']} or {row['image2']}. Skipping.")
                continue
    
            try:
                img1_disp = self.dm.load_and_preprocess(row['image1']).squeeze()
                img2_disp = self.dm.load_and_preprocess(row['image2']).squeeze()
            except Exception as e:
                print(f"[WARN] Error loading images: {e}. Skipping sample.")
                continue
    
            try:
                img1_for_model = img1_disp[None, ..., None] if img1_disp.ndim == 2 else img1_disp[None, ...]
                img2_for_model = img2_disp[None, ..., None] if img2_disp.ndim == 2 else img2_disp[None, ...]
                pred_value = model.model.predict([img1_for_model, img2_for_model], verbose=0)[0][0]
            except Exception as e:
                print(f"[WARN] Error in model prediction: {e}. Skipping sample.")
                continue
    
            # Determine prediction
            if self.cfg.MODEL_TYPE == ModelType.DISTANCE_BASED:
                prediction = "AUTHENTIC" if pred_value < evaluation_results['best_threshold'] else "DIFFERENT"
                pred_label = f'Dist: {pred_value:.2f}'
            else:
                prediction = "AUTHENTIC" if pred_value >= 0.5 else "DIFFERENT"
                pred_label = f'Prob: {pred_value:.2f}'
    
            actual = "AUTHENTIC" if row['label'] == 1 else "DIFFERENT"
            
            # Show images with black borders
            axes[0, i].imshow(img1_disp, cmap='gray')
            axes[0, i].set_title('Reference')
            axes[0, i].axis('off')
            rect1 = patches.Rectangle((0, 0), img1_disp.shape[1]-1, img1_disp.shape[0]-1, 
                                      linewidth=2, edgecolor='black', facecolor='none')
            axes[0, i].add_patch(rect1)
            
            axes[1, i].imshow(img2_disp, cmap='gray')
            axes[1, i].set_title(f'Query\n{pred_label}\nPred: {prediction}\nReal: {actual}')
            axes[1, i].axis('off')
            rect2 = patches.Rectangle((0, 0), img2_disp.shape[1]-1, img2_disp.shape[0]-1, 
                                      linewidth=2, edgecolor='black', facecolor='none')
            axes[1, i].add_patch(rect2)
    
        plt.tight_layout()
        plt.show()

## 5. Controller

The `ExperimentController` class serves to organize and script three experiment pipelines:
- Single experiment
- performance_analysis_n_shots
- performance_analysis_forgeries_alpha

In [None]:
class ExperimentController:
    def __init__(self, config, model_cls, data_cls, visualizer):
        self.cfg = config
        self.ModelCls = model_cls
        self.DataCls = data_cls
        self.visualizer = visualizer

    @staticmethod
    def compute_forgery_accuracy(model, extra):
        preds_extra, labels_extra = model.get_predictions_and_labels(extra)
        results_extra = model.evaluate_predictions(preds_extra, labels_extra)
        return results_extra['accuracy']

    @staticmethod
    def perc_false_positives(preds, labels):
        positive_mask = (labels == 1)
        if positive_mask.sum() == 0:
            return 0.0
        false_positives = ((labels == 1) & (preds == 0)).sum()
        return 100.0 * false_positives / positive_mask.sum()

    @staticmethod
    def perc_false_negatives(preds, labels):
        negative_mask = (labels == 0)
        if negative_mask.sum() == 0:
            return 0.0
        false_negatives = ((labels == 0) & (preds == 1)).sum()
        return 100.0 * false_negatives / negative_mask.sum()

    def single_experiment(self, blur_kernel_size=3, binarize=True, center=True):
        dm = self.DataCls(self.cfg, blur_kernel_size, binarize, center)
        train, val, test, extra = dm.prepare_datasets()
        model = self.ModelCls(self.cfg)
        model.build()
        history = model.train(train[0], train[1], val[0], val[1], self.cfg.EPOCHS)
        self.visualizer.plot_history(history)

        preds, labels = model.get_predictions_and_labels(test)
        results = model.evaluate_predictions(preds, labels)
        forgery_acc = self.compute_forgery_accuracy(model, extra)
        preds_extra, labels_extra = model.get_predictions_and_labels(extra)
        predictions_bin = model.strategy.evaluate_predictions(preds_extra, labels_extra)['predictions']
        perc_fp = self.perc_false_positives(predictions_bin, labels_extra)
        perc_fn = self.perc_false_negatives(predictions_bin, labels_extra)
        
        print(f"\nTest accuracy: {results['accuracy']:.3f} (AUC: {results['auc']:.3f})")
        print(f"\nForgery accuracy: {forgery_acc:.3f}")
        print(f"Percentage False Positive: {perc_fp:.2f}%")
        print(f"Percentage False Negative: {perc_fn:.2f}%")

        self.visualizer.dm = dm
        self.visualizer.plot_test_examples(model=model,test_df=dm.test_df,evaluation_results=results,outpath="test_examples.png",num_examples=4)
        self.visualizer.plot_test_examples(model=model,test_df=dm.extra_df,evaluation_results=results,outpath="extra_examples.png",num_examples=4)
        # return results, forgery_acc

    def performance_analysis_n_shots(self, shots_list, blur_kernel_size=3, binarize=True, center=True):
        test_accs, forgery_accs = [], []
        perc_fp_list, perc_fn_list = [], []
        for n_shots in shots_list:
            cfg = Config(self.cfg.MODEL_TYPE_STRING, N_SHOTS=n_shots, FORGERIES_ALPHA=self.cfg.FORGERIES_ALPHA)
            dm = self.DataCls(cfg, blur_kernel_size, binarize, center)
            train, val, test, extra = dm.prepare_datasets()
            model = self.ModelCls(cfg)
            model.build()
            model.train(train[0], train[1], val[0], val[1], cfg.EPOCHS)
            preds, labels = model.get_predictions_and_labels(test)
            results = model.evaluate_predictions(preds, labels)
            forgery_acc = self.compute_forgery_accuracy(model, extra)

            preds_extra, labels_extra = model.get_predictions_and_labels(extra)
            predictions_bin = model.strategy.evaluate_predictions(preds_extra, labels_extra)['predictions']
            perc_fp = self.perc_false_positives(predictions_bin, labels_extra)
            perc_fn = self.perc_false_negatives(predictions_bin, labels_extra)
            perc_fp_list.append(perc_fp)
            perc_fn_list.append(perc_fn)

            test_accs.append(results['accuracy'])
            forgery_accs.append(forgery_acc)
            print(f"[n_shots={n_shots}] Test Acc: {results['accuracy']:.3f}\tForgery Acc: {forgery_acc:.3f}\tFP%: {perc_fp:.2f}\tFN%: {perc_fn:.2f}")

        # Plot accuracy
        self.visualizer.plot_accuracy_vs_shots(shots_list, test_accs, forgery_accs)
        # Plot FP/FN percentage
        plt.figure(figsize=(8,5))
        plt.plot(shots_list, perc_fp_list, '-o', label='False Positive (%)', color='#6A5ACD')  
        plt.plot(shots_list, perc_fn_list, '-s', label='False Negative (%)', color='#228B22')  
        plt.xlabel('N_SHOTS')
        plt.ylabel('Error Percentage (%)')
        plt.title('False Positive and False Negative (%) vs N_SHOTS')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def performance_analysis_forgeries_alpha(self, alpha_list, blur_kernel_size=3, binarize=True, center=True):
        test_accs, forgery_accs = [], []
        perc_fp_list, perc_fn_list = [], []
        for alpha in alpha_list:
            cfg = Config(self.cfg.MODEL_TYPE_STRING, N_SHOTS=self.cfg.N_SHOTS, FORGERIES_ALPHA=alpha)
            dm = self.DataCls(cfg, blur_kernel_size, binarize, center)
            train, val, test, extra = dm.prepare_datasets()
            model = self.ModelCls(cfg)
            model.build()
            model.train(train[0], train[1], val[0], val[1], cfg.EPOCHS)
            preds, labels = model.get_predictions_and_labels(test)
            results = model.evaluate_predictions(preds, labels)
            forgery_acc = self.compute_forgery_accuracy(model, extra)

            preds_extra, labels_extra = model.get_predictions_and_labels(extra)
            predictions_bin = model.strategy.evaluate_predictions(preds_extra, labels_extra)['predictions']
            perc_fp = self.perc_false_positives(predictions_bin, labels_extra)
            perc_fn = self.perc_false_negatives(predictions_bin, labels_extra)
            perc_fp_list.append(perc_fp)
            perc_fn_list.append(perc_fn)

            test_accs.append(results['accuracy'])
            forgery_accs.append(forgery_acc)
            print(f"[alpha={alpha}] Test Acc: {results['accuracy']:.3f}\tForgery Acc: {forgery_acc:.3f}\tFP%: {perc_fp:.2f}\tFN%: {perc_fn:.2f}")

        # Plot accuracy
        self.visualizer.plot_accuracy_vs_alpha(alpha_list, test_accs, forgery_accs)
        # Plot FP/FN percentage
        plt.figure(figsize=(8,5))
        plt.plot(alpha_list, perc_fp_list, '-o', label='False Positive (%)', color='#6A5ACD')  
        plt.plot(alpha_list, perc_fn_list, '-s', label='False Negative (%)', color='#228B22')  
        plt.xlabel('FORGERIES_ALPHA')
        plt.ylabel('Error Percentage (%)')
        plt.title('False Positive and False Negative (%) vs FORGERIES_ALPHA')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

# Main

### Configuration

In [None]:
cfg = Config(model_type="distance_euclidean", N_SHOTS=9, FORGERIES_ALPHA=0.75, ENVIRONMENT="kaggle")
visualizer = Visualizer(cfg)
controller = ExperimentController(cfg, SiameseModel, DatasetManager, visualizer)

### Performing the experiments

In [None]:
controller.single_experiment(blur_kernel_size=13, binarize=True, center=True)

In [None]:
controller.performance_analysis_forgeries_alpha([0, 0.1, 0.25, 0.5, 0.6, 0.75, 1.0], blur_kernel_size=13, binarize=True, center=True)

In [None]:
controller.performance_analysis_n_shots([1,2,3,4,5,6,7,8,9,10], blur_kernel_size=13, binarize=True, center=True)