In [13]:
import torch
from torch import Tensor
from torchvision import transforms
from lib.lib import SignatureDataset, image_to_graph
from typing import Tuple, Optional, Union
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import GAE


In [14]:
class SignatureGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, embedding_dim)
        
    def forward(self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None) -> Tensor:
        """
        x: Node features [num_nodes, in_channels]
        edge_index: Graph edges [2, num_edges]
        batch: Graph IDs for mini-batch training [num_nodes] - not used for GAE
        """
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index)
        
        # Return node-level embeddings for GAE
        # Do NOT use global_mean_pool for GAE - it needs node embeddings
        return x  # [num_nodes, embedding_dim]

In [30]:
def transform(**kwargs):
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=kwargs['num_output_channels']),
        transforms.Resize(kwargs['resize']),
        transforms.ToTensor(),
    ])

dataset = SignatureDataset(
    root_dir="test_image",
    transform=transform(num_output_channels=1, resize=(150, 150))
)

test_graph = []

for t in dataset:
    for test_tensor_image in t:
        t_graph = image_to_graph(test_tensor_image)
        test_graph.append(t_graph)

next(iter(test_graph))
next(iter(test_graph))

sample = next(iter(test_graph))

Loaded 3 signature images (genuine + forged)


In [31]:
input_dim = next(iter(test_graph)).x.shape[1]
hidden_dim = 64
embedding_dim = 128
epochs = 1500

In [32]:
encoder = SignatureGNN(input_dim, hidden_dim, embedding_dim)
model = GAE(encoder)
model.load_state_dict(torch.load('model.pth'))
model.eval()

with torch.no_grad():
    z = model.encode(sample.x, sample.edge_index)

In [33]:
z

tensor([[ 0.1388, -0.1934,  0.1532,  ...,  0.1042,  0.1565,  0.0883],
        [ 0.1617, -0.2071,  0.1523,  ...,  0.1083,  0.1765,  0.0973],
        [ 0.1558, -0.2031,  0.1537,  ...,  0.1067,  0.1720,  0.0956],
        ...,
        [ 0.1222, -0.2086,  0.1686,  ...,  0.0904,  0.1480,  0.0965],
        [ 0.1231, -0.2146,  0.1706,  ...,  0.0883,  0.1509,  0.0998],
        [ 0.1050, -0.2000,  0.1693,  ...,  0.0864,  0.1346,  0.0907]])