In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import transforms

from eurus.track.pytorch.train import ForwardTrackingModel, Alov300, Uav123, Vot2016

## Data

Create a dataset instance:

In [None]:
dataset = Vot2016(
    '/data1/joan/eurus/data/vot2016/', 
    transform=transforms.ToTensor(),
    sequence_length=2
)

Visualize dataset:

In [None]:
dataset.view_original()

Dataset length:

In [None]:
print(dataset)

View sequence length histogram:

In [None]:
dataset.view_sequence_length_histogram()

Create a `Dataloader` for the dataset:

In [None]:
batch_size = 8

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

## Train

Create an instance of the model:

In [None]:
model = ForwardTrackingModel()

Move it to the gpu:

In [None]:
model = model.cuda()

Define loss functions:

In [None]:
criterion = nn.MSELoss(size_average=True)
criterion = criterion.cuda()

Define optimizer:

In [None]:
optimizer = optim.SGD(model.parameters(), lr=1e-6)

In [None]:
len(data_loader)

Define the training loop:

In [None]:
def train(epoch):
    
    model.train()
    
    for i, data_sequence in enumerate(data_loader):
        
#         optimizer.zero_grad()
        
#         loss1 = 0
#         loss2 = 0
#         loss = 0
        
#         bz = data_sequence[0][0].size()[0]
        
#         s1 = (Variable(torch.zeros([bz, 256]).cuda()),
#               Variable(torch.zeros([bz, 256]).cuda()))
#         s2 = (Variable(torch.zeros([bz, 256]).cuda()),
#               Variable(torch.zeros([bz, 256]).cuda()))
        
        for j, data in enumerate(data_sequence):
            
            optimizer.zero_grad()
            
            x1, x2, _, search_box, context_box, t2 = data
    
            x1 = Variable(x1.cuda())
            x2 = Variable(x2.cuda())
            
#             y1, y2, s1, s2 = model(x1, x2, s1, s2)
            y2 = model(x1, x2)
            
#             t1 = search_box - context_box
#             t1 = Variable(t1.cuda())
            t2 = Variable(t2.cuda())

#             l1 = torch.sqrt(criterion(y1[:, :2], t1[:, :2])) + torch.sqrt(criterion(y1[:, 2:], t1[:, 2:]))
            l2 = criterion(y2, t2.unsqueeze(1)) / len(data_loader)
            
#             loss1 += l1
#             loss2 += l2
#             loss += l1 + l2

            loss2 = l2
            
            loss2.backward()
            optimizer.step()
        
            if j % 1 == 0:
    #             print('Train Epoch: {0:03d} [{1:06d}/{2:05d} ({3:2.0f}%)]\t'
    #                   'Loss: {4:4.4f} [{5:4.4f} + {6:4.4f}]'.format(
    #                       epoch,
    #                       i * x1.size()[0],
    #                       len(data_loader.dataset),
    #                       100. * i / len(data_loader),
    #                       loss.data[0],
    #                       loss1.data[0],
    #                       loss2.data[0])
    #                  )
                 print('Train Epoch: {0:03d} [{1:06d}/{2:05d} ({3:2.0f}%)]\t'
                      'Loss: {4:4.4f}'.format(
                          epoch,
                          i * x1.size()[0],
                          len(data_loader.dataset),
                          100. * i / len(data_loader),
                          loss2.data[0])
                     )
            

In [None]:
for epoch in range(0, 100000):
    train(epoch)
    if epoch % 25 == 0:
        torch.save(model.state_dict(), "/data1/joan/eurus/model.pth")

In [None]:
state_dict = torch.load("/data1/joan/eurus/model.pth")
model.load_state_dict(state_dict)