In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor, ToPILImage
from torchvision import datasets
from torchvision import models
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
import random

In [3]:
class cnn(nn.Module):
    def __init__(self) -> None:
        super(cnn, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.conv2 = nn.Sequential(
            nn.Conv2d(8, 8, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.flatten = nn.Flatten()
        self.fc1 = nn.Sequential(
            nn.Linear(7*7*8, 28),
            nn.ReLU())
        self.fc2 = nn.Linear(28, 10)

    def forward(self, inp):
        x = self.conv1(inp)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return self.fc2(x)
        


train_dataset = datasets.MNIST(root='../data/', transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='../data/', transform=transforms.ToTensor(), train=False)

train = DataLoader(train_dataset, batch_size=64, shuffle=True)
test = DataLoader(test_dataset, batch_size=64, shuffle=True)

model = cnn()
optimizer = optim.Adam(params=model.parameters())

for i in tqdm(range(10)):
    running_loss = 0
    for x, y in train:
        pred = model.forward(x)
        loss = F.cross_entropy(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss
    running_loss /= len(train)
    print(f'Loss: {running_loss}')
    correct = 0
    for x, y in test:
        pred = model.forward(x)
        correct += (pred.argmax(1) == y).sum().item()
    print(f'Testing accuracy: {correct/test_dataset.__len__()}')

  0%|          | 0/10 [00:00<?, ?it/s]

Loss: 0.48468831181526184


 10%|█         | 1/10 [00:11<01:42, 11.42s/it]

Testing accuracy: 0.934
Loss: 0.14353947341442108


 20%|██        | 2/10 [00:19<01:16,  9.58s/it]

Testing accuracy: 0.969
Loss: 0.09697817265987396


 30%|███       | 3/10 [00:27<01:01,  8.80s/it]

Testing accuracy: 0.9693
Loss: 0.07919647544622421


 40%|████      | 4/10 [00:35<00:50,  8.37s/it]

Testing accuracy: 0.9787
Loss: 0.06733802706003189


 50%|█████     | 5/10 [00:43<00:40,  8.14s/it]

Testing accuracy: 0.9803
Loss: 0.05966944247484207


 60%|██████    | 6/10 [00:50<00:32,  8.05s/it]

Testing accuracy: 0.9811
Loss: 0.05532391741871834


 70%|███████   | 7/10 [00:58<00:23,  7.89s/it]

Testing accuracy: 0.9837
Loss: 0.04929293319582939


 80%|████████  | 8/10 [01:06<00:15,  7.90s/it]

Testing accuracy: 0.9822
Loss: 0.046428218483924866


 90%|█████████ | 9/10 [01:14<00:07,  7.84s/it]

Testing accuracy: 0.9853
Loss: 0.04253356158733368


100%|██████████| 10/10 [01:21<00:00,  8.18s/it]

Testing accuracy: 0.9847





In [4]:
dataset = datasets.FashionMNIST(root='../data/', transform=transforms.ToTensor())

In [5]:
images = defaultdict(list)

counts = [2]*10
for i in dataset:
    if counts[i[1]]:
        images[i[1]].append(i[0])
        counts[i[1]] -= 1
    if sum(counts)==0:
        break

In [6]:

class siamese(nn.Module):
    def __init__(self, embedding) -> None:
        super(siamese, self).__init__()
        self.embedding = embedding
        self.fc = nn.Linear(10, 5)
    
    def forward(self, inp):
        x = self.embedding.forward(inp)
        return self.fc(x)

model.requires_grad_(False)
sim_model = siamese(model)
optimizer = optim.Adam(params=sim_model.parameters())
EPOCHS = 10000


for i in tqdm(range(1, EPOCHS+1)):
    true_class = random.randint(0, 9)
    true_img1_idx = random.randint(0, 1)
    true_img1 = images[true_class][true_img1_idx]
    true_img2 = images[true_class][1-true_img1_idx]
    false_img_idx = random.randint(0, 9)
    while false_img_idx == true_img1_idx:
        false_img_idx = random.randint(0, 9)
    false_img = images[false_img_idx][random.randint(0, 1)]

    out1 = sim_model(true_img1.unsqueeze(0))
    out2 = sim_model(true_img2.unsqueeze(0))
    out3 = sim_model(false_img.unsqueeze(0))

    d_plus = F.mse_loss(out1, out2)**2
    d_minus = F.mse_loss(out1, out3)**2
    d = d_plus + 10 - d_minus
    loss = F.relu(d)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


100%|██████████| 10000/10000 [00:17<00:00, 577.10it/s]


In [8]:
img1 = images[2][0].unsqueeze(0)
img2 = images[2][1].unsqueeze(0)
img3 = images[5][0].unsqueeze(0)

with torch.no_grad():
    out1 = sim_model.forward(img1)
    out2 = sim_model.forward(img2)
    out3 = sim_model.forward(img3)
    print(f'distance between True: {F.mse_loss(out1, out2)}')
    print(f'distance between False: {F.mse_loss(out1, out3)}')

distance between True: 1.2505754232406616
distance between False: 56.155799865722656
