In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import resnet50
from torchvision.ops import roi_align
from torchvision.io import read_image
import numpy as np
import os
from got10k.trackers import Tracker
from got10k.datasets import DTB70, NfS, OTB, TColor128, TrackingNet, UAV123, VOT
from got10k.utils.viz import show_frame

In [None]:
LR = 0.0001
EPOCHS = 10
W = 512
H = 512
LAMBDA = 1
BATCHSIZE = 4

In [None]:
# feature extractor
class ResnetFeature(nn.Module):
    def __init__(self):
        super(ResnetFeature, self).__init__()
        self.feat = {}
        def get_features(features, name):
            def hook(model, input, output):
                features[name] = output.detach()
            return hook
        self.f = resnet50(pretrained=True)
        self.f.layer2.register_forward_hook(get_features(self.feat, 'feat1'))
        self.f.layer3.register_forward_hook(get_features(self.feat, 'feat2'))
        self.f.layer4.register_forward_hook(get_features(self.feat, 'feat3'))
    
    def forward(self, x):
        o = self.f(x)
        return self.feat['feat1'], self.feat['feat2'], self.feat['feat3']

# TODO
class RPN(nn.Module):
    def __init__(self):
        super(ResnetFeature, self).__init__()
        
    def forward(self, scoremap1, scoremap2, scoremap3):
        

