In [1]:
import torch
from torch import Tensor
from torchvision import transforms
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from tqdm import tqdm

import os
from PIL import Image
from torch.utils.data import Dataset

import torch.nn.functional as F
import numpy as np
from sklearn.feature_extraction import image
import cv2
from typing import Tuple, Optional, Union

import kagglehub

In [2]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 32  # Changed to match your DataLoader batch_size
epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class SignatureGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, embedding_dim)
        
    def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:
        """
        x: Node features [num_nodes, in_channels]
        edge_index: Graph edges [2, num_edges]
        batch: Graph IDs for mini-batch training [num_nodes]
        """
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index)
        
        # Aggregate node embeddings into a graph-level signature embedding
        x = global_mean_pool(x, batch)  # [num_graphs, embedding_dim]
        
        return x


In [4]:
class SignatureDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # collect all signer folders
        signer_folders = sorted(os.listdir(root_dir))

        for folder in signer_folders:
            folder_path = os.path.join(root_dir, folder)
            if os.path.isdir(folder_path):
                for img_name in os.listdir(folder_path):
                    if self._is_image_file(img_name):
                        self.samples.append(os.path.join(folder_path, img_name))

        print(f"Loaded {len(self.samples)} signature images (genuine + forged)")

    def _is_image_file(self, filename):
        valid_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
        return os.path.splitext(filename.lower())[1] in valid_exts

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path = self.samples[idx]
        try:
            image = Image.open(path).convert("L")  # grayscale
            if self.transform:
                image = self.transform(image)
            return image   # only image, no label
        except Exception as e:
            print(f"Error loading {path}: {e}")
            # fallback blank image
            fallback = Image.new("L", (224, 224), 0)
            if self.transform:
                fallback = self.transform(fallback)
            return fallback

In [5]:
def image_to_graph(
    image_tensor: torch.Tensor,
    patch_size: int = 8,
    k_neighbors: int = 8,
    edge_threshold: float = 0.1,
    include_features: bool = True
) -> Data:
    """
    Convert an image to a graph representation with nodes and edges.
    
    Args:
        image_tensor: Input image tensor of shape (C, H, W) or (H, W)
        method: Graph construction method ('grid', 'knn', 'superpixel', 'region')
        patch_size: Size of patches for grid method
        k_neighbors: Number of neighbors for KNN method
        edge_threshold: Threshold for edge creation based on feature similarity
        include_features: Whether to include patch features as node features
        
    Returns:
        PyTorch Geometric Data object with node features and edge indices
    """
    
    return _image_to_grid_graph(image_tensor, patch_size, include_features)

def _image_to_grid_graph(
    image_tensor: torch.Tensor, 
    patch_size: int, 
    include_features: bool
) -> Data:
    """Convert image to grid-based graph where each patch is a node."""
    
    # Handle different input shapes
    if len(image_tensor.shape) == 2:
        image_tensor = image_tensor.unsqueeze(0)  # Add channel dimension
    
    C, H, W = image_tensor.shape
    
    # Create patches
    patches_h = H // patch_size
    patches_w = W // patch_size
    
    # Extract patch features
    node_features = []
    node_positions = []
    
    for i in range(patches_h):
        for j in range(patches_w):
            # Extract patch
            patch = image_tensor[
                :, 
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size
            ]
            
            if include_features:
                # Compute patch statistics as features
                mean_val = patch.mean(dim=[1, 2])  # Per channel mean
                std_val = patch.std(dim=[1, 2])    # Per channel std
                max_val = patch.max(dim=2)[0].max(dim=1)[0]  # Per channel max
                min_val = patch.min(dim=2)[0].min(dim=1)[0]  # Per channel min
                
                features = torch.cat([mean_val, std_val, max_val, min_val])
                node_features.append(features)
            
            # Store position
            node_positions.append([i, j])
    
    # Create edges (connect adjacent patches)
    edge_indices = []
    
    for i in range(patches_h):
        for j in range(patches_w):
            current_node = i * patches_w + j
            
            # Connect to neighbors (4-connectivity)
            neighbors = [
                (i-1, j), (i+1, j),  # vertical neighbors
                (i, j-1), (i, j+1)   # horizontal neighbors
            ]
            
            # Add diagonal connections for 8-connectivity
            neighbors.extend([
                (i-1, j-1), (i-1, j+1),
                (i+1, j-1), (i+1, j+1)
            ])
            
            for ni, nj in neighbors:
                if 0 <= ni < patches_h and 0 <= nj < patches_w:
                    neighbor_node = ni * patches_w + nj
                    edge_indices.append([current_node, neighbor_node])
    
    # Convert to tensors
    if include_features:
        x = torch.stack(node_features)
    else:
        x = torch.tensor(node_positions, dtype=torch.float32)
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    pos = torch.tensor(node_positions, dtype=torch.float32)
    
    return Data(x=x, edge_index=edge_index, pos=pos)

