In [1]:
import os
import gc

import torch
import torch.nn as nn
from torch.utils.data import random_split

from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, VGAE

import os
import kagglehub
from kagglehub import KaggleDatasetAdapter

import pandas as pd

from tqdm import tqdm
from tqdm.contrib import tmap
from tqdm.contrib.concurrent import process_map

from torchvision import transforms

from concurrent.futures import ProcessPoolExecutor

from lib.lib import SiameseSignatureDataset, image_to_graph

In [2]:
# Hyperparameters
learning_rate = 1e-3
w_d = 1e-5
batch_size = 32
epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
df = pd.read_csv('data.csv')

def dataset_path():
    path = kagglehub.dataset_download("mallapraveen/signature-matching")
    return os.path.join(path, 'custom\\full')

def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])
    
dataset = SiameseSignatureDataset(
    root_dir=dataset_path(),
    signer_folders=df,
    transform=transform(num_output_channels=1, resize=(150, 150))
)

Loaded 85246 signature images (genuine + forged)


In [5]:
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes - Train: {train_size}, Validation: {val_size}")

Dataset sizes - Train: 68196, Validation: 17050


In [18]:
train_dataset[0]

(Data(x=[1024, 3], edge_index=[2, 3968]),
 Data(x=[1024, 3], edge_index=[2, 3968]),
 0)

In [31]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)

In [32]:
next(iter(train_loader))

[DataBatch(x=[16384, 3], edge_index=[2, 63488], batch=[16384], ptr=[17]),
 DataBatch(x=[16384, 3], edge_index=[2, 63488], batch=[16384], ptr=[17]),
 tensor([1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0])]

## Above is the data preperation
# Now let's proceed to the creation of the model

In [33]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_dim)
        self.conv_logvar = GCNConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index):
        # Step 1: Aggregate node features from neighbors
        x = F.relu(self.conv1(x, edge_index))

        # Step 2: Output mean and log variance
        mu = self.conv_mu(x, edge_index, batch)
        logvar = self.conv_logvar(x, edge_index, batch)

        return mu, logvar

In [47]:
class ContrastiveSiameseNetwork(nn.Module):
    """
    Siamese Network using Contrastive Loss
    Better for signature verification with distance-based similarity
    """
    def __init__(self, gnn_vae_model, latent_dim=32, margin=2.0):
        super().__init__()
        
        self.gnn_vae = gnn_vae_model
        self.latent_dim = latent_dim
        self.margin = margin  # Margin for contrastive loss
        
        # Freeze GNN-VAE
        for param in self.gnn_vae.parameters():
            param.requires_grad = False
        
        # Optional: Additional projection layer
        self.projection = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
    
    def forward_one(self, x, edge_index, batch=None):
        """Extract and project features"""
        with torch.no_grad():
            embedding = self.gnn_vae.encode(x, edge_index, batch)
        
        # Optional projection
        embedding = self.projection(embedding)
        
        return embedding
    
    def forward(self, x1, edge_index1, x2, edge_index2, batch1=None, batch2=None):
        """
        Returns embeddings and Euclidean distance
        """
        emb1 = self.forward_one(x1, edge_index1, batch1)
        emb2 = self.forward_one(x2, edge_index2, batch2)
        
        # Compute Euclidean distance
        distance = F.pairwise_distance(emb1, emb2)
        
        return distance, emb1, emb2
    
    def predict(self, x1, edge_index1, x2, edge_index2, threshold=1.0):
        """
        Predict based on distance threshold
        """
        self.eval()
        with torch.no_grad():
            distance, _, _ = self.forward(x1, edge_index1, x2, edge_index2)
            is_same_person = distance < threshold
            return is_same_person.item(), distance.item()

In [48]:
def contrastive_loss(distance, label, margin=2.0):
    """
    Contrastive loss for Siamese networks
    
    Args:
        distance: Euclidean distance between embeddings
        label: 1 if same person, 0 if different
        margin: Margin for negative pairs
    
    Returns:
        loss: Contrastive loss value
    """
    loss = torch.mean(
        label * torch.pow(distance, 2) +  # Same person: minimize distance
        (1 - label) * torch.pow(torch.clamp(margin - distance, min=0.0), 2)  # Different: maximize distance
    )
    return loss

In [49]:
img1, _, _ = next(iter(train_loader))

input_dim = img1.x.shape[1]
hidden_dim = 64
latent_dim = 128
# epochs = 500
epochs = 5

In [50]:
# Load your trained GNN-VAE
checkpoint = torch.load('VGAE_Model.pt')
vgae = VGAE(GNNEncoder(in_channels=input_dim, hidden_channels=hidden_dim, latent_dim=latent_dim))
vgae.load_state_dict(checkpoint)
vgae.eval()

VGAE(
  (encoder): GNNEncoder(
    (conv1): GCNConv(3, 64)
    (conv_mu): GCNConv(64, 128)
    (conv_logvar): GCNConv(64, 128)
  )
  (decoder): InnerProductDecoder()
)

In [51]:
contrastive_model = ContrastiveSiameseNetwork(
    gnn_vae_model=vgae,
    latent_dim=128,
    margin=2.0
).to(device)