In [None]:
class FullNet(nn.Module):
    def __init__(self):
        super(FullNet, self).__init__()
        self.feature_net = ResnetFeature()
        self.rpn = RPN()
    
    def forward(self, query_image, target_box, search_image):
        # get features
        # qf = query features, sf = search features
        qf1, qf2, qf3 = self.feature_net(query_image)
        sf1, sf2, sf3 = self.feature_net(search_image)
        
        #get roi features from query
        roi11 = roi_align(input=qf1, boxes=target_box, output_size=3, spatial_scale=qf1.shape[-1]/W, aligned=True)
        roi12 = roi_align(input=qf1, boxes=target_box, output_size=5, spatial_scale=qf1.shape[-1]/W, aligned=True)
        roi13 = roi_align(input=qf1, boxes=target_box, output_size=7, spatial_scale=qf1.shape[-1]/W, aligned=True)
        
        roi21 = roi_align(input=qf2, boxes=target_box, output_size=3, spatial_scale=qf2.shape[-1]/W, aligned=True)
        roi22 = roi_align(input=qf2, boxes=target_box, output_size=5, spatial_scale=qf2.shape[-1]/W, aligned=True)
        roi23 = roi_align(input=qf2, boxes=target_box, output_size=7, spatial_scale=qf2.shape[-1]/W, aligned=True)
        
        roi31 = roi_align(input=qf3, boxes=target_box, output_size=3, spatial_scale=qf3.shape[-1]/W, aligned=True)
        roi32 = roi_align(input=qf3, boxes=target_box, output_size=5, spatial_scale=qf3.shape[-1]/W, aligned=True)
        roi33 = roi_align(input=qf3, boxes=target_box, output_size=7, spatial_scale=qf3.shape[-1]/W, aligned=True)
        
        # cross corr templates with search features groupwise, each output has 16 channels
        scoremap11 = F.conv2d(input=sf1.reshape(1, -1, sf1.shape[-2], sf1.shape[-1]), weight=roi11.reshape(roi11.shape[0]*16, roi11.shape[1]//16, roi11.shape[2], roi11.shape[3]), groups=16*sf1.shape[0], padding=0)
        scoremap12 = F.conv2d(input=sf1.reshape(1, -1, sf1.shape[-2], sf1.shape[-1]), weight=roi12.reshape(roi12.shape[0]*16, roi12.shape[1]//16, roi12.shape[2], roi12.shape[3]), groups=16*sf1.shape[0], padding=1)
        scoremap13 = F.conv2d(input=sf1.reshape(1, -1, sf1.shape[-2], sf1.shape[-1]), weight=roi13.reshape(roi13.shape[0]*16, roi13.shape[1]//16, roi13.shape[2], roi13.shape[3]), groups=16*sf1.shape[0], padding=2)
        scoremap11 = scoremap11.reshape(sf1.shape[0], -1, scoremap11.shape[2], scoremap11.shape[3])
        scoremap12 = scoremap12.reshape(sf1.shape[0], -1, scoremap12.shape[2], scoremap12.shape[3])
        scoremap13 = scoremap13.reshape(sf1.shape[0], -1, scoremap13.shape[2], scoremap13.shape[3])
        scoremap1 = torch.cat([scoremap11, scoremap12, scoremap13], 1)
        
        scoremap21 = F.conv2d(input=sf2.reshape(1, -1, sf2.shape[-2], sf2.shape[-1]), weight=roi21.reshape(roi21.shape[0]*16, roi21.shape[1]//16, roi21.shape[2], roi21.shape[3]), groups=16*sf2.shape[0], padding=0)
        scoremap22 = F.conv2d(input=sf2.reshape(1, -1, sf2.shape[-2], sf2.shape[-1]), weight=roi22.reshape(roi22.shape[0]*16, roi22.shape[1]//16, roi22.shape[2], roi22.shape[3]), groups=16*sf2.shape[0], padding=1)
        scoremap23 = F.conv2d(input=sf2.reshape(1, -1, sf2.shape[-2], sf2.shape[-1]), weight=roi23.reshape(roi23.shape[0]*16, roi23.shape[1]//16, roi23.shape[2], roi23.shape[3]), groups=16*sf2.shape[0], padding=2)
        scoremap21 = scoremap21.reshape(sf2.shape[0], -1, scoremap21.shape[2], scoremap21.shape[3])
        scoremap22 = scoremap22.reshape(sf2.shape[0], -1, scoremap22.shape[2], scoremap22.shape[3])
        scoremap23 = scoremap23.reshape(sf2.shape[0], -1, scoremap23.shape[2], scoremap23.shape[3])
        scoremap2 = torch.cat([scoremap21, scoremap22, scoremap23], 1)
        
        scoremap31 = F.conv2d(input=sf3.reshape(1, -1, sf3.shape[-2], sf3.shape[-1]), weight=roi31.reshape(roi31.shape[0]*16, roi31.shape[1]//16, roi31.shape[2], roi31.shape[3]), groups=16*sf3.shape[0], padding=0)
        scoremap32 = F.conv2d(input=sf3.reshape(1, -1, sf3.shape[-2], sf3.shape[-1]), weight=roi32.reshape(roi32.shape[0]*16, roi32.shape[1]//16, roi32.shape[2], roi32.shape[3]), groups=16*sf3.shape[0], padding=1)
        scoremap33 = F.conv2d(input=sf3.reshape(1, -1, sf3.shape[-2], sf3.shape[-1]), weight=roi33.reshape(roi33.shape[0]*16, roi33.shape[1]//16, roi33.shape[2], roi33.shape[3]), groups=16*sf3.shape[0], padding=2)
        scoremap31 = scoremap31.reshape(sf3.shape[0], -1, scoremap31.shape[2], scoremap31.shape[3])
        scoremap32 = scoremap32.reshape(sf3.shape[0], -1, scoremap32.shape[2], scoremap32.shape[3])
        scoremap33 = scoremap33.reshape(sf3.shape[0], -1, scoremap33.shape[2], scoremap33.shape[3])
        scoremap3 = torch.cat([scoremap31, scoremap32, scoremap33], 1)
        
        # region proposal
        cls_output, reg_output = self.rpn(scoremap1, scoremap2, scoremap3)
        
        return cls_output, reg_output

In [None]:
# implement tracker to conduct got10k experiments

#resize image and box, TODO
def prepro(image, box=None):
    return image, box

#get best box from classifier and regressor outputs, TODO
def find_box(cls_output, reg_output):
    return found_box

#get training labels from ground truth anno, TODO
def anno_to_labels(anno):
    return cls_label, reg_label

#TODO
def readimage(image_file):
    image = read_image(image_file).astype(torch.float32)
    return image


class TrackerA(Tracker):
    def __init__(self):
        super(TrackerA, self).__init__()
        self.net = FullNet()
        self.optimizer = optim.Adam(self.net.parameters(), lr=LR)
        
    
    #this one we use for our training
    def step(self, image_files, annos):
        self.qs = []
        self.ss = []
        self.bs = []
        for i in range(annos.shape[0]):
            image_file = image_files[i]
            image = readimage(image_file)
            anno = annos[i]
            
            #if first frame
            if i == 0:
                self.query_image, self.target_box = prepro(image, anno)
                continue
            
            self.search_image = prepro(image)
            
            self.qs.append(self.query_image)
            self.ss.append(self.search_image)
            self.bs.append(self.target_box)
            
            # train when batch
            if i % BATCHSIZE == 0:
                cls_output, reg_output = self.net(torch.cat(self.qs, 0), self.bs, torch.cat(self.ss, 0))
            
                #getting losses
                cls_label, reg_label = anno_to_labels(self.bs)
                cls_loss, reg_loss = get_losses(cls_output, reg_output, cls_label, reg_label)
                loss = cls_loss + LAMBDA * reg_loss
            
                #training
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                #reset batch lists
                self.qs = []
                self.ss = []
                self.bs = []
    
    #save model while training
    def save(self, epoch)
        path = os.path.join("./modelsave/", "model_" + str(epoch) + ".pth")
    
    #these two are used in experiments, TODO
    def init(self, image, box):
        self.query_image, self.target_box = prepro(image, box)
    
    def update(self, image):
        self.search_image = prepro(image)
        cls_output, reg_output = self.net(self.query_image, self.target_box, self.search_image)
        found_box = find_box(cls_output, reg_output)
        return found_box

In [None]:
f = FullNet()
q = torch.randn(BATCHSIZE, 3, 512, 512)
s = torch.randn(BATCHSIZE, 3, 512, 512)
b = [torch.randn(1, 4) for i in range(BATCHSIZE)]
o = f(q,b,s)

In [None]:
for oo in o:
    print(oo.shape)

In [None]:
dataset = UAV123(root_dir='./data/UAV123/')
#tracker = TrackerA()
for epoch in range(EPOCHS):
    for image_files, annos in dataset:
        #tracker.step(image_files, annos)
        print(annos[0])
    # tracker.save(epoch)

In [None]:
d1 = OTB(root_dir='data/OTB', download=True, version='tb50')