In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.embedding_net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),
            nn.Conv2d(64, 128, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128*4*4, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 2)  # Output has 2 dimensions for binary classification
        )

    def forward_one(self, x):
        x = self.embedding_net(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

def contrastive_loss(output1, output2, label, margin=2.0):
    euclidean_distance = nn.functional.pairwise_distance(output1, output2)
    loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                  (label) * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2))
    return loss_contrastive

In [None]:
class SiameseDataset(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __getitem__(self, index):
        img1, label1 = self.mnist_dataset[index]
        img2, label2 = self.mnist_dataset[self._get_random_target(label1)]
        return self.transform(img1), self.transform(img2), torch.tensor(int(label1 != label2), dtype=torch.float32)

    def __len__(self):
        return len(self.mnist_dataset)

    def _get_random_target(self, current_label):
        potential_targets = [i for i in range(len(self.mnist_dataset)) if self.mnist_dataset[i][1] != current_label]
        return torch.randint(0, len(potential_targets), (1,)).item()


In [None]:
# Load MNIST dataset
mnist_dataset = MNIST(root="./data", train=True, download=True)

In [None]:
# Create Siamese dataset
siamese_dataset = SiameseDataset(mnist_dataset)

# Set up dataloader
batch_size = 64
dataloader = DataLoader(siamese_dataset, batch_size=batch_size, shuffle=True)

# Initialize the model, loss, and optimizer
siamese_net = SiameseNetwork()
criterion = contrastive_loss
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)

In [None]:

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch in dataloader:
        input1, input2, label = batch
        optimizer.zero_grad()
        output1, output2 = siamese_net(input1, input2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Save the trained model
torch.save(siamese_net.state_dict(), 'siamese_model.pth')
