In [None]:
import os
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split

from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, VGAE

import os
import kagglehub
from kagglehub import KaggleDatasetAdapter

import pandas as pd

from tqdm import tqdm
from tqdm.contrib import tmap
from tqdm.contrib.concurrent import process_map

from torchvision import transforms

from concurrent.futures import ProcessPoolExecutor

from lib.lib import SiameseSignatureDataset

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Data Preparation

## prepare data from mallapraveen/signature-matching
## and construct it using data.csv

In [None]:
df = pd.read_csv('data.csv')

def dataset_path():
    path = kagglehub.dataset_download("mallapraveen/signature-matching")
    return os.path.join(path, 'custom\\full')

def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])
    
dataset = SiameseSignatureDataset(
    root_dir=dataset_path(),
    signer_folders=df,
    transform=transform(num_output_channels=1, resize=(32, 32)
))

## split the data
### train dataset & validation dataset

In [None]:
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes - Train: {train_size}, Validation: {val_size}")

In [None]:
train_dataset[0]

## load the data using dataloader

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4
)

In [None]:
next(iter(train_loader))

# Model Preparation

In [None]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_dim)
        self.conv_logvar = GCNConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index):
        # Step 1: Aggregate node features from neighbors
        x = F.relu(self.conv1(x, edge_index))

        # Step 2: Output mean and log variance
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)

        return mu, logvar

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, fe_model, latent_dim):
        super(SiameseNetwork, self).__init__()
        self.encoder = fe_model
        self.embedding_dim = latent_dim
        
        self.projector = nn.Sequential(
            nn.Linear(self.embedding_dim * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),

            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),

            nn.Linear(64, 2)
        )

    def forward_once(self, x, edge_index, batch):
        mu, _ = self.encoder(x, edge_index)
        graph_emb = global_mean_pool(mu, batch) 
        # x = torch.flatten(x, 1)
        # return x
        return graph_emb

    def forward(self, x1, x2,
               edge_index1, edge_index2,
               batch):
        emb1 = self.forward_once(x1, edge_index1, batch)
        emb2 = self.forward_once(x2, edge_index2, batch)

        # Combine embeddings (abs difference works well for verification)
        combined = torch.cat([
            emb1,
            emb2,
            torch.abs(emb1 - emb2),
            emb1 * emb2
        ], dim=1)

        # Predict same/forged
        out = self.projector(combined)
        return out

# Hyperparameters

In [None]:
w_d = 1e-5
epochs = 50
learning_rate = 1e-3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

# Training Preparation

In [None]:
img1, _, _ = next(iter(train_loader))

input_dim = img1.x.shape[1]
hidden_dim = 64
latent_dim = 128

In [None]:
# Load your trained GNN-VAE
checkpoint = torch.load('VGAE_Model.pt', map_location=device)
vgae = VGAE(GNNEncoder(in_channels=input_dim, hidden_channels=hidden_dim, latent_dim=latent_dim)).to(device)
vgae.load_state_dict(checkpoint)
vgae.eval()

In [None]:
model = SiameseNetwork(vgae, latent_dim=128).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## train steps

In [None]:
def train_step(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for x1, x2, label in tqdm(dataloader, desc="Training", leave=False):
        x1, x2, label = x1.to(device), x2.to(device), label.to(device)

        # Forward
        output = model(x1.x.to(device),
                    x2.x.to(device),
                    x1.edge_index.to(device),
                    x2.edge_index.to(device),
                    x1.batch)  # logits shape: [batch, 2]
        
        loss = criterion(output, label)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics
        total_loss += loss.item() * x1.size(0)
        preds = torch.argmax(output, dim=1)
        correct += (preds == label).sum().item()
        total += label.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

## validation steps

In [None]:
def val_step(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for x1, x2, label in tqdm(dataloader, desc="Validating", leave=False):
            x1, x2, label = x1.to(device), x2.to(device), label.to(device)
            output = model(x1.x.to(device),
                    x2.x.to(device),
                    x1.edge_index.to(device),
                    x2.edge_index.to(device),
                    x1.batch)

            loss = criterion(output, label)
            total_loss += loss.item() * x1.size(0)

            probs = torch.softmax(output, dim=1)[:, 1]  # Probability of class 1 ("genuine")
            preds = torch.argmax(output, dim=1)

            all_labels.extend(label.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

            correct += (preds == label).sum().item()
            total += label.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy, np.array(all_labels), np.array(all_preds), np.array(all_probs)


# Training Phase

In [None]:
writer = SummaryWriter(log_dir="runs/siamese_signature_experiment")

patience = 10
best_auc = 0.0
best_val_loss = float('inf')  # start with infinity
wait = 0  # counter for early stopping

for epoch in range(epochs):
    train_loss, train_acc = train_step(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, y_true, y_pred, y_prob = val_step(model, val_loader, criterion, device)

    # --- Confusion Matrix ---
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    fig_cm, ax_cm = plt.subplots(figsize=(4, 4))
    disp.plot(ax=ax_cm, cmap="Blues", colorbar=False)
    writer.add_figure("ConfusionMatrix/val", fig_cm, global_step=epoch)
    plt.close(fig_cm)

    # --- ROC Curve and AUC ---
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    fig_roc, ax_roc = plt.subplots()
    ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f"AUC = {roc_auc:.3f}")
    ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
    ax_roc.set_xlabel("False Positive Rate")
    ax_roc.set_ylabel("True Positive Rate")
    ax_roc.legend(loc="lower right")
    writer.add_figure("ROC/val", fig_roc, global_step=epoch)
    plt.close(fig_roc)

    # --- Precision, Recall, F1 ---
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary"
    )

    # Log scalar metrics
    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss/val", val_loss, epoch)
    writer.add_scalar("Accuracy/train", train_acc, epoch)
    writer.add_scalar("Accuracy/val", val_acc, epoch)
    writer.add_scalar("AUC/val", roc_auc, epoch)
    writer.add_scalar("Precision/val", precision, epoch)
    writer.add_scalar("Recall/val", recall, epoch)
    writer.add_scalar("F1/val", f1, epoch)

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} "
          f"| Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} "
          f"| AUC: {roc_auc:.4f} | F1: {f1:.4f}")

    # --- Save best model by AUC ---
    if roc_auc > best_auc:
        best_auc = roc_auc
        torch.save(model.state_dict(), "best_siamese_signature.pth")

    # --- Early stopping based on validation loss ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0  # reset counter if improved
        torch.save(model.state_dict(), os.path.join(writer.log_dir, "best_vgae_model.pth"))
    else:
        wait += 1
        if wait >= patience:
            print(f"⏹️ Early stopping triggered at epoch {epoch+1}!")
            break

writer.close()