In [6]:
def dataset_path():
    path = kagglehub.dataset_download("akashgundu/signature-verification-dataset")
    return os.path.join(path, 'extract')

def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])

dataset = SignatureDataset(
    root_dir=dataset_path(),
    transform=transform(num_output_channels=1, resize=(150, 150))
)

Loaded 14626 signature images (genuine + forged)


In [7]:
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes - Train: {train_size}, Validation: {val_size}")

Dataset sizes - Train: 11700, Validation: 2926


In [8]:
train_graph = []
val_graph = []

for t, v in tqdm(zip(train_dataset, val_dataset)):
    for train_tensor_image, val_tensor_image in zip(t, v):
        t_graph = image_to_graph(train_tensor_image)
        v_graph = image_to_graph(val_tensor_image)
        train_graph.append(t_graph)
        val_graph.append(v_graph)

2926it [03:05, 15.78it/s]


In [9]:
train_graph = DataLoader(train_graph, batch_size=batch_size, shuffle=True)
val_graph = DataLoader(val_graph, batch_size=batch_size, shuffle=False)

In [10]:
def train_loop(graphs, model, epochs=50, optimizer):
    for epoch in range(epochs):
        model.train()

def test_loop(graphs, model, loss_type="reconstruction", **kwargs):
    """Simplified testing loop for unsupervised GNNs"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for graph in tqdm(graphs, desc="Testing"):
            embedding = model(graph.x, graph.edge_index, getattr(graph, "batch", None))

            if loss_type == "reconstruction":
                loss = reconstruction_loss(embedding, graph.edge_index)
            elif loss_type == "contrastive":
                loss = contrastive_loss(embedding, temperature=kwargs.get("temperature", 0.1))
            elif loss_type == "edge_prediction":
                loss = edge_prediction_loss(embedding, graph.edge_index)
            else:
                raise ValueError(f"Unknown loss type: {loss_type}")

            total_loss += loss.item()

    avg_loss = total_loss / len(graphs)
    print(f"Test Loss: {avg_loss:.4f}")
    return avg_loss


# -------- Losses -------- #

def reconstruction_loss(embeddings, edge_index):
    num_nodes = embeddings.size(0)
    sim_matrix = torch.sigmoid(embeddings @ embeddings.t())

    adj = torch.zeros(num_nodes, num_nodes, device=embeddings.device)

    # Filter edges so they don’t exceed num_nodes
    mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
    edge_index = edge_index[:, mask]

    adj[edge_index[0], edge_index[1]] = 1

    return F.binary_cross_entropy(sim_matrix, adj)


def contrastive_loss(embeddings, temperature=0.1):
    """Self-supervised contrastive loss"""
    z = F.normalize(embeddings, dim=1)
    sim = z @ z.t() / temperature
    labels = torch.arange(z.size(0), device=z.device)
    return F.cross_entropy(sim, labels)


def edge_prediction_loss(embeddings, edge_index):
    """Binary link prediction loss"""
    neg_edge_index = negative_sampling(
        edge_index, num_nodes=embeddings.size(0),
        num_neg_samples=edge_index.size(1)
    )

    pos_scores = (embeddings[edge_index[0]] * embeddings[edge_index[1]]).sum(dim=1)
    neg_scores = (embeddings[neg_edge_index[0]] * embeddings[neg_edge_index[1]]).sum(dim=1)

    pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
    neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
    return (pos_loss + neg_loss) / 2


In [11]:
writer = SummaryWriter('runs/signature_gnn')

input_dim = next(iter(train_graph)).x.shape[1]
hidden_dim = 64
output_dim = 128

model = SignatureGCN(input_dim, hidden_dim, output_dim).to(device)

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
next(iter(train_graph))

DataBatch(x=[10368, 4], edge_index=[2, 76160], pos=[10368, 2], batch=[10368], ptr=[33])

In [None]:
# Initialize best loss for unsupervised learning (lower is better)
best_val_loss = float('inf')

for epoch in range(epochs):
    # Training - corrected function call for unsupervised learning
    train_loss = train_loop(train_graph, model, optimizer, loss_type='reconstruction')
    
    # Validation - corrected function call for unsupervised learning  
    val_loss = test_loop(val_graph, model, loss_type='reconstruction')
    
    # Logging - only loss since we don't have accuracy in unsupervised learning
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Loss/Validation', val_loss, epoch)
    
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print("-" * 50)
    
    # Save best model based on validation loss (lower is better)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_feature_extraction_model.pth')
        print(f"✓ New best model saved with validation loss: {val_loss:.4f}")

writer.close()
print(f"\nTraining completed! Best validation loss: {best_val_loss:.4f}")
print("Model saved as 'best_feature_extraction_model.pth'")

# Alternative: If reconstruction still gives issues, use edge_prediction
# Just change 'reconstruction' to 'edge_prediction' in both function calls:
# train_loss = train_loop(train_graph, model, optimizer, loss_type='edge_prediction')
# val_loss = test_loop(val_graph, model, loss_type='edge_prediction')