In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import torchvision.models as models
import time


In [2]:
rootPath = "../data"   ## subject to change
resultPath = os.path.join(rootPath, 'results')

In [3]:
import torch
import torchvision.transforms as trn
from torch.utils.data import Dataset
import glob
import os
from PIL import Image

def buildIndexLabelMapping() :
    idx2label = os.listdir(os.path.join(rootPath, 'Moments_in_Time_Mini/jpg/validation'))
    label2idx = {}
    for i, label in enumerate(idx2label) :
        label2idx[label] = i
    return idx2label, label2idx

idx2label, label2idx = buildIndexLabelMapping()
    

class Moments(Dataset) :
    """
    A customized data loader for Moments-In-Time dataset.
    """    
    def __init__(self, subset='validation', use_frames=16) :
        super().__init__()
        root = os.path.join(rootPath, 'Moments_in_Time_Mini/jpg', subset)     
        self.use_frames = use_frames
        
        self.filenames = []

        for video_path in glob.glob(os.path.join(root, "*/*")) :
            label = video_path.split('/')[-2]
            self.filenames.append((video_path, label2idx[label]))
        self.len = len(self.filenames)
        
        self.tf = trn.Compose([trn.Resize((224, 224)), 
                               trn.ToTensor(), 
                               trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ## subject to change
                              ])
    
    def __getitem__(self, index) :
        video_path, label = self.filenames[index]
        tot_frames = len(os.listdir(video_path)) - 1
        video = []
        time_spacing = (tot_frames-1)//(self.use_frames-1)
        for i in range(1, 1+self.use_frames * time_spacing, time_spacing) :
            img = Image.open(os.path.join(video_path, 'image_{:05d}.jpg'.format(i))).convert('RGB')
            video.append(self.tf(img))
        return torch.stack(video, dim=1), label

    def __len__(self) :
        return self.len

In [4]:
debug = False

In [5]:
trainset = Moments(subset='training')
valset = Moments(subset='validation')
print("Number of training videos:", len(trainset))
print("Number of validation videos:", len(valset))

if (debug) :
    video_info = trainset.__getitem__(3)
    print(video_info[0].shape)
    print(video_info[1])

Number of training videos: 100000
Number of validation videos: 10000


In [6]:
trainset_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=8)
valset_loader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=4)

if (debug) :
    for batch_idx, (data, target) in enumerate(trainset_loader):
        print('batch', batch_idx)
        print('data.shape=', data.shape)
        print('target=', target)

        if (batch_idx >= 1) :
            break
    print("") 
    for batch_idx, (data, target) in enumerate(valset_loader):
        print('batch', batch_idx)
        print('data.shape=', data.shape)
        print('target=', target)

        if (batch_idx >= 1) :
            break

In [7]:
class FrameResNet50(nn.Module) :
    def __init__(self, use_pretrain=-1, num_classes=200) :
        super().__init__()
        self.frame_model = models.resnet50(num_classes=num_classes) ## back to 50
        if (use_pretrain >= 0) :
            self.loadPretrainedParam(use_pretrain)
        
    def forward(self, x) :
        B, C, T, H, W = x.shape
        if self.training :
            return self.frame_model(x[:,:,T//2,:,:])
        else :
            logits = self.frame_model(x.permute(0, 2, 1, 3, 4).contiguous().view(-1, C, H, W))
            return logits.view(B, T, -1).mean(dim=1)
    
    def loadPretrainedParam(self, n_levels) :
        assert(n_levels <= 4)
        resnet_imgnet_checkpoint = torch.load(os.path.join(rootPath, 'models/resnet50-19c8e357.pth'))
        # resnet_imgnet_checkpoint = torch.load(os.path.join(rootPath, 'models/resnet18-5c106cde.pth'))  ## back to 50
        states_to_load = {}
        for name, param in resnet_imgnet_checkpoint.items() :
            if name.startswith('fc') :
                continue
            if name.startswith('layer') :
                if int(name[5]) <= n_levels :
                    states_to_load[name]=param
            else :
                states_to_load[name]=param
        model_state = self.frame_model.state_dict()
        model_state.update(states_to_load)
        self.frame_model.load_state_dict(model_state)
        

In [8]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


# device = torch.device("cpu")

print(device)

cuda


In [9]:
if debug :
    model = FrameResNet50(use_pretrain=4).to(device)

    model.train()
    x = torch.zeros((64, 3, 16, 224, 224)).to(device)
    scores = model(x)
    print(scores.shape) ## Should give torch.Size([64, 200])
    print(F.cross_entropy(scores, torch.zeros(64, dtype=torch.long, device=device)))  

    model.eval()
    x = torch.zeros((4, 3, 16, 224, 224)).to(device)
    scores = model(x)
    print(scores.shape) ## Should give torch.Size([64, 200])
    print(F.cross_entropy(scores, torch.zeros(4, dtype=torch.long, device=device)))    

In [10]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'state_dict': model.state_dict(),
             'optimizer' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to %s' % checkpoint_path)
    
def load_checkpoint(checkpoint_path, model, optimizer):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print('model loaded from %s' % checkpoint_path)

In [11]:
def test(test_size = 1000):
    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for data, target in valset_loader:
            data, target = data.to(device), target.to(device)
            scores = model(data)
            _, preds = scores.max(1)
            num_correct += (preds == target).sum()
            num_samples += preds.size(0)
            # if (num_samples%100 == 0) :
            #     print("Number of test sample examined:", str(num_samples))
            if (num_samples >= test_size) :
                break

    acc = 100.0 * num_correct / num_samples
    print('\tValidation set accuracy: {}/{} ({:.2f}%)\n'.format(num_correct, num_samples, acc))

In [22]:
def train_save(epoch, model, optimizer):
    """
    @pre: the result of @a model is the score of each category, which we will use cross-entropy loss
    """
    model.train()  # set training mode
    iteration = 0
    for ep in range(epoch):
        t0 = time.time()
        for batch_idx, (data, target) in enumerate(trainset_loader):
            # print("iteration =", iteration)
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            scores = model(data)
            loss = F.cross_entropy(scores, target)
            loss.backward()
            optimizer.step()
            if iteration % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, batch_idx * len(data), len(trainset_loader.dataset),
                    100. * batch_idx / len(trainset_loader), loss.item()))
            if batch_idx == 800 :
                test()
            iteration += 1
        save_checkpoint(os.path.join(resultPath, '2d_resnet-%i.pth'%ep+1), model, optimizer)
        test()
        t1 = time.time()
        print("Epoch %d done, takes %fs\n"%(ep+1, t1-t0))
    
    # save the final model
    save_checkpoint(os.path.join(resultPath, '2d_resnet-final.pth'), model, optimizer)

In [23]:
torch.manual_seed(123)
model = FrameResNet50(use_pretrain=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_save(1, model, optimizer)
# test()



Process Process-60:
Process Process-57:
Process Process-62:
Process Process-58:
Process Process-59:
Process Process-61:
Process Process-63:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multi



KeyboardInterrupt
  File "/home/shared/anaconda3/lib/python3.6/site-packages/PIL/Image.py", line 1747, in resize
    return self._new(self.im.resize(size, resample, box))
KeyboardInterrupt
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/shared/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/shared/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/shared/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/shared/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/shar

KeyboardInterrupt: 