In [3]:
from dataset import YouCookII
from dataset import collate_fn
from torch.utils.data import DataLoader
from loss import loss_RA_MIL
from transformers import get_linear_schedule_with_warmup

import numpy as np
import torch

def train(model, num_actions, batch_size, epochs=25, lr=0.001, y=0.5):
    dataset = YouCookII(num_actions, "/h/sagar/ece496-capstone/datasets/ycii")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.2*epochs), epochs)

    train_loss = np.zeros(epochs)
    train_accuracy = np.zeros(epochs)
    
    model.train()
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0
        for data in dataloader:
            _, bboxes_tensor, features_tensor, steps_list, _, entity_count_list, _, _ = data
            batch_size = len(data[0])
            
            # Zero out any gradients.
            optimizer.zero_grad()

            # Run inference (forward pass).
            loss_E, loss_V, loss_R, _, _, _, _, _ = model(batch_size, num_actions + 1, steps_list, features_tensor, bboxes_tensor, entity_count_list)

            # Loss from alignment.
            loss_ = loss_RA_MIL(y, loss_R, loss_E, loss_V)
            print(loss_)

            # Backpropagation (backward pass).
            loss_.backward()

            # Update parameters.
            optimizer.step()
            
            epoch_loss += loss_
            num_batches += 1
            
        #after epoch completes
        print("---Epoch Completed----")
        
        #learning rate schedule
        #update after each epoch
        scheduler.step()
        epoch_loss = epoch_loss/num_batches 
        
        # TODO: save loss and accuracy at each epoch, plot (and checkpoint).
        train_loss[epoch] = epoch_loss
        
    return train_loss

In [4]:
from model import Model

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(device)
train_loss = train(model, 8, 2, epochs=10, lr=0.001)

tensor(2321.7051, device='cuda:0', grad_fn=<SumBackward0>)
tensor(1918.0331, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2539.8669, device='cuda:0', grad_fn=<SumBackward0>)
tensor(14848.9531, device='cuda:0', grad_fn=<SumBackward0>)
tensor(5964.3652, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2573.7678, device='cuda:0', grad_fn=<SumBackward0>)
tensor(9852.0645, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4632.1416, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4753.4468, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3790.0146, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2964.5046, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3543.1785, device='cuda:0', grad_fn=<SumBackward0>)
tensor(1628.5984, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2549.8374, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3639.9170, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2008.8545, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4511.3447, device='cuda:0', grad_fn=<SumBackward

tensor(2995.8647, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4488.3594, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2504.9846, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3000.2341, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4176.6421, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2407.5725, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3322.2197, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2536.4731, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3112.2537, device='cuda:0', grad_fn=<SumBackward0>)
tensor(6153.1655, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2399.5186, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2707.1736, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3433.7483, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2750.5962, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4816.3101, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4360.5991, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3926.3652, device='cuda:0', grad_fn=<SumBackward0

tensor(2876.2786, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2883.3423, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3100.4829, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3011.5413, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3728.7766, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4434.9805, device='cuda:0', grad_fn=<SumBackward0>)
tensor(3035.2493, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4259.2559, device='cuda:0', grad_fn=<SumBackward0>)
tensor(1747.7135, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4195.6553, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2608.1416, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2580.0137, device='cuda:0', grad_fn=<SumBackward0>)
