In [40]:
import torch
from torch import Tensor
from torchvision import transforms
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data 
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from tqdm import tqdm

import os
from PIL import Image
from torch.utils.data import Dataset

import torch.nn.functional as F
import numpy as np
from sklearn.feature_extraction import image
import cv2
from typing import Tuple, Optional, Union

import kagglehub

In [41]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 32  # Changed to match your DataLoader batch_size
epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [42]:
class SignatureGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, embedding_dim)
        
    def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:
        """
        x: Node features [num_nodes, in_channels]
        edge_index: Graph edges [2, num_edges]
        batch: Graph IDs for mini-batch training [num_nodes]
        """
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index)
        
        # Aggregate node embeddings into a graph-level signature embedding
        x = global_mean_pool(x, batch)  # [num_graphs, embedding_dim]
        
        return x


In [43]:
class SignatureDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # collect all signer folders
        signer_folders = sorted(os.listdir(root_dir))

        for folder in signer_folders:
            folder_path = os.path.join(root_dir, folder)
            if os.path.isdir(folder_path):
                for img_name in os.listdir(folder_path):
                    if self._is_image_file(img_name):
                        self.samples.append(os.path.join(folder_path, img_name))

        print(f"Loaded {len(self.samples)} signature images (genuine + forged)")

    def _is_image_file(self, filename):
        valid_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
        return os.path.splitext(filename.lower())[1] in valid_exts

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path = self.samples[idx]
        try:
            image = Image.open(path).convert("L")  # grayscale
            if self.transform:
                image = self.transform(image)
            return image   # only image, no label
        except Exception as e:
            print(f"Error loading {path}: {e}")
            # fallback blank image
            fallback = Image.new("L", (224, 224), 0)
            if self.transform:
                fallback = self.transform(fallback)
            return fallback

In [87]:
def image_to_graph(
    image_tensor: torch.Tensor,
    patch_size: int = 8,
    k_neighbors: int = 8,
    edge_threshold: float = 0.1,
    include_features: bool = True
) -> Data:
    """
    Convert an image to a graph representation with nodes and edges.
    
    Args:
        image_tensor: Input image tensor of shape (C, H, W) or (H, W)
        method: Graph construction method ('grid', 'knn', 'superpixel', 'region')
        patch_size: Size of patches for grid method
        k_neighbors: Number of neighbors for KNN method
        edge_threshold: Threshold for edge creation based on feature similarity
        include_features: Whether to include patch features as node features
        
    Returns:
        PyTorch Geometric Data object with node features and edge indices
    """
    
    return _image_to_grid_graph(image_tensor, patch_size, include_features)

def _image_to_grid_graph(
    image_tensor: torch.Tensor, 
    patch_size: int, 
    include_features: bool
) -> Data:
    """Convert image to grid-based graph where each patch is a node."""
    
    # Handle different input shapes
    if len(image_tensor.shape) == 2:
        image_tensor = image_tensor.unsqueeze(0)  # Add channel dimension
    
    C, H, W = image_tensor.shape
    
    # Create patches
    patches_h = H // patch_size
    patches_w = W // patch_size
    
    # Extract patch features
    node_features = []
    node_positions = []
    
    for i in range(patches_h):
        for j in range(patches_w):
            # Extract patch
            patch = image_tensor[
                :, 
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size
            ]
            
            if include_features:
                # Compute patch statistics as features
                mean_val = patch.mean(dim=[1, 2])  # Per channel mean
                std_val = patch.std(dim=[1, 2])    # Per channel std
                max_val = patch.max(dim=2)[0].max(dim=1)[0]  # Per channel max
                min_val = patch.min(dim=2)[0].min(dim=1)[0]  # Per channel min
                
                features = torch.cat([mean_val, std_val, max_val, min_val])
                node_features.append(features)
            
            # Store position
            node_positions.append([i, j])
    
    # Create edges (connect adjacent patches)
    edge_indices = []
    
    for i in range(patches_h):
        for j in range(patches_w):
            current_node = i * patches_w + j
            
            # Connect to neighbors (4-connectivity)
            neighbors = [
                (i-1, j), (i+1, j),  # vertical neighbors
                (i, j-1), (i, j+1)   # horizontal neighbors
            ]
            
            # Add diagonal connections for 8-connectivity
            neighbors.extend([
                (i-1, j-1), (i-1, j+1),
                (i+1, j-1), (i+1, j+1)
            ])
            
            for ni, nj in neighbors:
                if 0 <= ni < patches_h and 0 <= nj < patches_w:
                    neighbor_node = ni * patches_w + nj
                    edge_indices.append([current_node, neighbor_node])
    
    # Convert to tensors
    if include_features:
        x = torch.stack(node_features)
    else:
        x = torch.tensor(node_positions, dtype=torch.float32)
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    pos = torch.tensor(node_positions, dtype=torch.float32)
    
    return Data(x=x, edge_index=edge_index, pos=pos)

