In [41]:
import os
import glob
import time
import copy

import numpy as np
import cv2
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import datasets, models, transforms

### Data Loader

In [36]:
class MotionVectorDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mode):
        self.mode = mode
        motion_vectors_file = os.path.join(root_dir, "preprocessed/motion_vectors_{}.pkl".format(mode))
        self.motion_vectors = pickle.load(open(motion_vectors_file, "rb"))   
        
        bounding_boxes_file = os.path.join(root_dir, "preprocessed/bounding_boxes_{}.pkl".format(mode))
        self.bounding_boxes = pickle.load(open(bounding_boxes_file, "rb"))
        
        if mode == "train" or mode == "val":
            box_velocities_files = os.path.join(root_dir, "preprocessed/box_velocities_{}.pkl".format(mode)) 
            self.box_velocities = pickle.load(open(box_velocities_files, "rb"))        
               
    def __len__(self):
        return len(self.motion_vectors)
        
    def __getitem__(self, idx):
        mvs = self.motion_vectors[idx]
        mvs = mvs.permute(2, 0, 1).float()
        
        boxes = self.bounding_boxes[idx].float()
        
        item = mvs, boxes, None
        
        if self.mode == "train" or self.mode == "val":
            y = self.box_velocities[idx].long()
            item = mvs, y
        
        return item

### Propagation Network

In [42]:
class PropagationNetwork(nn.Module):
    def __init__(self):
        super(PropagationNetwork, self).__init__()
        self.base = torchvision.models.resnet18(pretrained=True)
        
        # change number of input channels from 3 to 2
        #self.base.conv1.in_channels = 2
        self.base.conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # remove fully connected and avg pool layers
        self.base = nn.Sequential(*list(self.base.children())[:-2])
        
        self.conv1 = nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        
        #print(list(self.base.modules()))
        
    def forward(self, x):
        
        print(x.shape)
        x = self.base(x)
        print(x.shape)
        x = self.conv1(x)
        y_pred = F.relu(x)
        print(y_pred.shape)
        
        return y_pred
    
model = PropagationNetwork()

### Training loop

In [43]:
def train(model, criterion, optimizer, scheduler, num_epochs=2):
    tstart = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    pickle.dump(best_model_wts, open("models/best_model.pkl", "wb"))
    best_loss = 999999.0

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs-1))
        
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0

            for x, y in dataloaders[phase]:
                x = x.to(device)
                y = y.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == "train"):
                    y_pred = model(x)
                    loss = criterion(y_pred, y)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        scheduler.step()
                        
                running_loss += loss.item() * x.size(0)
                
            epoch_loss = running_loss / len(datasets[phase])
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            
            if phase == "val":
                model_wts = copy.deepcopy(model.state_dict())
                pickle.dump(best_model_wts, open("models/model_{:04d}.pkl".format(epoch), "wb"))
            
            if phase == "val" and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                pickle.dump(best_model_wts, open("models/best_model.pkl", "wb"))
                
    time_elapsed = time.time() - tstart
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Lowest validation loss: {:4f}'.format(best_loss))
    
    model.load_state_dict(best_model_wts)
    return model

In [44]:
datasets = {x: MotionVectorDataset(root_dir='../benchmark/MOT17', mode=x) for x in ["train", "val", "test"]}
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=8, shuffle=False) for x in ["train", "val", "test"]}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = PropagationNetwork()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)
train(model, criterion, optimizer, scheduler, num_epochs=2)

Epoch 0/1
torch.Size([8, 2, 68, 120])
torch.Size([8, 512, 3, 4])
torch.Size([8, 512, 3, 4])


RuntimeError: input and target batch or spatial sizes don't match: target [8 x 22 x 4], input [8 x 512 x 3 x 4] at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:23

In [46]:
next(iter(dataloaders["train"]))[1]

tensor([[[  4,  -1,   1,   1],
         [  0,  -1,   0,   1],
         [  6,   0,  -1,   1],
         [ -1,  -1,   0,   0],
         [ -1,   0,   1,  -1],
         [  0,   0,   0,   0],
         [ -1,   0,   0,   0],
         [ -1,  -1,   0,   0],
         [ -1,  -1,   0,   0],
         [  0,   0,   2,  -1],
         [  0,   0,   0,   0],
         [  1,   0,   1,   0],
         [  2,  -1,  -1,   1],
         [  0,   0,   0,   0],
         [ -1,  -1,   0,   0],
         [  0,   0,   0,   0],
         [  0,  -1,  -1,   1],
         [  0,   0,   0,   0],
         [  0,  -1,   0,   0],
         [ -1,  -1,   0,   0],
         [ -1,  -1,   0,   0],
         [  1,   0,   0,  -1]],

        [[  4,   0,   2,   0],
         [  0,   0,   0,   0],
         [  6,   0,  -1,   1],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   2,   0],
      