In [None]:
%matplotlib inline

import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import transforms

from eurus.utils import Box
from eurus.track.pytorch import ForwardTracker
from eurus.track.pytorch.train import Alov300, Uav123, Vot2016

## Data

In [None]:
dataset = Vot2016(
    '/data1/joan/eurus/data/vot2016/', 
    transform=transforms.ToTensor()
)

In [None]:
print(dataset)

In [None]:
dataset.view_original()

## Tracker

In [None]:
seq_ind = 1
img_ind = 0

img_sequence = [Image.open(img_path) for img_path in dataset.img_list[seq_ind]]
ann_sequence = dataset.ann_list[seq_ind]

tracker = ForwardTracker('/data1/joan/eurus/model3.pth')

image = img_sequence[img_ind]
initial_box = Box(*ann_sequence[img_ind], timestamp=0)

tracker.initialize(image, initial_box)

Make sure the `context` was correctly set up:

In [None]:
transforms.ToPILImage()(tracker.context.squeeze().data.cpu())

Tracking loop:

In [None]:
boxes = [np.array([initial_box.x + initial_box.w / 2, 
                   initial_box.y + initial_box.h / 2])]
contexts = []
responses = []

for img in img_sequence[:]: 
    contexts.append(transforms.ToPILImage()(tracker.context.squeeze().data.cpu()))
    box, response = tracker.track(img, 0)
    boxes.append(box)
    responses.append(response)

Check the current context of the tracker:

In [None]:
transforms.ToPILImage()(tracker.context.squeeze().data.cpu())

Define visualization functions and display the tracking results:

In [None]:
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets


def visualize_datum(img, box=None):
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.set_axis_off()
    if len(np.array(img).shape) == 2:
        cmap = 'jet'
    else:
        cmap = None   
    ax.imshow(img, cmap=cmap)
    if box is not None:
        ax.scatter(box[0], box[1], c='r', marker='+')
    plt.show()
    

def visualize_data(img_sequence, box_sequence=None):
    def _view_data(index):
        if box_sequence is None:
            boxes = [None] * len(img_sequence)
        else:
            boxes = box_sequence
        visualize_datum(img_sequence[index], boxes[index])

    slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(img_sequence)-1,
        step=1,
        description='index:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='i',
        slider_color='white'
    )

    interact(_view_data, index=slider)

In [None]:
visualize_data(img_sequence, boxes)

Visualize contexts:

In [None]:
visualize_data(contexts)

Visualize responses:

In [None]:
score_maps = [r for r in responses]

In [None]:
visualize_data(score_maps)