In [1]:
import os
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split

from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, VGAE

import os
import kagglehub
from kagglehub import KaggleDatasetAdapter

import pandas as pd

from tqdm import tqdm
from tqdm.contrib import tmap
from tqdm.contrib.concurrent import process_map

from torchvision import transforms

from concurrent.futures import ProcessPoolExecutor

from lib.lib import SiameseSignatureDataset

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [2]:
# Hyperparameters
learning_rate = 1e-3
w_d = 1e-5
epochs = 50
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
df = pd.read_csv('data.csv')

def dataset_path():
    path = kagglehub.dataset_download("mallapraveen/signature-matching")
    return os.path.join(path, 'custom\\full')

def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])
    
dataset = SiameseSignatureDataset(
    root_dir=dataset_path(),
    signer_folders=df,
    transform=transform(num_output_channels=1, resize=(32, 32)
))

Loaded 85246 signature images (genuine + forged)


In [4]:
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes - Train: {train_size}, Validation: {val_size}")

Dataset sizes - Train: 68196, Validation: 17050


In [5]:
train_dataset[0]

(Data(x=[1024, 3], edge_index=[2, 3968]),
 Data(x=[1024, 3], edge_index=[2, 3968]),
 0)

In [6]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4
)

In [7]:
next(iter(train_loader))

[DataBatch(x=[32768, 3], edge_index=[2, 126976], batch=[32768], ptr=[33]),
 DataBatch(x=[32768, 3], edge_index=[2, 126976], batch=[32768], ptr=[33]),
 tensor([1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,
         0, 1, 1, 0, 1, 0, 1, 0])]

## Above is the data preperation
# Now let's proceed to the creation of the model

In [8]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_dim)
        self.conv_logvar = GCNConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index):
        # Step 1: Aggregate node features from neighbors
        x = F.relu(self.conv1(x, edge_index))

        # Step 2: Output mean and log variance
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)

        return mu, logvar

