In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.models import resnet18
from einops import rearrange
from einops.layers.torch import Rearrange

# Data

Need to add classes

In [3]:
import numpy as np
import loaders 
from torch.utils.data import DataLoader
import torch
from generator import Generator, fonts
import string
from augmentation import aug

In [4]:
threshold = 50
transforms = T.Compose([
#     Rearrange("h w -> () h w"),
    #TODO Normalize with dataset mean and std...
    T.Resize((64,64)),
    T.ToTensor(),
#     T.Lambda(lambda x: (x < threshold).float().mean(0, keepdim=True))
])

In [5]:
chars = string.ascii_letters

In [6]:
train_gen = Generator(chars, fonts, aug)
test_gen = Generator(chars, fonts)

In [7]:
train_ds = loaders.FontDs(train_gen, 60000, transforms)
test_ds =  loaders.FontDs(test_gen, 10000, transforms)

# train_ds = loaders.KMNIST10("data/10", tfms=transforms)
# test_ds = loaders.KMNIST10("data/10", train=False, tfms=transforms)

train_dl = DataLoader(train_ds, batch_size=32, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=32, num_workers=8)

# Model

In [8]:
from models import create_model

In [9]:
device = "cuda"
vector_size = 10

In [10]:
model = create_model(3, vector_size)

In [11]:
model.to(device);

# Loss and Optim



https://omoindrot.github.io/triplet-loss#a-better-implementation-with-online-triplet-mining

In [12]:
from loss import batch_all_triplet_loss, batch_hard_triplet_loss

In [13]:
import torch.nn.functional as F
from torch import nn
from torch.optim import SGD

In [14]:
lr = 1e-4
opt = SGD(model.parameters(), lr=lr)

# Loop

In [15]:
from torch.utils.tensorboard import FileWriter, SummaryWriter 
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, accuracy_score
from pathlib import Path

In [16]:
summary = SummaryWriter()

In [17]:
epoch = 0

In [18]:
best_frac = 1

In [19]:
checkpoints = Path("checkpoints/siamese")

In [20]:
loss_fn = batch_all_triplet_loss

In [None]:
for _ in range(20):
    epoch += 1
    
    #Training
    train_loss = 0
    train_frac = 0
    pbar = tqdm(train_dl)

    for data in pbar:
        opt.zero_grad()
        x, y = [i.to(device) for i in data]
        embeddings = model(x)
        loss, frac = loss_fn(y, embeddings)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        train_frac += frac.item()

    train_loss /= len(train_ds)
    train_frac /= len(train_dl)
    pbar.set_description(f"Training {epoch} | Loss: {train_loss:.3f} | Fraction {train_frac:.3f}")

    # Testing
    test_loss = 0
    test_frac = 0
    pbar = tqdm(test_dl)
    
    test_embeddings = []
    test_imgs = []
    

    with torch.no_grad():
        for data in pbar:
            x, y = [i.to(device) for i in data]
            embeddings = model(x)
            loss, frac = loss_fn(y, embeddings)
            test_loss += loss.item()
            test_frac += frac.item()
            test_embeddings.append(embeddings.cpu())
            test_imgs.append(x.cpu())
            
    
    test_embeddings = torch.cat(test_embeddings, 0)
    test_imgs = torch.cat(test_imgs, 0)
    summary.add_embedding(test_embeddings, global_step=epoch, label_img=test_imgs, tag="embeddings")
    test_loss /= len(test_ds)
    test_frac /= len(test_dl)
    pbar.set_description(f"Testing {epoch} | Loss: {test_loss:.3f} | Fraction {test_frac:.3f}")
    
    
    # Tensorboard
    summary.add_scalars("siamese_loss", {
        "train" : train_loss,
        "test" : test_loss,
    }, epoch)
    
    summary.add_scalars("fraction", {
        "train" : train_frac,
        "test" :  test_frac,
    }, epoch)
   

    summary.file_writer.flush()
    
    #Saving the model
    if test_frac < best_frac:
        best_frac = test_frac
        checkpoints.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), checkpoints / "best_model.pth")

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

# Save Model

In [49]:
torch.save(model.state_dict(), "checkpoints/model.pth")

# Metrics

In [53]:
print(classification_report(ans, preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       172
           1       1.00      1.00      1.00       186
           2       1.00      1.00      1.00       179
           3       0.97      1.00      0.98       185
           4       1.00      1.00      1.00       191
           5       1.00      0.99      0.99       190
           6       1.00      0.99      1.00       198
           7       0.95      1.00      0.97       212
           8       0.97      0.99      0.98       200
           9       0.79      0.98      0.87       198
          10       1.00      1.00      1.00       202
          11       0.71      0.73      0.72       195
          12       1.00      0.95      0.97       215
          13       0.99      1.00      0.99       184
          14       1.00      1.00      1.00       185
          15       1.00      0.98      0.99       167
          16       0.94      0.99      0.97       189
          17       1.00    