## ResNet-18

Start-of-the-art tattoo verification using Siamese Network

#### References
- [RestNet-18](https://pytorch.org/hub/pytorch_vision_resnet/)

### Load dependencies

In [19]:
# This will help us to measure the time it took for the whole
# notebook to execute
import time
start_time = time.time()

import os
import re
import pandas as pd
from PIL import Image
from pathlib import Path

import importlib
import sys
sys.path.append('../../utils')
import datasets
importlib.reload(datasets)
import helpers
importlib.reload(helpers)
import annotations
importlib.reload(annotations)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

### Get dataset

In [20]:
bound_box_path = Path("../../datasets/BIVTatt-Dataset/bounding_boxes")
data_path = Path("../../datasets/BIVTatt-Dataset/images")
pattern = r'^\d+_\d+\.JPG'
total_bound_boxes = [file.name for file in bound_box_path.iterdir() if file.is_file()]
all_images = [file.name for file in data_path.iterdir() if file.is_file()]
base_images = [file.name for file in data_path.iterdir() if file.is_file() and re.match(pattern, file.name)]

print ("Base images in data folder: ")
print("     Total of bounding boxes: ", len(total_bound_boxes))
print("     Total of images: ", len(all_images))
print("     Total of base images: ", len(base_images))
print('')
print("Base images and their variants")

base_image_variant_counts = {base_image: 0 for base_image in base_images}

for image in all_images:
    for base_image in base_images:
        if image.startswith(base_image[:-4]):
            base_image_variant_counts[base_image] += 1

for base_image, count in base_image_variant_counts.items():
    print(f"    Base image '{base_image}' has {count} variants.")

# Define dataset base path
dataset_path = annotations.bivtatt_dataset_path

Base images in data folder: 
     Total of bounding boxes:  4410
     Total of images:  4410
     Total of base images:  161

Base images and their variants
    Base image '118_1.JPG' has 21 variants.
    Base image '103_2.JPG' has 21 variants.
    Base image '77_1.JPG' has 21 variants.
    Base image '32_1.JPG' has 21 variants.
    Base image '14_2.JPG' has 21 variants.
    Base image '53_1.JPG' has 21 variants.
    Base image '16_1.JPG' has 21 variants.
    Base image '144_1.JPG' has 21 variants.
    Base image '93_1.JPG' has 21 variants.
    Base image '91_1.JPG' has 21 variants.
    Base image '146_1.JPG' has 21 variants.
    Base image '103_1.JPG' has 21 variants.
    Base image '51_1.JPG' has 21 variants.
    Base image '14_1.JPG' has 21 variants.
    Base image '88_1.JPG' has 21 variants.
    Base image '75_1.JPG' has 21 variants.
    Base image '16_2.JPG' has 21 variants.
    Base image '127_1.JPG' has 21 variants.
    Base image '48_1.JPG' has 21 variants.
    Base image '10_4

### Define dataset

In [21]:
class BIVTattDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.pairs = []
        self.labels = []

        # Load all image pairs and labels
        for file in os.listdir(data_dir):
            if "_" in file:
                base_name, _ = file.split("_", 1)
                for other_file in os.listdir(data_dir):
                    if other_file.startswith(base_name) and file != other_file:
                        self.pairs.append((file, other_file))
                        self.labels.append(1 if file.split("_")[-1][0] == other_file.split("_")[-1][0] else 0)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path = self.pairs[idx]
        img1 = Image.open(os.path.join(self.data_dir, img1_path)).convert("RGB")
        img2 = Image.open(os.path.join(self.data_dir, img2_path)).convert("RGB")
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, label

### Define model

In [22]:
class ImagePairClassifier(nn.Module):
    def __init__(self):
        super(ImagePairClassifier, self).__init__()
        # Load pretrained ResNet-18
        self.backbone = models.resnet18(pretrained=True)
        # Adjust input channels for concatenated images
        self.backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Update the final fully connected layer for binary classification
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, 1)

    def forward(self, img1, img2):
        # Concatenate the two images along the channel dimension
        x = torch.cat((img1, img2), dim=1)  # Shape: [B, 6, H, W]
        return self.backbone(x)

### Training and Evaluation

In [None]:
# Training configurations
# cuda - for nvidia gpus
# mps - for macbook air
device = torch.device("cuda" if torch.cuda.is_available() else 
                      "mps" if torch.backends.mps.is_available() else 
                      "cpu")

print(f"Using device: {device}")

batch_size = 32
learning_rate = 0.001
num_epochs = 1

# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Dataset and DataLoader
dataset = BIVTattDataset(f"{dataset_path}images", transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model, Loss, and Optimizer
model = ImagePairClassifier().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



In [24]:
# Training loop
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for img1, img2, labels in train_loader:
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        # Forward pass
        outputs = model(img1, img2).squeeze(1)  # Remove singleton dimension
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), f"{dataset_path}bivtatt_resnet_18_model.pth")

KeyboardInterrupt: 

### Testing and Inference

In [None]:
# Load the model for inference
model.load_state_dict(torch.load(f"{dataset_path}bivtatt_resnet_18_model.pth"))
model.eval()

# Function for testing a single pair
def verify_tattoo(image1_path, image2_path):
    img1 = Image.open(image1_path).convert("RGB")
    img2 = Image.open(image2_path).convert("RGB")

    img1 = transform(img1).unsqueeze(0).to(device)  # Add batch dimension
    img2 = transform(img2).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img1, img2).item()
        similarity = torch.sigmoid(torch.tensor(output)).item()
    
    return "Match" if similarity > 0.5 else "No Match", similarity

# Example usage
result, similarity = verify_tattoo(f"{dataset_path}1_1.JPG", f"{dataset_path}1_1_a1.JPG")
print(f"Result: {result}, Similarity: {similarity:.4f}")

---

## Total Time

This show the total time of execution

In [None]:
# Sets the total time of execution
end_time = time.time()
helpers.calculate_execution_time(start_time, end_time)