In [88]:
def dataset_path():
    path = kagglehub.dataset_download("akashgundu/signature-verification-dataset")
    return os.path.join(path, 'extract')

def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])

dataset = SignatureDataset(
    root_dir=dataset_path(),
    transform=transform(num_output_channels=1, resize=(150, 150))
)

Loaded 14626 signature images (genuine + forged)


In [89]:
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes - Train: {train_size}, Validation: {val_size}")

Dataset sizes - Train: 11700, Validation: 2926


In [90]:
train_graph = []
val_graph = []

for t, v in tqdm(zip(train_dataset, val_loader)):
    for train_tensor_image, val_tensor_image in zip(t, v):
        t_graph = image_to_graph(train_tensor_image)
        v_graph = image_to_graph(val_tensor_image)
        train_graph.append(t_graph)
        val_graph.append(v_graph)

92it [00:07, 11.89it/s]


In [91]:
def train_loop(graphs, model, loss_fn, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0
    
    for graph_idx, graph in enumerate(tqdm(graphs, desc="Training")):
        batch = next(iter(t_batch))
        optimizer.zero_grad()
        
        graph_loss = 0
        graph_correct = 0

        embedding = model(graph.x, graph.edge_index, graph.pos)

        # Calculate loss
        if embedding.dim() == 1:
            embedding = embedding.unsqueeze(0)
        loss = loss_fn(embedding, label)
        graph_loss += loss
        
        # Calculate accuracy
        if embedding.size(-1) > 1:  # Multi-class
            pred_class = embedding.argmax(dim=-1)
            graph_correct += (pred_class == label).sum().item()
        
        if graph_loss > 0:
            graph_loss.backward()
            optimizer.step()
            
            total_loss += graph_loss.item()
            correct += graph_correct
            total_samples += len(graphs)
        
        if graph_idx % 10 == 0:
            avg_loss = total_loss / (graph_idx + 1) if graph_idx > 0 else graph_loss.item()
            print(f"Graph {graph_idx}, Loss: {avg_loss:.4f}")
    
    avg_loss = total_loss / len(graphs) if len(graphs) > 0 else 0
    accuracy = correct / total_samples if total_samples > 0 else 0
    return avg_loss, accuracy

def test_loop(graphs, model, loss_fn):
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for graph in tqdm(graphs, desc="Testing"):
            batch = next(iter(v_batch))
            
            graph_loss = 0
            graph_correct = 0
                    
            # Get prediction
            embedding = model(graph.x, graph.edge_index, graph.pos)
            
            # Calculate loss
            if embedding.dim() == 1:
                embedding = embedding.unsqueeze(0)
            loss = loss_fn(embedding, label)
            graph_loss += loss.item()
            
            # Calculate accuracy
            if embedding.size(-1) > 1:
                pred_class = embedding.argmax(dim=-1)
                graph_correct += (pred_class == label).sum().item()
            
            total_loss += graph_loss
            correct += graph_correct
            total_samples += len(graphs)
    
    avg_loss = total_loss / total_samples if total_samples > 0 else 0
    accuracy = correct / total_samples if total_samples > 0 else 0
    
    print(f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {avg_loss:>8f} \n")
    return avg_loss, accuracy

In [97]:
writer = SummaryWriter('runs/signature_gnn')

input_dim = train_graph[0].x.shape[1]
hidden_dim = 64
output_dim = 128

model = SignatureGCN(input_dim, hidden_dim, output_dim).to(device)

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_loader = DataLoader(train_graph, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_graph, batch_size=batch_size, shuffle=False)

tensor([ 0,  0,  0,  ..., 31, 31, 31])

In [53]:
for epoch in range(epochs):
    # Training
    train_loss, train_accV = train_loop(train_graph, model, loss_fn, optimizer)
    
    # Validation
    val_loss, val_acc = test_loop(val_graph, model, loss_fn)

    # Logging
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Loss/Validation', val_loss, epoch)
    writer.add_scalar('Accuracy/Train', train_acc, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_signature_model.pth')
        print(f"New best model saved with validation accuracy: {val_acc:.4f}")

writer.close()
print(f"\nTraining completed! Best validation accuracy: {best_val_acc:.4f}")
print("Model saved as 'best_signature_model.pth'")


Training:   0%|                                                                                 | 0/92 [00:00<?, ?it/s]


ValueError: The `index` argument must be one-dimensional (got 2 dimensions)