In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import SubsetRandomSampler, DataLoader
from torchvision import transforms as T

import numpy as np
import matplotlib.pyplot as plt

%cd ../
from src.data.datasets import HandwritingDataset
from src.models import HandwritingClassifier
%cd notebooks/

%matplotlib inline
%load_ext autoreload
%autoreload 2

/home/nazar/Projects/ukrainian_handwriting
/home/nazar/Projects/ukrainian_handwriting/notebooks


In [2]:
device = torch.device("cuda:0")

# Loading train/test/val

In [3]:
MEAN = HandwritingClassifier._mean
STD = HandwritingClassifier._std

In [4]:
tf = T.Compose([
    T.RandomRotation(30),
    T.RandomAffine(0, (0.1, 0.1)),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD)
])

In [5]:
train_data = HandwritingDataset(
    '../data/processed/train_data.csv',
    transforms=tf
)

test_data = HandwritingDataset(
    '../data/processed/test_data.csv',
    transforms=T.Compose([
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
)

print('Number of samples in training data:', len(train_data))
print('Number of samples in test data:', len(test_data))

Number of samples in training data: 1281
Number of samples in test data: 300


In [6]:
BATCH_SIZE = 64
VAL_SIZE = 100

indices = list(range(len(train_data)))
np.random.seed(42)
np.random.shuffle(indices)
train_indices, val_indices = indices[VAL_SIZE:], indices[:VAL_SIZE]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(train_data, BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(train_data, sampler=val_sampler)
test_loader = DataLoader(test_data)

# Functions for training

In [7]:
def compute_accuracy(prediction, ground_truth):
    correct = torch.sum(prediction == ground_truth).item()
    return correct / len(ground_truth)


def validate(model, loader):
    model.eval()
    lbl_acc = 0
    is_upp_acc = 0
    loss_acum = 0
    for i, (x, *y, _) in enumerate(train_loader):
        x_gpu = x.to(device)
        y_gpu = y.to(device)
        
        prediction = model(x_gpu)
        loss_value = sum(loss(out, targ) for loss, out, targ in zip(losses, prediction, y_gpu))
        
        loss_acum += loss_value.item()
        lbl_acc += compute_accuracy(prediction[0], y_gpu[0])
        is_upp_acc += compute_accuracy(prediction[1], y_gpu[1])
    return loss_acum / i, lbl_acc / i, is_upp_acc / i  


def train_model(model, train_loader, val_loader, optimizer, losses, num_epochs, scheduler=None):
    t_loss_history = []
    v_loss_history = []
    lbl_acc_history = []
    is_upp_acc_hist = []
    
    for epoch in range(num_epochs):
        model.train()
        
        loss_acum = 0
        for i, (x, *y, _) in enumerate(train_loader):
            x_gpu = x.to(device)
            y_gpu = y.to(device)
            
            prediction = model(x_gpu)
            loss_value = sum(loss(out, targ) for loss, out, targ in zip(losses, prediction, y_gpu))
            
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()
            
            loss_acum += loss_value.item()
        epoch_loss = loss_acum / i
        val_loss, lbl_acc, is_upp_acc = validate(model, val_loader)
        
        if scheduler:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        t_loss_history.append()
        v_loss_history.append(val_loss)
        lbl_acc_history.append(lbl_acc)
        is_upp_acc_hist.append(is_upp_acc_hist)
        
        print(f'{epoch + 1}. Loss = {epoch_loss:.6f}; Val loss = {val_loss:.6f}')
        print(f'Label accuracy = {lbl_acc}; Is_upper accuracy = {is_upp_acc}')
    return t_loss_history, v_loss_history, lbl_acc_history, is_upp_acc_hist 


def plot_history(t_loss_h, v_loss_h, lbl_acc_h, is_upp_acc_h):
    fig, ax = plt.subplots(2, 1, figsize=(15, 7))
    ax[0].set_title('Train/validation Loss')
    ax[0].plot(t_loss_h, label='Train')
    ax[0].plot(v_loss_h, label='Validation')
    ax[0].legend()
    ax[1].set_title('Accuracy for 2 outputs')
    ax[1].plot(lbl_acc_h, label='Label')
    ax[1].plot(is_upp_acc_h, label='Is uppercase')
    ax[1].legend();

# Model training