In [9]:
class SiameseModel(nn.Module):
    def __init__(self, model, latent_dim, output_dim=2):
        super(SiameseModel, self).__init__()
        
        # GNN Encoder (shared between both inputs)
        self.encoder = model
        
        # Fully connected layers for final embedding
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            
            nn.Linear(64, output_dim)
        )
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick for VAE"""
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu
    
    def forward_once(self, x, edge_index, batch):
        """Forward pass for one graph"""
        # Get mu and logvar from encoder
        mu, logvar = self.encoder(x, edge_index)
    
        # Reparameterize to get latent representation
        z = self.reparameterize(mu, logvar)
    
        # ✅ Pool node embeddings into per-graph embeddings
        z_graph = global_mean_pool(z, batch)  # [num_graphs, latent_dim]
    
        # Pass through fully connected layers
        output = self.fc(z_graph)  # [num_graphs, output_dim]
    
        return output, mu, logvar
    
    def forward(self, x1, edge_index1, batch1, x2, edge_index2, batch2):
        output1, mu1, logvar1 = self.forward_once(x1, edge_index1, batch1)
        output2, mu2, logvar2 = self.forward_once(x2, edge_index2, batch2)
        return output1, output2, (mu1, logvar1), (mu2, logvar2)

In [10]:
class ContrastiveLoss(nn.Module):
    "Contrastive loss function"

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2)
            + (label)
            * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )

        return loss_contrastive

In [11]:
img1, _, _ = next(iter(train_loader))
input_dim = img1.x.shape[1]
hidden_dim = 64
latent_dim = 128

In [12]:
# Load your trained GNN-VAE
checkpoint = torch.load('VGAE_Model.pt', map_location=device)
vgae = VGAE(GNNEncoder(in_channels=input_dim, hidden_channels=hidden_dim, latent_dim=latent_dim)).to(device)
vgae.load_state_dict(checkpoint)
vgae.eval()

VGAE(
  (encoder): GNNEncoder(
    (conv1): GCNConv(3, 64)
    (conv_mu): GCNConv(64, 128)
    (conv_logvar): GCNConv(64, 128)
  )
  (decoder): InnerProductDecoder()
)

In [13]:
contrastive_model = SiameseModel(
    model=vgae,
    latent_dim=128,
).to(device)

In [14]:
def train_epoch(model, optimizer, train_loader, criterion, device):
    """
    Train for one epoch.
    
    Args:
        model: Siamese network model
        optimizer: Optimizer
        train_loader: DataLoader providing (graph1, graph2, label) tuples
        criterion: ContrastiveLoss
        device: torch device
    
    Returns:
        avg_loss: Average training loss
        accuracy: Training accuracy
    """
    model.train()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    for batch_idx, (graph1, graph2, labels) in tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc="Training",
        leave=False
    ):
        # Move data to device
        x1 = graph1.x.to(device)
        edge_index1 = graph1.edge_index.to(device)
        x2 = graph2.x.to(device)
        edge_index2 = graph2.edge_index.to(device)
        labels = labels.to(device).float()
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output1, output2, (mu1, logvar1), (mu2, logvar2) = model(
            graph1.x, graph1.edge_index, graph1.batch.to(device),
            graph2.x, graph2.edge_index, graph2.batch.to(device)
        )

        print(f"output1: {output1.shape}, graph1.batch: {graph1.batch.shape}")
        break  # just to inspect one batch
        
        # Compute contrastive loss
        loss = criterion(output1, output2, labels)
        
        # Optionally add KL divergence loss for VGAE
        kl_weight = 0.001
        kl_loss1 = -0.5 * torch.sum(1 + logvar1 - mu1.pow(2) - logvar1.exp())
        kl_loss2 = -0.5 * torch.sum(1 + logvar2 - mu2.pow(2) - logvar2.exp())
        kl_loss = (kl_loss1 + kl_loss2) / 2
        
        total_loss_with_kl = loss + kl_weight * kl_loss
        
        # Backward pass
        total_loss_with_kl.backward()
        optimizer.step()
        
        # Calculate predictions
        # euclidean_distance = F.pairwise_distance(output1, output2)
        # predictions = (euclidean_distance > 1.0).float()  # Threshold at 1.0

        graph_emb1 = output1
        graph_emb2 = output2
        
        euclidean_distance = F.pairwise_distance(graph_emb1, graph_emb2)
        predictions = (euclidean_distance > 1.0).float()
        
        # Store for metrics
        all_predictions.extend(predictions.cpu().detach().numpy())
        all_labels.extend(labels.cpu().detach().numpy())
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    
    return avg_loss, accuracy


def validate_epoch(model, val_loader, criterion, device):
    """
    Validate the model.
    
    Args:
        model: Siamese network model
        val_loader: Validation DataLoader
        criterion: ContrastiveLoss
        device: torch device
    
    Returns:
        Dictionary with validation metrics
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    all_distances = []
    
    with torch.no_grad():
        for batch_idx, (graph1, graph2, labels) in tqdm(
            enumerate(val_loader),
            total=len(val_loader),
            desc="Validating",
            leave=False
        ):
            # Move data to device
            x1 = graph1.x.to(device)
            edge_index1 = graph1.edge_index.to(device)
            x2 = graph2.x.to(device)
            edge_index2 = graph2.edge_index.to(device)
            labels = labels.to(device).float()
            
            # Forward pass
            output1, output2, (mu1, logvar1), (mu2, logvar2) = model(
                graph1.x, graph1.edge_index, graph1.batch.to(device),
                graph2.x, graph2.edge_index, graph2.batch.to(device)
            )
            
            # Compute loss
            loss = criterion(output1, output2, labels)
            total_loss += loss.item()
            
            # Calculate predictions
            graph_emb1 = output1
            graph_emb2 = output2
            
            euclidean_distance = F.pairwise_distance(graph_emb1, graph_emb2)
            predictions = (euclidean_distance > 1.0).float()
            
            # Store for metrics
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_distances.extend(euclidean_distance.cpu().numpy())
    
    # Calculate metrics
    avg_loss = total_loss / len(val_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, zero_division=0)
    recall = recall_score(all_labels, all_predictions, zero_division=0)
    f1 = f1_score(all_labels, all_predictions, zero_division=0)
    
    # For AUC, use distances as scores (inverted since lower distance = more similar)
    # Convert labels: 0 (similar) -> 1, 1 (dissimilar) -> 0 for distance-based scoring
    inverted_labels = 1 - np.array(all_labels)
    inverted_distances = -np.array(all_distances)  # Lower distance = higher score
    
    try:
        auc = roc_auc_score(inverted_labels, inverted_distances)
    except:
        auc = 0.0
    
    return {
        'val_loss': avg_loss,
        'val_accuracy': accuracy,
        'val_precision': precision,
        'val_recall': recall,
        'val_f1': f1,
        'val_auc': auc
    }

