In [1]:
from dataset import YouCookII
from dataset import collate_fn
from torch.utils.data import DataLoader
from loss import loss_RA_MIL
from transformers import get_linear_schedule_with_warmup

import numpy as np
import torch
import matplotlib.pyplot as plt

def train(model, num_actions, batch_size, epochs=25, lr=0.001, y=0.5):
    dataset = YouCookII(num_actions, "/h/sagar/ece496-capstone/datasets/ycii")
    train_size = int(len(dataset) * (2/3))
    valid_size = int(len(dataset) - train_size)
    
    train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.2*epochs), epochs)

    train_loss = np.zeros(epochs)
    valid_loss = np.zeros(epochs)
    
    model.train()
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0
        for data in train_dataloader:
            _, bboxes_tensor, features_tensor, steps_list, _, entity_count_list, _, _ = data
            batch_size = len(data[0])
            
            # Zero out any gradients.
            optimizer.zero_grad()

            # Run inference (forward pass).
            loss_E, loss_V, loss_R, _, _, _, _, _ = model(batch_size, num_actions + 1, steps_list, features_tensor, bboxes_tensor, entity_count_list)

            # Loss from alignment.
            loss_ = loss_RA_MIL(y, loss_R, loss_E, loss_V)

            # Backpropagation (backward pass).
            loss_.backward()

            # Update parameters.
            optimizer.step()
            
            epoch_loss += loss_
            num_batches += 1
        
        #learning rate schedule
        #update after each epoch
        scheduler.step()
        epoch_loss = epoch_loss / num_batches
        
        # Save loss and accuracy at each epoch, plot (and checkpoint).
        train_loss[epoch] = epoch_loss
        valid_loss[epoch] = get_validation_loss(num_actions, y, valid_dataloader)
        
        #after epoch completes
        print("Epoch {} - Train Loss: {}, Validation Loss: {}".format(epoch, train_loss[epoch], valid_loss[epoch]))
    
    plt.plot(train_loss, valid_loss)
        
    return train_loss, valid_loss

def get_validation_loss(num_actions, y, valid_dataloader):
    epoch_loss = 0.0
    num_batches = 0
        
    with torch.no_grad():
        for data in valid_dataloader:
            _, bboxes_tensor, features_tensor, steps_list, _, entity_count_list, _, _ = data
            batch_size = len(data[0])

            # Run inference (forward pass).
            loss_E, loss_V, loss_R, _, _, _, _, _ = model(batch_size, num_actions + 1, steps_list, features_tensor, bboxes_tensor, entity_count_list)

            # Loss from alignment.
            loss_ = loss_RA_MIL(y, loss_R, loss_E, loss_V)
            
            epoch_loss += loss_
            num_batches += 1
            
    epoch_loss = epoch_loss / num_batches
    
    return epoch_loss

In [None]:
from model import Model

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(device)
train_loss = train(model, 8, 2, epochs=50, lr=1e-4)

Epoch 0 - Train Loss: 1984.5281982421875, Validation Loss: 1601.7069091796875
Epoch 1 - Train Loss: 1756.751220703125, Validation Loss: 1400.0838623046875
Epoch 2 - Train Loss: 1506.93994140625, Validation Loss: 1183.3402099609375
Epoch 3 - Train Loss: 1374.0240478515625, Validation Loss: 1397.35498046875
Epoch 4 - Train Loss: 1236.9931640625, Validation Loss: 1048.4273681640625
Epoch 5 - Train Loss: 1381.4388427734375, Validation Loss: 1177.0594482421875
Epoch 6 - Train Loss: 1465.8040771484375, Validation Loss: 1674.5927734375
Epoch 7 - Train Loss: 1784.80712890625, Validation Loss: 1846.4990234375
Epoch 8 - Train Loss: 1567.5465087890625, Validation Loss: 879.2543334960938
Epoch 9 - Train Loss: 1180.1571044921875, Validation Loss: 785.67724609375
Epoch 10 - Train Loss: 1000.3878784179688, Validation Loss: 735.6682739257812
Epoch 11 - Train Loss: 1185.2210693359375, Validation Loss: 724.5114135742188
Epoch 12 - Train Loss: 786.1958618164062, Validation Loss: 501.7469787597656
Epoch 1