In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os, sys
import numpy as np
import torchvision.models as models
import time


In [2]:
from load_moments_dataset import Moments
from global_params import *
from Frame2dResNet50 import Frame2dResNet50


In [3]:
debug = False

In [4]:
trainset = Moments(subset='training', use_frames=4)
valset = Moments(subset='validation', use_frames=4)
print("Number of training videos:", len(trainset))
print("Number of validation videos:", len(valset))

if (debug) :
    video_info = trainset.__getitem__(3)
    print(video_info[0].shape)
    print(video_info[1])

Number of training videos: 100000
Number of validation videos: 10000


In [5]:
trainset_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=12)
valset_loader = DataLoader(valset, batch_size=16, shuffle=True, num_workers=12)

if (debug) :
    for batch_idx, (data, target, _) in enumerate(trainset_loader):
        print('batch', batch_idx)
        print('data.shape=', data.shape)
        print('target=', target)

        if (batch_idx >= 1) :
            break
    print("") 
    for batch_idx, (data, target, _) in enumerate(valset_loader):
        print('batch', batch_idx)
        print('data.shape=', data.shape)
        print('target=', target)

        if (batch_idx >= 1) :
            break

In [6]:
print(device)

cuda


In [7]:
if debug :
    model = Frame2dResNet50(use_pretrain=4).to(device)

    model.train()
    x = torch.zeros((64, 3, 16, 224, 224)).to(device)
    scores = model(x)
    print(scores.shape) ## Should give torch.Size([64, 200])
    print(F.cross_entropy(scores, torch.zeros(64, dtype=torch.long, device=device)))  

    model.eval()
    x = torch.zeros((4, 3, 16, 224, 224)).to(device)
    scores = model(x)
    print(scores.shape) ## Should give torch.Size([64, 200])
    print(F.cross_entropy(scores, torch.zeros(4, dtype=torch.long, device=device)))    

In [8]:
def save_training_state(checkpoint_path, model, optimizer):
    state = {'state_dict': model.state_dict(),
             'optimizer' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to %s' % checkpoint_path)
    
def load_training_state(checkpoint_path, model, optimizer):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print('model loaded from %s' % checkpoint_path)

In [9]:
def test(model, test_size = 2000, print_to = sys.stdout):
    model.eval()
    num_correct1 = 0
    num_correct5 = 0
    num_samples = 0
    torch.manual_seed(123)
    with torch.no_grad():
        for data, target, _ in valset_loader:
            data_g = data.to(device)
            scores = model(data_g).data.cpu().numpy()
            preds = scores.argsort(axis=1)[:,-5:]
            target_np = target.numpy()
            batch_sz = preds.shape[0]
            num_correct1 += (preds[:,-1] == target_np).sum()
            num_correct5 += sum([1 if target_np[i] in preds[i] else 0 for i in range(batch_sz)])
            num_samples += batch_sz
            if (debug) :
                break
            if (num_samples % 400 == 0) :
                print('Tested [{}/{} ({:.2%})]'.format(
                    num_samples, len(valset), num_samples / len(valset)), flush=True)
    acc1 = 1.0*num_correct1/num_samples
    acc5 = 1.0*num_correct5/num_samples
    print('\tValidation set accuracy: top-1 {}/{} ({:.2%}), top-5 {}/{} ({:.2%})'.format(
        num_correct1, num_samples, acc1, num_correct5, num_samples, acc5), flush=True, file=print_to)


In [13]:
def testLoad() :
    pre_imgnet = 4
    nllr = 4
    pre_ours = 3
    param_val = 'p%dlr%d'%(pre_imgnet, nllr)

    model = Frame2dResNet50().to(device)
    optimizer = optim.Adam(model.parameters(), lr=10**(-nllr))
    load_training_state(os.path.join(savedPath, '2dResNet-'+param_val+'-3.pth'), model, optimizer)
    torch.manual_seed(123)
    test(model)
    
testLoad()

model loaded from ../data/model_param_saved/2dResNet-p4lr4-3.pth
	Validation set accuracy: top-1 1806/10000 (18.06%), top-5 4071/10000 (40.71%)


In [21]:
def train_save(epoch, model, optimizer, param_val, print_to = sys.stdout, epoc_start = 0):
    """
    @pre: the return of @a model is the score of each category, which we will use cross-entropy loss
    """
    print(param_val)
    model.train()  # set training mode
    iteration = 0
    for ep in range(epoc_start, epoc_start+epoch):
        t0 = time.time()
        for batch_idx, (data, target, _) in enumerate(trainset_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            scores = model(data)
            loss = F.cross_entropy(scores, target)
            loss.backward()
            optimizer.step()
            if iteration % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, batch_idx * len(data), len(trainset_loader.dataset),
                    100. * batch_idx / len(trainset_loader), loss.item()), flush=True, file=print_to)
            if batch_idx == 800 :
                test(model, print_to=print_to)
                model.train()
            iteration += 1
        save_training_state(os.path.join(savedPath, '2dResNet-'+param_val+'-%d.pth'%(ep+1)), model, optimizer)
        test(model, print_to=print_to)
        model.train()
        t1 = time.time()
        print("Epoch %d done, takes %fs"%(ep+1, t1-t0), flush=True)
        print("Epoch %d done, takes %fs"%(ep+1, t1-t0), flush=True, file=print_to)

In [22]:
#

In [23]:
# torch.manual_seed(123)
# model = FrameResNet50(use_pretrain=3).to(device)
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# train_save(5, model, optimizer)
# # load_checkpoint("../data/results/2d_resnet-2.pth", model, optimizer)

In [24]:
def finetuneparam(pre_imgnet, nllr, use_pre_ours=0, epocs=3) :
    torch.manual_seed(123)
    param_val = 'p%dlr%d'%(pre_imgnet, nllr)
    if (use_pre_ours == 0) :
        model = Frame2dResNet50(use_pretrain=pre_imgnet).to(device)
        optimizer = optim.Adam(model.parameters(), lr=10**(-nllr))
        torch.manual_seed(123)
    else :
        model = Frame2dResNet50().to(device)
        optimizer = optim.Adam(model.parameters(), lr=10**(-nllr))
        load_training_state(os.path.join(savedPath, '2dResNet-'+param_val+'-%d.pth'%(use_pre_ours)), model, optimizer)
        torch.manual_seed(123)
    log_file = open('log/log-p%dlr%d.txt'%(pre_imgnet, nllr), 'a')
    train_save(epocs, model, optimizer, param_val, print_to=log_file, epoc_start=use_pre_ours)
    log_file.close()

In [None]:
# param_options = [(4, 4), (4, 3), (3, 4), (3, 3)]
# for pre, nllr in param_options :
#     tuneparam(pre, nllr, pre_ours=0, epocs=3)

In [25]:
finetuneparam(4, 5, use_pre_ours=0, epocs=3)

p4lr5
model saved to ../data/model_param_saved/2dResNet-p4lr5-1.pth
Epoch 1 done, takes 2789.668236s


Process Process-70:
Process Process-72:
Process Process-68:
Process Process-69:
Process Process-62:
Process Process-67:
Process Process-65:
Process Process-71:
Process Process-66:
Process Process-64:
Process Process-63:
Process Process-61:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap


  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/queues.py", line 347, in put
    self._writer.send_bytes(obj)
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/connection.py", l

KeyboardInterrupt: 