In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.datasets import KMNIST
from torchvision.models import resnet18
from einops import rearrange
from einops.layers.torch import Rearrange

# Data

Need to add classes

In [23]:
import numpy as np
import loaders 
from torch.utils.data import DataLoader
import torch
from generator import Generator, fonts
import string
from augmentation import aug

In [24]:
threshold = 50
transforms = T.Compose([
#     Rearrange("h w -> () h w"),
    #TODO Normalize with dataset mean and std...
    T.Resize((64,64)),
    T.ToTensor(),
#     T.Lambda(lambda x: (x < threshold).float().mean(0, keepdim=True))
])

In [25]:
chars = string.ascii_letters

In [26]:
train_gen = Generator(chars, fonts, aug)
test_gen = Generator(chars, fonts)

In [27]:
train_ds = loaders.FontDs(train_gen, 60000, transforms)
test_ds =  loaders.FontDs(test_gen, 10000, transforms)

In [28]:
# train_ds = loaders.KMNIST10("data/10", tfms=transforms)
# test_ds = loaders.KMNIST10("data/10", train=False, tfms=transforms)

train_dl = DataLoader(train_ds, batch_size=32, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=32, num_workers=8)

In [29]:
# a = T.ToPILImage()(make_grid((train_ds.data[:40].unsqueeze(1) < threshold) * 255, pad_value=125))
# a

# imgs = []
# for label in train_ds.targets.unique():
#     data = train_ds.data[train_ds.targets == label][:40]
#     img = make_grid((data.unsqueeze(1) < threshold) * 255, pad_value=125)
#     imgs.append(img)

# T.ToPILImage()(torch.cat(imgs,1))

# Model

In [30]:
from models import create_model

In [31]:
device = "cuda"
n_classes = len(train_ds.char2idx)

In [32]:
model = create_model(3, n_classes)

In [33]:
model.to(device);

# Loss and Optim



In [34]:
import torch.nn.functional as F
from torch.optim import SGD

In [35]:
lr = 1e-4
opt = SGD(model.parameters(), lr=lr)

In [36]:
loss_fn = F.cross_entropy

# Loop

In [43]:
from torch.utils.tensorboard import FileWriter, SummaryWriter 
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, accuracy_score
from pathlib import Path

In [38]:
summary = SummaryWriter()

In [39]:
epoch = 0

In [40]:
best_acc = 0

In [46]:
checkpoints = Path("checkpoints")

In [None]:
for _ in range(5):
    epoch += 1
    
    #Training
    train_loss = 0
    preds = []
    ans = []
    pbar = tqdm(train_dl)

    for data in pbar:
        opt.zero_grad()
        x, y = [i.to(device) for i in data]
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        preds.extend(y_hat.argmax(1).tolist())
        ans.extend(y.tolist())

    train_acc = accuracy_score(preds, ans)
    train_loss /= len(train_ds)
    pbar.set_description(f"Training {epoch} | Loss: {train_loss:.3f} | Accuracy {train_acc:.3f}")

    # Testing
    test_loss = 0
    preds = [] 
    ans = []
    pbar = tqdm(test_dl)

    with torch.no_grad():
        for data in pbar:
            x, y = [i.to(device) for i in data]
            y_hat = model(x)
            test_loss += loss_fn(y_hat, y)
            preds.extend(y_hat.argmax(1).tolist())
            ans.extend(y.tolist())

    test_acc = accuracy_score(preds, ans)
    test_loss /= len(test_ds)
    pbar.set_description(f"Testing {epoch} | Loss: {test_loss:.3f} | Accuracy {test_acc:.3f}")
    
    
    # Tensorboard
    summary.add_scalars("loss", {
        "train" : train_loss,
        "test" : test_loss,
    }, epoch)
    
    summary.add_scalars("accuracy", {
        "train" : train_acc,
        "test" : test_acc,
    }, epoch)
   

    summary.file_writer.flush()
    
    #Saving the model
    
    
    if test_acc > best_acc:
        checkpoints.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), checkpoints / "best_model.pth")

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




HBox(children=(IntProgress(value=0, max=313), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

# Save Model

In [None]:
torch.save(model.state_dict(), "checkpoints/model.pth")

# Metrics

In [None]:
print(classification_report(ans, preds))