In [15]:
criterion = ContrastiveLoss(margin=2.0)
    
# Optimizer and loss
optimizer = torch.optim.Adam(
    contrastive_model.parameters(),
    lr=0.0001,
    weight_decay=1e-5
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
)

# Training settings
epochs = 50
best_val_loss = float('inf')
best_val_f1 = 0.0
patience = 10
patience_counter = 0

In [16]:
# TensorBoard setup
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f'runs/siamese_{timestamp}')

print("\n" + "="*80)
print("TRAINING SIAMESE NETWORK")
print("="*80)

# Training loop
for epoch in range(1, epochs + 1):
    # Training
    train_loss, train_accuracy = train_epoch(
        contrastive_model,
        optimizer,
        train_loader,
        criterion,
        device
    )
    
    # Validation
    val_metrics = validate_epoch(
        contrastive_model,
        val_loader,
        criterion,
        device
    )
    
    # Learning rate scheduling
    scheduler.step(val_metrics['val_loss'])
    
    # Logging to tensorboard
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_metrics['val_loss'], epoch)
    writer.add_scalar('Accuracy/train', train_accuracy, epoch)
    writer.add_scalar('Accuracy/val', val_metrics['val_accuracy'], epoch)
    writer.add_scalar('Metrics/precision', val_metrics['val_precision'], epoch)
    writer.add_scalar('Metrics/recall', val_metrics['val_recall'], epoch)
    writer.add_scalar('Metrics/f1', val_metrics['val_f1'], epoch)
    writer.add_scalar('Metrics/auc', val_metrics['val_auc'], epoch)
    
    # Print progress
    print(f"\nEpoch {epoch:03d}/{epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f}")
    print(f"  Val Loss:   {val_metrics['val_loss']:.4f} | Val Acc:  {val_metrics['val_accuracy']:.4f}")
    print(f"  Precision: {val_metrics['val_precision']:.4f} | Recall: {val_metrics['val_recall']:.4f}")
    print(f"  F1 Score:  {val_metrics['val_f1']:.4f} | AUC:    {val_metrics['val_auc']:.4f}")
    
    # Save best model based on validation F1 score
    if val_metrics['val_f1'] > best_val_f1:
        best_val_f1 = val_metrics['val_f1']
        best_val_loss = val_metrics['val_loss']
        patience_counter = 0
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': contrastive_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['val_loss'],
            'val_f1': val_metrics['val_f1'],
            'val_metrics': val_metrics
        }, f'best_siamese_{timestamp}.pt')
        
        print(f"  ✓ Best model saved! F1: {best_val_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch}")
            break

writer.close()

print("\n" + "="*80)
print("✓ TRAINING COMPLETE!")
print("="*80)
print(f"Best Validation Loss: {best_val_loss:.4f}")
print(f"Best Validation F1 Score: {best_val_f1:.4f}")


TRAINING SIAMESE NETWORK


                                                                                                                       

RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_mm)