# Pipeline test

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

# ResNet18 Feature Extractor Block
class ResNet18Block(nn.Module):
    def __init__(self):
        super(ResNet18Block, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])  # Remove FC layer and AvgPool
    
    def forward(self, x):
        return self.feature_extractor(x)

# 2D Attention Module
class AttentionModule(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionModule, self).__init__()
        self.w_f = nn.Conv2d(input_dim, hidden_dim, kernel_size=1)
        self.w_h = nn.Linear(input_dim, hidden_dim)
        self.w_att = nn.Linear(hidden_dim, 1)
        
    def forward(self, feature_map, pooled_feature):
        batch_size, channels, height, width = feature_map.size()
        reshaped_feature_map = feature_map.view(batch_size, channels, height * width).permute(0, 2, 1)  # (N, H*W, C)
        
        feature_out = self.w_f(feature_map).view(batch_size, -1, height * width).permute(0, 2, 1)  # (N, H*W, hidden_dim)
        pooled_out = self.w_h(pooled_feature).unsqueeze(1)  # (N, 1, hidden_dim)
        
        attention_scores = F.tanh(feature_out + pooled_out)
        attention_weights = F.softmax(self.w_att(attention_scores), dim=1)
        
        attended_output = torch.bmm(attention_weights.permute(0, 2, 1), feature_out)  # (N, 1, hidden_dim)
        attended_output = attended_output.squeeze(1)
        
        return attended_output

# U-Net-style Decoder
class UNetDecoder(nn.Module):
    def __init__(self, input_dim):
        super(UNetDecoder, self).__init__()
        
        self.upconv1 = nn.ConvTranspose2d(input_dim, 256, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.upconv4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)
        
    def forward(self, x):
        x = F.relu(self.conv1(self.upconv1(x)))
        x = F.relu(self.conv2(self.upconv2(x)))
        x = F.relu(self.conv3(self.upconv3(x)))
        x = F.relu(self.conv4(self.upconv4(x)))
        x = self.final_conv(x)
        return x

# Complete Model
class SignatureVerificationModel(nn.Module):
    def __init__(self):
        super(SignatureVerificationModel, self).__init__()
        self.resnet = ResNet18Block()
        self.attention = AttentionModule(input_dim=512, hidden_dim=256)
        self.decoder = UNetDecoder(input_dim=512)
    
    def forward(self, input_image, comparison_image):
        # Feature extraction
        feature_map = self.resnet(input_image)  # (N, 512, 16, 16)
        pooled_feature = self.resnet(comparison_image).mean(dim=[2, 3])  # (N, 512)
        
        # Attention
        attended_feature = self.attention(feature_map, pooled_feature)
        
        # Decode
        output = self.decoder(feature_map)
        
        return output, attended_feature

# Instantiate model
model = SignatureVerificationModel()
input_image = torch.randn(1, 3, 256, 256)  # Input size 256x256
comparison_image = torch.randn(1, 3, 256, 256)
output, attended_feature = model(input_image, comparison_image)

print(output.shape)  # Expected output shape: (1, 1, 256, 256)
print(attended_feature.shape)  # Expected attended feature shape: (1, 256)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/luca/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████████████████████████████████| 44.7M/44.7M [00:04<00:00, 11.5MB/s]


torch.Size([1, 1, 128, 128])
torch.Size([1, 256])


In [2]:
total_params = sum(p.numel() for p in model.parameters())

print(f"Numero totale di parametri nel modello: {total_params}")

Numero totale di parametri nel modello: 12920098
