In [1]:
# Sets seeds for reproducability.

import torch
torch.manual_seed(0)

import random
random.seed(0)

import numpy as np
np.random.seed(0)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [2]:
from dataset import YouCookII
from dataset import YouCookIICollate
from torch.utils.data import DataLoader
from loss import *
from accuracy import *
from transformers import get_linear_schedule_with_warmup
from model import Model
from torch import nn

import numpy as np
import torch
import matplotlib.pyplot as plt

def train(model, num_actions_train=4, batch_size=4, epochs=25, lr=0.001, MAX_DETECTIONS=20):    
    train_datasets = [YouCookII(num_actions_train, "/h/sagar/ece496-capstone/datasets/ycii_{}".format(num_actions_train))]
    
    # Validation set defaults to test set for now for diagnosing.
    num_actions_valid = [4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 25, 27]
    valid_datasets = [YouCookII(num_action, "/h/sagar/ece496-capstone/datasets/fi") for num_action in num_actions_valid]
    
    train_size = sum([len(train_dataset) for train_dataset in train_datasets])
    valid_size = sum([len(valid_dataset) for valid_dataset in valid_datasets])
    
    print("Training Dataset Size: {}, Validation Dataset Size: {}".format(train_size, valid_size))
    print("Effective Batch Size: {} * {} = {}".format(num_actions_train, batch_size, num_actions_train * batch_size))
    print("Learning Rate: {}, Epochs: {}".format(lr, epochs))
    
    collate = YouCookIICollate(MAX_DETECTIONS=MAX_DETECTIONS)
    
    train_dataloaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate, drop_last=True, worker_init_fn=seed_worker)
                         for train_dataset in train_datasets]
    valid_dataloaders = [DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate, drop_last=False, worker_init_fn=seed_worker)
                         for valid_dataset in valid_datasets]
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.2*epochs), epochs)

    train_loss = np.zeros(epochs)
    valid_loss = np.zeros(epochs)
    
    train_accuracy = np.zeros(epochs)
    valid_accuracy = np.zeros(epochs)
    
    for epoch in range(epochs):
        model.train()
        
        epoch_loss = 0.
        datapoints = 0
        
        for train_dataloader in train_dataloaders:
            for input_data in train_dataloader:
                _, bboxes, features, actions, steps, entities, entity_count, _ = input_data
                
                # Zero out any gradients.
                optimizer.zero_grad()

                # Run inference (forward pass).
                loss_data, VG, RR = model(steps, features, bboxes, entities, entity_count)

                # Loss from alignment.
                loss_ = compute_loss_batched(loss_data)

                # Backpropagation (backward pass).
                loss_.backward()

                # Update parameters.
                optimizer.step()

                epoch_loss += loss_
                datapoints += len(steps) * len(actions[0])
                            
        # Scheduler update.
        scheduler.step()
        epoch_loss = epoch_loss / datapoints
        
        # Save loss and accuracy at each epoch and plot.
        train_loss[epoch] = float(epoch_loss)
        train_accuracy[epoch] = get_alignment_accuracy(model, train_dataloaders) 
        
        valid_loss[epoch] = get_alignment_loss(model, valid_dataloaders)
        valid_accuracy[epoch] = get_alignment_accuracy(model, valid_dataloaders)
        
        print("Epoch {} - Train Loss: {:.2f}, Validation Loss: {:.2f}, Train Accuracy: {:.2f}, Validation Accuracy: {:.2f}"
              .format(epoch + 1, train_loss[epoch], valid_loss[epoch], train_accuracy[epoch], valid_accuracy[epoch]))
    
    plt.figure()
    plt.plot(train_loss, label='train loss')
    plt.plot(valid_loss, label='valid loss')
    plt.legend()
    
    plt.figure()
    plt.plot(train_accuracy, label='train accuracy')
    plt.plot(valid_accuracy, label='valid accuracy')
    plt.legend()
    
    plt.show()
        
    return train_loss, valid_loss, train_accuracy, valid_accuracy, VG, loss_data, input_data

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(device)

In [None]:
# Trainer.

num_actions_train = 4
batch_size = 8
epochs = 50
lr = 1e-5

train_loss, valid_loss, train_accuracy, valid_accuracy, VG, loss_data, input_data = train(
    model, 
    num_actions_train=num_actions_train,
    batch_size=batch_size,
    epochs=epochs,
    lr=lr
)

Training Dataset Size: 813, Validation Dataset Size: 54
Effective Batch Size: 4 * 8 = 32
Learning Rate: 1e-05, Epochs: 50
Epoch 1 - Train Loss: 63.62, Validation Loss: 324.58, Train Accuracy: 0.49, Validation Accuracy: 0.52
Epoch 2 - Train Loss: 62.10, Validation Loss: 305.97, Train Accuracy: 0.51, Validation Accuracy: 0.51
Epoch 3 - Train Loss: 61.45, Validation Loss: 304.60, Train Accuracy: 0.52, Validation Accuracy: 0.51
Epoch 4 - Train Loss: 60.17, Validation Loss: 299.99, Train Accuracy: 0.52, Validation Accuracy: 0.51
Epoch 5 - Train Loss: 59.36, Validation Loss: 301.89, Train Accuracy: 0.53, Validation Accuracy: 0.51
Epoch 6 - Train Loss: 58.80, Validation Loss: 297.70, Train Accuracy: 0.54, Validation Accuracy: 0.52
Epoch 7 - Train Loss: 57.12, Validation Loss: 294.78, Train Accuracy: 0.55, Validation Accuracy: 0.51


In [None]:
# Evaluation.

from eval_fi import eval_all_dataset
eval_all_dataset(model, path="/h/sagar/ece496-capstone/datasets/fi")

In [None]:
# Visualizer.

from visualizer import inference

YCII = "/h/sagar/ece496-capstone/datasets/ycii"
FI = "/h/sagar/ece496-capstone/datasets/fi"

VG, RR = inference(model, 27, 0, FI)

In [None]:
# Saving and loading weights.

SAVE = True
LOAD = False

if SAVE:
    torch.save(model.state_dict(), "/h/sagar/ece496-capstone/weights/t3")
    
if LOAD:
    model.load_state_dict(torch.load("/h/sagar/ece496-capstone/weights/t1"))

In [None]:
# Reload modules.

import importlib
import visualizer
import eval_fi
import model as mdl
import loss

importlib.reload(visualizer)
importlib.reload(eval_fi)
importlib.reload(mdl)
importlib.reload(loss)
importlib.reload(torch)