In [55]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.datasets import KMNIST
from torchvision.models import resnet18
from einops import rearrange
from einops.layers.torch import Rearrange

# Data

Need to add classes

In [57]:
import numpy as np
import loaders 
from torch.utils.data import DataLoader
import torch
from generator import Generator, fonts
import string
from augmentation import aug

In [58]:
threshold = 50
transforms = T.Compose([
#     Rearrange("h w -> () h w"),
    #TODO Normalize with dataset mean and std...
    T.Resize((64,64)),
    T.ToTensor(),
#     T.Lambda(lambda x: (x < threshold).float().mean(0, keepdim=True))
])

In [59]:
chars = string.ascii_letters

In [60]:
train_gen = Generator(chars, fonts, aug)
test_gen = Generator(chars, fonts)

In [61]:
train_ds = loaders.FontDs(train_gen, 60000, transforms)
test_ds =  loaders.FontDs(test_gen, 10000, transforms)

In [62]:
# train_ds = loaders.KMNIST10("data/10", tfms=transforms)
# test_ds = loaders.KMNIST10("data/10", train=False, tfms=transforms)

train_dl = DataLoader(train_ds, batch_size=32, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=32, num_workers=8)

In [63]:
# a = T.ToPILImage()(make_grid((train_ds.data[:40].unsqueeze(1) < threshold) * 255, pad_value=125))
# a

# imgs = []
# for label in train_ds.targets.unique():
#     data = train_ds.data[train_ds.targets == label][:40]
#     img = make_grid((data.unsqueeze(1) < threshold) * 255, pad_value=125)
#     imgs.append(img)

# T.ToPILImage()(torch.cat(imgs,1))

# Model

In [64]:
from models import create_model

In [65]:
device = "cuda"
n_classes = len(train_ds.char2idx)

In [66]:
model = create_model(3, n_classes)

In [67]:
model.to(device);

# Loss and Optim



In [68]:
import torch.nn.functional as F
from torch.optim import SGD

In [69]:
lr = 1e-4
opt = SGD(model.parameters(), lr=lr)

In [70]:
loss_fn = F.cross_entropy

# Loop

In [71]:
from torch.utils.tensorboard import FileWriter, SummaryWriter 
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, accuracy_score
from pathlib import Path

In [72]:
summary = SummaryWriter()

In [73]:
epoch = 0

In [74]:
best_acc = 0

In [75]:
checkpoints = Path("checkpoints")

In [None]:
for _ in range(20):
    epoch += 1
    
    #Training
    train_loss = 0
    preds = []
    ans = []
    pbar = tqdm(train_dl)

    for data in pbar:
        opt.zero_grad()
        x, y = [i.to(device) for i in data]
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        preds.extend(y_hat.argmax(1).tolist())
        ans.extend(y.tolist())

    train_acc = accuracy_score(preds, ans)
    train_loss /= len(train_ds)
    pbar.set_description(f"Training {epoch} | Loss: {train_loss:.3f} | Accuracy {train_acc:.3f}")

    # Testing
    test_loss = 0
    preds = [] 
    ans = []
    pbar = tqdm(test_dl)

    with torch.no_grad():
        for data in pbar:
            x, y = [i.to(device) for i in data]
            y_hat = model(x)
            test_loss += loss_fn(y_hat, y)
            preds.extend(y_hat.argmax(1).tolist())
            ans.extend(y.tolist())

    test_acc = accuracy_score(preds, ans)
    test_loss /= len(test_ds)
    pbar.set_description(f"Testing {epoch} | Loss: {test_loss:.3f} | Accuracy {test_acc:.3f}")
    
    
    # Tensorboard
    summary.add_scalars("loss", {
        "train" : train_loss,
        "test" : test_loss,
    }, epoch)
    
    summary.add_scalars("accuracy", {
        "train" : train_acc,
        "test" : test_acc,
    }, epoch)
   

    summary.file_writer.flush()
    
    #Saving the model
    
    
    if test_acc > best_acc:
        best_acc = test_acc
        checkpoints.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), checkpoints / "best_model.pth")

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

# Save Model

In [49]:
torch.save(model.state_dict(), "checkpoints/model.pth")

# Metrics

In [53]:
print(classification_report(ans, preds))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       172
           1       1.00      1.00      1.00       186
           2       1.00      1.00      1.00       179
           3       0.97      1.00      0.98       185
           4       1.00      1.00      1.00       191
           5       1.00      0.99      0.99       190
           6       1.00      0.99      1.00       198
           7       0.95      1.00      0.97       212
           8       0.97      0.99      0.98       200
           9       0.79      0.98      0.87       198
          10       1.00      1.00      1.00       202
          11       0.71      0.73      0.72       195
          12       1.00      0.95      0.97       215
          13       0.99      1.00      0.99       184
          14       1.00      1.00      1.00       185
          15       1.00      0.98      0.99       167
          16       0.94      0.99      0.97       189
          17       1.00    