In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, VGAE
from torch_geometric.data import Data

from tqdm import tqdm

from torchvision import transforms

from lib.lib import SignatureDataset, image_to_graph

In [2]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_dim)
        self.conv_logvar = GCNConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index):
        # Step 1: Aggregate node features from neighbors
        x = F.relu(self.conv1(x, edge_index))

        # Step 2: Output mean and log variance
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)

        return mu, logvar

In [39]:
# class SiameseNetwork(nn.Module):
#     """
#     Siamese Network for Signature Verification
#     Takes two signature graphs and predicts if they're from the same person
#     """
#     def __init__(self, gnn_vae_model, latent_dim=128, fc_hidden=256):
#         super().__init__()
        
#         # Pre-trained GNN-VAE for feature extraction
#         self.gnn_vae = gnn_vae_model
#         self.latent_dim = latent_dim
        
#         # Freeze GNN-VAE weights (optional - for fine-tuning, set to False)
#         for param in self.gnn_vae.parameters():
#             param.requires_grad = False
        
#         # Similarity computation layers
#         self.fc1 = nn.Linear(latent_dim * 2, fc_hidden)
#         self.bn1 = nn.BatchNorm1d(fc_hidden)
        
#         self.fc2 = nn.Linear(fc_hidden, fc_hidden // 2)
#         self.bn2 = nn.BatchNorm1d(fc_hidden // 2)
        
#         self.fc3 = nn.Linear(fc_hidden // 2, 1)
        
#         self.dropout = nn.Dropout(0.3)
        
#     def forward_one(self, x, edge_index, batch=None):
#         """
#         Extract features from one signature graph
#         """
#         # Use GNN-VAE's extract_features method
#         with torch.no_grad():  # Don't update GNN-VAE weights
#             embedding = self.gnn_vae.encode(x, edge_index)
        
#         return embedding
    
#     def forward(self, x1, edge_index1, x2, edge_index2, batch1=None, batch2=None):
#         """
#         Forward pass for two signatures
        
#         Args:
#             x1, edge_index1: First signature graph
#             x2, edge_index2: Second signature graph
#             batch1, batch2: Batch indices (if processing multiple pairs)
        
#         Returns:
#             similarity: Similarity score [0, 1]
#             emb1, emb2: Embeddings for analysis
#         """
#         # Extract embeddings from both signatures
#         emb1 = self.forward_one(x1, edge_index1, batch1)  # [batch_size, latent_dim]
#         emb2 = self.forward_one(x2, edge_index2, batch2)  # [batch_size, latent_dim]
        
#         # Concatenate embeddings
#         combined = torch.cat([emb1, emb2], dim=1)  # [batch_size, latent_dim * 2]
        
#         # Pass through similarity network
#         x = self.fc1(combined)
#         if x.size(0) > 1:
#             x = self.bn1(x)
#         x = F.relu(x)
#         x = self.dropout(x)
        
#         x = self.fc2(x)
#         if x.size(0) > 1:
#             x = self.bn2(x)
#         x = F.relu(x)
#         x = self.dropout(x)
        
#         # Output similarity score
#         similarity = torch.sigmoid(self.fc3(x))  # [batch_size, 1]
        
#         return similarity, emb1, emb2
    
#     def predict(self, x1, edge_index1, x2, edge_index2):
#         """
#         Predict if two signatures are from the same person
        
#         Returns:
#             is_same_person: Boolean prediction
#             similarity_score: Confidence score
#         """
#         self.eval()
#         with torch.no_grad():
#             similarity, _, _ = self.forward(x1, edge_index1, x2, edge_index2)
#             similarity = similarity.mean()
#             is_same_person = similarity > 0.5
#             return is_same_person.item(), similarity.item()
    
#     def get_embedding(self, x, edge_index, batch=None):
#         """
#         Get embedding for a single signature
#         """
#         self.eval()
#         with torch.no_grad():
#             return self.forward_one(x, edge_index, batch)

class ContrastiveSiameseNetwork(nn.Module):
    """
    Siamese Network using Contrastive Loss
    Better for signature verification with distance-based similarity
    """
    def __init__(self, gnn_vae_model, latent_dim=32, margin=2.0):
        super().__init__()
        
        self.gnn_vae = gnn_vae_model
        self.latent_dim = latent_dim
        self.margin = margin  # Margin for contrastive loss
        
        # Freeze GNN-VAE
        for param in self.gnn_vae.parameters():
            param.requires_grad = False
        
        # Optional: Additional projection layer
        self.projection = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
    
    def forward_one(self, x, edge_index, batch=None):
        """Extract and project features"""
        with torch.no_grad():
            embedding = self.gnn_vae.encode(x, edge_index)
        
        # Optional projection
        embedding = self.projection(embedding)
        
        return embedding
    
    def forward(self, x1, edge_index1, x2, edge_index2, batch1=None, batch2=None):
        """
        Returns embeddings and Euclidean distance
        """
        emb1 = self.forward_one(x1, edge_index1, batch1)
        emb2 = self.forward_one(x2, edge_index2, batch2)
        
        # Compute Euclidean distance
        distance = F.pairwise_distance(emb1, emb2)
        
        return distance, emb1, emb2
    
    def predict(self, x1, edge_index1, x2, edge_index2, threshold=1.0):
        """
        Predict based on distance threshold
        """
        self.eval()
        with torch.no_grad():
            distance, _, _ = self.forward(x1, edge_index1, x2, edge_index2)
            distance = distance.mean()
            is_same_person = distance < threshold
            return is_same_person.item(), distance.item()

In [79]:
def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])

dataset = SignatureDataset(
    root_dir="test_image",
    transform=transform(num_output_channels=1, resize=(150, 150))
)

Loaded 4 signature images (genuine + forged)


In [80]:
test_graph = []

# Convert training dataset
for t in tqdm(dataset, desc="Train Graphs", leave=False):
    t_graph = image_to_graph(t)
    test_graph.append(t_graph)

print("Train graphs:", len(test_graph))

                                                                                                                       

Train graphs: 4




In [81]:
input_dim = next(iter(test_graph)).x.shape[1]
hidden_dim = 64
latent_dim = 128

In [82]:
checkpoint = torch.load('VGAE_Model.pt')
gnn_vae = VGAE(GNNEncoder(in_channels=input_dim, hidden_channels=hidden_dim, latent_dim=latent_dim))
gnn_vae.load_state_dict(checkpoint)
gnn_vae.eval()

# Create Siamese Network
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# siamese_model = SiameseNetwork(
#     gnn_vae_model=gnn_vae,
#     latent_dim=128,
#     fc_hidden=256
# ).to(device)

contrastive_model = ContrastiveSiameseNetwork(
    gnn_vae_model=gnn_vae,
    latent_dim=128,
    margin=2.0
).to(device)

In [89]:
# Test with two signatures
graph1 = test_graph[1].to(device)
graph2 = test_graph[0].to(device)

# Predict similarity
# is_same, score = siamese_model.predict(
#     graph1.x, graph1.edge_index,
#     graph2.x, graph2.edge_index
# )

is_same, score = contrastive_model.predict(
    graph1.x, graph1.edge_index,
    graph2.x, graph2.edge_index
)

In [90]:
is_same

True

In [91]:
score

0.09018753468990326