In [1]:
from collections import defaultdict
import time
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score

import torch
import torchvision
from torch import nn
import datasets
from torch.nn import functional as F
from torch.optim import AdamW


transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)

mnist_train = torchvision.datasets.MNIST(root="./mnist", train=True,  download=True, transform=transform)
mnist_test  = torchvision.datasets.MNIST(root="./mnist", train=False, download=True, transform=transform)

device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 256 + 64
lr = 3e-3
embd_n = 128

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,    32, (3, 3))
        self.conv2 = nn.Conv2d(32,   64, (3, 3))
        self.max_pool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1   = nn.Linear(9216, 2048)
        self.fc2   = nn.Linear(2048, embd_n)
    
    def forward(self, x):
        x = self.conv1(x) # B, 1, 28, 28 ==> B, 32, 26, 26
        x = F.relu(x)
        x = self.conv2(x) # B, 32, 26, 26 ==> B, 64, 24, 24
        x = F.relu(x)
        x = self.max_pool(x) # B, 64, 12, 12
        x = self.flatten(x) # B, 9216
        x = self.fc1(x) # B, 2048
        x = self.fc2(x) # B, embed_n
        # Normalize using L2 norm
        x = F.normalize(x, p=2, dim=1)
        return x

model = Encoder()
model = model.to(device)
optim = AdamW(model.parameters(), lr=lr)
ce_loss = nn.CrossEntropyLoss()

train_dl = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_dl  = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)
margin = torch.tensor(0.5, device=device)

for batch_idx, (data, labels) in enumerate(train_dl):
    t0 = time.time()
    data = data.to(device)
    data_map = defaultdict(list)
    for idx, class_idx in enumerate(labels.tolist()):
        data_map[class_idx].append(idx)
    embeddings = model(data)
    optim.zero_grad()
    triplets_idx = []
    for class_idx in data_map:
        n = len(data_map[class_idx])
        neg_idx_list = [l for l in labels.tolist() if l not in data_map[class_idx]]
        for i in range(n):
            anc_idx = data_map[class_idx][i]
            for j in range(i+1, n):
                pos_idx = data_map[class_idx][j]
                for neg_idx in neg_idx_list:
                    triplets_idx.append(torch.tensor([anc_idx, pos_idx, neg_idx]))
    triplets_tensor = embeddings[torch.stack(triplets_idx, dim=0)]
    loss = torch.relu(torch.multiply(triplets_tensor[:, 0, :], triplets_tensor[:, 2, :]).sum(dim=1) - torch.multiply(triplets_tensor[:, 0, :], triplets_tensor[:, 1, :]).sum(dim=1) + margin).mean()
    #print(f"loss: {loss.item()}")
    loss.backward()
    optim.step()
    print(f"batch_idx: {batch_idx:2d}  loss: {loss.item():.3f} time : {(time.time() - t0):.2f}s")
    if batch_idx == 50:
        break

batch_idx:  0  loss: 0.389 time : 11.59s
batch_idx:  1  loss: 0.442 time : 10.72s
batch_idx:  2  loss: 0.304 time : 11.49s
batch_idx:  3  loss: 0.279 time : 10.97s
batch_idx:  4  loss: 0.274 time : 11.14s
batch_idx:  5  loss: 0.241 time : 11.86s
batch_idx:  6  loss: 0.226 time : 10.73s
batch_idx:  7  loss: 0.212 time : 11.16s
batch_idx:  8  loss: 0.210 time : 10.72s
batch_idx:  9  loss: 0.224 time : 10.98s
batch_idx: 10  loss: 0.191 time : 10.72s
batch_idx: 11  loss: 0.196 time : 11.26s
batch_idx: 12  loss: 0.182 time : 11.16s
batch_idx: 13  loss: 0.179 time : 10.91s
batch_idx: 14  loss: 0.170 time : 11.16s
batch_idx: 15  loss: 0.158 time : 10.81s
batch_idx: 16  loss: 0.179 time : 10.62s
batch_idx: 17  loss: 0.136 time : 10.57s
batch_idx: 18  loss: 0.127 time : 11.07s
batch_idx: 19  loss: 0.139 time : 9.97s
batch_idx: 20  loss: 0.109 time : 11.39s
batch_idx: 21  loss: 0.142 time : 10.38s
batch_idx: 22  loss: 0.120 time : 10.84s
batch_idx: 23  loss: 0.122 time : 11.05s
batch_idx: 24  lo

In [2]:
model.eval()

Encoder(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=9216, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=128, bias=True)
)

In [4]:
model.compile()

In [5]:
model.to("cpu")

Encoder(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=9216, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=128, bias=True)
)

In [7]:
train_embeddings_arr = []
train_labels_arr = []
for idx, (data, labels) in enumerate(train_dl):
    embeddings = model(data)
    train_embeddings_arr.extend(embeddings)
    train_labels_arr.extend(labels)
    del data
    del embeddings

In [8]:
train_emebddings_tensors = torch.stack(train_embeddings_arr)

In [9]:
train_emebddings_tensors.size()

torch.Size([60000, 128])

In [None]:
train_embeddings_arr = []
train_labels_arr = []
i = 0
for idx, (data, labels) in enumerate(train_dl):
    data = data.to(device)
    embeddings = model(data)
    train_embeddings_arr.extend(embeddings)
    train_labels_arr.extend(labels)
    if idx == 50:
        break