# Training notebook/script 2
#### Starting with some imports

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import cv2 as cv
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
from torchvision.utils import make_grid
from sklearn.metrics import f1_score
from tqdm.notebook import tqdm
%matplotlib inline

#### Defining directories and instantiating the main dataframe

In [None]:
root_dir="./training/rgb/"
csv_file=os.path.join(root_dir, "training_xy.csv")
main_df=pd.read_csv(csv_file, header=None).iloc[:32560, :]

## The Dataset
#### Defining the blueprint of the dataset

In [None]:
class FreihandDataset(Dataset):
    def __init__(self, root_dir, csv_file, transforms=None):
        self.main_df=pd.read_csv(csv_file, header=None).iloc[:32560, :]
        self.root_dir=root_dir
        self.transforms=transforms

    def __len__(self):
        return len(self.main_df)

    def __getitem__(self, idx):
        img_dir=os.path.join(self.root_dir, self.main_df.iloc[idx, 0])
        image=Image.open(img_dir)
        if self.transforms:
            image=self.transforms(image)
        keypoints=torch.from_numpy(self.main_df.iloc[idx, 1:].astype("float").to_numpy())
        return image, keypoints

#### Initializing an instance of the dataset

In [None]:
dataset=FreihandDataset(root_dir, csv_file, transforms=T.ToTensor())
len(dataset)

## Splitting the dataset
#### Defining the training and validation set sizes

In [None]:
val_set_len=int(0.2*len(dataset))
train_set_len=len(dataset)-val_set_len
train_set_len, val_set_len
train_set, val_set=random_split(dataset, [train_set_len, val_set_len])

## Creating Dataloader workers
#### Using torch.utils.data.DataLoader

In [None]:
batch_size=32
train_dl=DataLoader(train_set, batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_dl=DataLoader(val_set, batch_size, num_workers=0, pin_memory=True)

for batch in train_dl:
    print(batch)
    break

## Defining utility functions
#### Function to show the image with keypoints

In [None]:
def show_keypoints(img_name):
    plt.clf()
    img_dir=os.path.join(root_dir, img_name)
    image=Image.open(img_dir)
    keypoints=main_df[main_df[0]==img_name].iloc[0, 1:].astype("float").to_numpy().reshape(-1,2)
    plt.imshow(image)
    plt.scatter(keypoints[:, 0], keypoints[:, 1], s=20, marker='.', c='m')

In [None]:
show_keypoints(main_df.iloc[321, 0])

#### Function to get the default device

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

In [None]:
device=get_default_device()
device

#### Function to move data to the GPU

In [None]:
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


## Device Dataloaders
#### The blueprint defined below is used to create instances of Dataloaders on a particular device (here, the GPU)

In [None]:
class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

#### Instantiating device Dataloaders

In [None]:
train_dl=DeviceDataLoader(train_dl, device)
val_dl=DeviceDataLoader(val_dl, device)

## The Model
#### Importing the required architecture (resnet34) and creating model instance

In [None]:
import models as archs
model=archs.Net2()
model

#### Uncomment the below cell to load a custom pretrained model.

In [None]:
# pth_filename=???
# model.load_state_dict(torch.load(pth_filename))

Moving the model to the default device (here, GPU)

In [None]:
model.to(device)

## Helper functions to train and evaluate the model
#### Function to run validation passes

In [None]:
def validate(model, val_dl):
    with torch.no_grad():
        model.eval()
        validation_output=[model.validation_step(batch) for batch in val_dl]
        return model.validation_per_epoch(validation_output)

#### Function to retrieve current learning rate

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

#### Function to train the model for a specified number of epochs

In [None]:
def train(epochs, max_lr, model, train_dl, val_dl, weight_decay=0, opt_func=torch.optim.SGD, print_loss=True):
    torch.cuda.empty_cache()
    history=[]
    optimizer=opt_func(model.parameters(), lr=max_lr, weight_decay=weight_decay)
    sched=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_dl))

    for epoch in range(epochs):
        model.train()

        train_losses=[]
        lrs=[]
        for batch in tqdm(train_dl):
            loss=model.training_step(batch)
            train_losses.append(loss)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))
            sched.step()

        res=validate(model, val_dl)
        res["train_loss"]=torch.stack(train_losses).mean().item()
        res["lrs"]=lrs

        if print_loss:
            print(f"Epoch[{epoch}] -> Training Loss: {res['train_loss']}, Validation Loss: {res['val_loss']}")

        history.append(res)
    return history

## Training the model
#### Start off by testing how long a single epoch takes to train

In [None]:
history=train(1, 0.001, model, train_dl, val_dl)

#### Add more cells and train for more epochs with varying learning rates

In [None]:
# history+=train(???, ???, model, train_dl, val_dl)

#### After training, save the model to disk, to prevent retraining from scratch

In [None]:
torch.save(model.state_dict(), "last.pth")
