## Siamese Network

Start-of-the-art tattoo verification using Siamese Network

#### References
- [Siamese Network](https://builtin.com/machine-learning/siamese-network)

### Load dependencies

In [40]:
# This will help us to measure the time it took for the whole
# notebook to execute
import time
start_time = time.time()

import os
import importlib
import pandas as pd
from PIL import Image

import sys
sys.path.append('../../utils')
import datasets
importlib.reload(datasets)
import helpers
importlib.reload(helpers)
import annotations
importlib.reload(annotations)

import re
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

### Get dataset

In [41]:
bound_box_path = Path("../../datasets/BIVTatt-Dataset/bounding_boxes")
data_path = Path("../../datasets/BIVTatt-Dataset/images")
pattern = r'^\d+_\d+\.JPG'
total_bound_boxes = [file.name for file in bound_box_path.iterdir() if file.is_file()]
all_images = [file.name for file in data_path.iterdir() if file.is_file()]
base_images = [file.name for file in data_path.iterdir() if file.is_file() and re.match(pattern, file.name)]

print ("Base images in data folder: ")
print("     Total of bounding boxes: ", len(total_bound_boxes))
print("     Total of images: ", len(all_images))
print("     Total of base images: ", len(base_images))
print('')
print("Base images and their variants")

base_image_variant_counts = {base_image: 0 for base_image in base_images}

for image in all_images:
    for base_image in base_images:
        if image.startswith(base_image[:-4]):
            base_image_variant_counts[base_image] += 1

for base_image, count in base_image_variant_counts.items():
    print(f"    Base image '{base_image}' has {count} variants.")

# Define dataset base path
dataset_path = annotations.bivtatt_dataset_path

print("")
annotations.bivtatt_dataset()

Base images in data folder: 
     Total of bounding boxes:  4410
     Total of images:  4410
     Total of base images:  161

Base images and their variants
    Base image '118_1.JPG' has 21 variants.
    Base image '103_2.JPG' has 21 variants.
    Base image '77_1.JPG' has 21 variants.
    Base image '32_1.JPG' has 21 variants.
    Base image '14_2.JPG' has 21 variants.
    Base image '53_1.JPG' has 21 variants.
    Base image '16_1.JPG' has 21 variants.
    Base image '144_1.JPG' has 21 variants.
    Base image '93_1.JPG' has 21 variants.
    Base image '91_1.JPG' has 21 variants.
    Base image '146_1.JPG' has 21 variants.
    Base image '103_1.JPG' has 21 variants.
    Base image '51_1.JPG' has 21 variants.
    Base image '14_1.JPG' has 21 variants.
    Base image '88_1.JPG' has 21 variants.
    Base image '75_1.JPG' has 21 variants.
    Base image '16_2.JPG' has 21 variants.
    Base image '127_1.JPG' has 21 variants.
    Base image '48_1.JPG' has 21 variants.
    Base image '10_4

### Define Needed classes

In [42]:
# Dataset class
class BIVTattDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        """
        Args:
            annotations_file (str): Path to the annotations CSV file.
            img_dir (str): Directory containing tattoo images.
            transform (callable, optional): Transformations to be applied on the images.
        """
        self.annotations = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __getitem__(self, index):
        # Get the pair of images and their label
        img1_name = self.annotations.iloc[index]['image1']
        img2_name = self.annotations.iloc[index]['image2']
        label = self.annotations.iloc[index]['label']

        # Load images
        img1 = Image.open(os.path.join(self.img_dir, img1_name)).convert("RGB")
        img2 = Image.open(os.path.join(self.img_dir, img2_name)).convert("RGB")

        # Apply transformations
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return len(self.annotations)



### Training Dataset

In [43]:
# Define Transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Dataset location
annotations_path = f'{dataset_path}annotations.csv'
images_path = f'{dataset_path}images'

# Split dataset
annotations = pd.read_csv(annotations_path)
train_annotations, test_annotations = train_test_split(annotations, test_size=0.2, random_state=42)

# Save the split annotations
train_annotations.to_csv(f'{dataset_path}train_annotations.csv', index=False)
test_annotations.to_csv(f'{dataset_path}test_annotations.csv', index=False)

train_dataset = BIVTattDataset(annotations_file=f'{dataset_path}train_annotations.csv', img_dir=images_path, transform=transform)
test_dataset = BIVTattDataset(annotations_file=f'{dataset_path}test_annotations.csv', img_dir=images_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

### Define Siamese Network

A siamese neural network (SNN) is a class of neural network architectures that contain two or more identical sub-networks.

“Identical” here means they have the same configuration with the same parameters and weights.

Parameter updating is mirrored across both sub-networks and it’s used to find similarities between inputs by comparing its feature vectors.

In [44]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Use a pre-trained ResNet backbone
        self.backbone = models.resnet18(pretrained=True)
        # Replace the final layer with a smaller output for embeddings
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, 128)
    
    def forward_once(self, x):
        return self.backbone(x)
    
    def forward(self, img1, img2):
        # Pass both images through the backbone
        feat1 = self.forward_once(img1)
        feat2 = self.forward_once(img2)
        return feat1, feat2

### Define Contrastive Loss

It’s a distance-based loss as opposed to more conventional error-prediction loss.

This loss function is used to learn embeddings in which two similar points have a low Euclidean distance and two dissimilar points have a large Euclidean distance.



In [4]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, output1, output2, label):
        # Compute the Euclidean distance
        distance = torch.norm(output1 - output2, p=2, dim=1)
        # Contrastive loss formula
        loss = torch.mean(
            (1 - label) * torch.pow(distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
        )
        return loss

### Train model

In [45]:
# Initialize the Model
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/administrator/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:02<00:00, 19.2MB/s]


In [46]:
# Training Loop
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for img1, img2, label in train_loader:
        optimizer.zero_grad()
        output1, output2 = model(img1, img2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

Epoch [1/10], Loss: 1635.5135


KeyboardInterrupt: 

In [47]:
# Save the Model
saved_model = f'{dataset_path}siamese_tattoo_model.pth'
torch.save(model.state_dict(), saved_model)

In [48]:
# Evaluation Function
def verify_tattoo(img1_path, img2_path, model, transform):
    img1 = Image.open(img1_path).convert("RGB")
    img2 = Image.open(img2_path).convert("RGB")
    img1 = transform(img1).unsqueeze(0)  # Add batch dimension
    img2 = transform(img2).unsqueeze(0)
    with torch.no_grad():
        feat1, feat2 = model(img1, img2)
        distance = torch.norm(feat1 - feat2, p=2).item()
    return distance

In [49]:
# Load the Model for Testing
model.load_state_dict(torch.load(saved_model))
model.eval()

  model.load_state_dict(torch.load(saved_model))


SiameseNetwork(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [52]:
# Test with Example Pairs
test_img1 = os.path.join(images_path, '1_2.JPG')
test_img2 = os.path.join(images_path, '4_2.JPG')
distance = verify_tattoo(test_img1, test_img2, model, transform)
print(f"Distance between test images: {distance}")

Distance between test images: 0.045044515281915665


---

## Total Time

This show the total time of execution

In [53]:
# Sets the total time of execution
end_time = time.time()
helpers.calculate_execution_time(start_time, end_time)

Total execution time: 124.0 minutes and 3.15 seconds