In [52]:
def train_epoch(model, optimizer, train_loader, criterion, device):
    """Train for one epoch"""
    model.train()
    
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    for batch1, batch2, labels in train_loader:
        batch1 = batch1.to(device)
        batch2 = batch2.to(device)
        labels = labels.to(device).unsqueeze(1)
        
        optimizer.zero_grad()
        
        # Forward pass
        similarity, emb1, emb2 = model(
            batch1.x, batch1.edge_index,
            batch2.x, batch2.edge_index,
            batch1.batch, batch2.batch
        )
        
        # Compute loss
        loss = criterion(similarity, labels)
        
        loss.backward()
        optimizer.step()
        
        # Collect predictions
        predictions = (similarity > 0.5).float()
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        total_loss += loss.item()
    
    # Compute metrics
    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    
    return avg_loss, accuracy


def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    
    total_loss = 0
    all_predictions = []
    all_labels = []
    all_similarities = []
    
    with torch.no_grad():
        for batch1, batch2, labels in val_loader:
            batch1 = batch1.to(device)
            batch2 = batch2.to(device)
            labels = labels.to(device).unsqueeze(1)
            
            # Forward pass
            similarity, emb1, emb2 = model(
                batch1.x, batch1.edge_index,
                batch2.x, batch2.edge_index,
                batch1.batch, batch2.batch
            )
            
            # Compute loss
            loss = criterion(similarity, labels)
            
            # Collect predictions
            predictions = (similarity > 0.5).float()
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_similarities.extend(similarity.cpu().numpy())
            
            total_loss += loss.item()
    
    # Compute metrics
    avg_loss = total_loss / len(val_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, zero_division=0)
    recall = recall_score(all_labels, all_predictions, zero_division=0)
    f1 = f1_score(all_labels, all_predictions, zero_division=0)
    
    # AUC-ROC
    try:
        auc = roc_auc_score(all_labels, all_similarities)
    except:
        auc = 0.0
    
    metrics = {
        'val_loss': avg_loss,
        'val_accuracy': accuracy,
        'val_precision': precision,
        'val_recall': recall,
        'val_f1': f1,
        'val_auc': auc
    }
    
    return metrics

In [53]:
criterion = contrastive_loss

# Optimizer and loss
optimizer = torch.optim.Adam(
    contrastive_model.parameters(),
    lr=0.0001,
    weight_decay=1e-5
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    # verbose=True
)

# Training settings
epochs = 50
best_val_loss = float('inf')
best_val_f1 = 0.0
patience = 10
patience_counter = 0

# TensorBoard (optional)
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f'runs/siamese_{timestamp}')

print("\n" + "="*80)
print("TRAINING SIAMESE NETWORK")
print("="*80)

# Training loop
for epoch in range(1, epochs + 1):
    # Training
    train_loss, train_accuracy = train_epoch(
        contrastive_model,
        optimizer,
        train_loader,
        criterion,
        device
    )
    
    # Validation
    val_metrics = validate_epoch(
        contrastive_model,
        val_loader,
        criterion,
        device
    )
    
    # Learning rate scheduling
    scheduler.step(val_metrics['val_loss'])
    
    # Logging to tensorboard
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_metrics['val_loss'], epoch)
    writer.add_scalar('Accuracy/train', train_accuracy, epoch)
    writer.add_scalar('Accuracy/val', val_metrics['val_accuracy'], epoch)
    writer.add_scalar('Metrics/precision', val_metrics['val_precision'], epoch)
    writer.add_scalar('Metrics/recall', val_metrics['val_recall'], epoch)
    writer.add_scalar('Metrics/f1', val_metrics['val_f1'], epoch)
    writer.add_scalar('Metrics/auc', val_metrics['val_auc'], epoch)
    
    # Print progress
    if epoch % 5 == 0 or epoch == 1:
        print(f"\nEpoch {epoch:03d}/{epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f}")
        print(f"  Val Loss:   {val_metrics['val_loss']:.4f} | Val Acc:  {val_metrics['val_accuracy']:.4f}")
        print(f"  Precision: {val_metrics['val_precision']:.4f} | Recall: {val_metrics['val_recall']:.4f}")
        print(f"  F1 Score:  {val_metrics['val_f1']:.4f} | AUC:    {val_metrics['val_auc']:.4f}")
    
    # Save best model based on validation F1 score
    if val_metrics['val_f1'] > best_val_f1:
        best_val_f1 = val_metrics['val_f1']
        best_val_loss = val_metrics['val_loss']
        patience_counter = 0
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': siamese_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['val_loss'],
            'val_f1': val_metrics['val_f1'],
            'val_metrics': val_metrics
        }, f'best_siamese_{timestamp}.pt')
        
        print(f"  ✓ Best model saved! F1: {best_val_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch}")
            break

writer.close()

print("\n" + "="*80)
print("✓ TRAINING COMPLETE!")
print("="*80)
print(f"Best Validation Loss: {best_val_loss:.4f}")
print(f"Best Validation F1 Score: {best_val_f1:.4f}")


TRAINING SIAMESE NETWORK


TypeError: GNNEncoder.forward() takes 3 positional arguments but 4 were given