### 1 - Importing Stuff

In [None]:
import torch
import os
import time
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from data_loader import VideoFolder
from torchvision.transforms import Compose
from RT3D_16F import FullModel
import transforms as t
import utils
from tensorboardX import SummaryWriter
from IPython.core.display import HTML

%matplotlib inline

In [None]:
import json
writer = SummaryWriter()

with open('./configs.json') as data_file:
    config = json.load(data_file)

In [None]:
curr_folder = 'full_net_10'
if not os.path.exists(curr_folder):
    os.makedirs(curr_folder)

In [None]:
batch_size = 14
steps_before_print = 1000
num_workers = 0
step_size = 2
num_frames = 32 // step_size

### 2 - Seting up Data Loaders

In [None]:
std, mean = [0.2674,  0.2676,  0.2648], [ 0.4377,  0.4047,  0.3925]
transform = Compose([
    t.GroupResize((100, 160)),
    t.GroupRandomCrop((140, 100)),
    t.GroupRandomRotation(18),
    t.GroupToTensor(),
    t.GroupNormalize(std=std, mean=mean),
])

In [None]:
transform_validation = Compose([
    t.GroupResize((100, 160)),
    t.GroupRandomCrop((140, 100)),
    t.GroupToTensor(),
    t.GroupNormalize(std=std, mean=mean),
])

In [None]:
train_data = VideoFolder(
    root=config['train_data_folder'],
    csv_file_input=config['full_train_data_csv'],
    csv_file_labels=config['full_labels_csv'],
    clip_size=num_frames,
    nclips=1,
    step_size=step_size,
    is_val=False,
    transform=transform,
)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False, #changed
    drop_last=True)

In [None]:
validation_data = VideoFolder(
    root=config['train_data_folder'],
    csv_file_input=config['full_validation_data_csv'],
    csv_file_labels=config['full_labels_csv'],
    clip_size=num_frames,
    nclips=1,
    step_size=step_size,
    is_val=False,
    transform=transform_validation,
)

In [None]:
validation_loader = torch.utils.data.DataLoader(
    validation_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False, #changed
    drop_last=True)

In [None]:
def save_model(model, use_ts=False):
    if use_ts:
        time_stamp = time.strftime("%d_%b_%Y_%Hh%Mm", time.gmtime())
        torch.save(model.state_dict(), curr_folder + '/{}.ckp'.format(time_stamp))
    else:
        torch.save(model.state_dict(), curr_folder + '/{}.ckp'.format('best_model'))

### 3 - Model definition

In [None]:
model = FullModel(batch_size=batch_size)

In [None]:
most_recent_file = ''
for file in os.listdir(curr_folder):
    if file.endswith(".ckp"):
        file = os.path.join(".", file)
        if(file > most_recent_file):
            most_recent_file = file
if(most_recent_file != ''):
    print('Model LOADED: ', curr_folder + '/' + most_recent_file)
    loaded_dict = torch.load(curr_folder + '/' + most_recent_file)
    model.load_state_dict(loaded_dict)
else:
    print('No model loaded.')

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
if torch.cuda.is_available():
    print('Cuda is available!')
    model.cuda()

In [None]:
def train(epochs):
    
    print("Trainning is about to start...")
    best_valdiation_loss = model.best_valdiation_loss

    for epoch in range(epochs):
        step = 0
        epoch_loss = 0
        epoch_acc = 0
        times_calculated = 0
        total_size = len(train_loader)
        for i, (images, labels) in enumerate(train_loader):
            model.train()

            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            writer.add_scalar('trainning_loss', loss.item(), model.steps)
            loss.backward()
            optimizer.step()

            step += 1
            epoch_loss += loss.item()

            if step % steps_before_print == 0:
                # Calculate Accuracy
                model.eval()
                validation_loss, accuracy = utils.calculate_loss_and_accuracy(validation_loader, model, criterion, stop_at = 1200)
                writer.add_scalar('validation_loss', validation_loss, model.steps)
                writer.add_scalar('accuracy', accuracy, model.steps)
                epoch_acc += accuracy
                times_calculated += 1
                # Print Loss
                print('Iteration: {}/{} - ({:.2f}%). Loss: {}. Accuracy: {}'.format(step, total_size, step*100/total_size , loss.item(), accuracy))
                if validation_loss < model.best_valdiation_loss:
                    model.best_valdiation_loss = validation_loss
                    print('Saving best model')
                    save_model(model)
                del validation_loss
            del loss, outputs, images, labels

        model.epochs += 1

        #print('Epoch({}) avg loss: {} avg acc: {}'.format(epoch, epoch_loss/step, epoch_acc/times_calculated))
        print('Epoch ', epoch)
        #save_model(model, use_ts=True)                

In [None]:
#Learning rate starting at 10e-03.

learning_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [None]:
train(50)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(50)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(50)
save_model(model, use_ts=True)
learning_rate = learning_rate / 10
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(20)
save_model(model, use_ts=True)

In [None]:
#Saves model with a timestamp (prevents overwritting)
save_model(model, use_ts=True)

In [None]:
#Check accuracy for all saved checkpoints

for file in os.listdir(curr_folder):
    if file.endswith(".ckp"):
        print(file)
        print('Model LOADED: ', curr_folder + '/' + file)
        loaded_dict = torch.load(curr_folder + '/' + file)
        #loaded_dict = {k: v for k, v in loaded_dict.items() if not k.startswith('combiner') }
        #model.load_state_dict(loaded_dict, strict=False)
        model.load_state_dict(loaded_dict)
        model.eval()
        validation_loss, accuracy = utils.calculate_loss_and_accuracy(validation_loader, model, criterion, 1500)
        validation_loss, train_accuracy = utils.calculate_loss_and_accuracy(train_loader, model, criterion, 1500)
        print('Validation Acc: {} \t Train Acc: {}'.format(accuracy, train_accuracy))
