In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.ops import roi_align
import numpy as np
import os
from got10k.trackers import Tracker
from got10k.datasets import GOT10k, UAV123
from got10k.utils.viz import show_frame

In [54]:
LR = 0.0001
EPOCHS = 10
W = 512
H = 512
LAMBDA = 1
BATCHSIZE = 8

In [49]:
# feature extractor
class ResnetFeature(nn.Module):
    def __init__(self):
        super(ResnetFeature, self).__init__()
        self.feat = {}
        def get_features(features, name):
            def hook(model, input, output):
                features[name] = output.detach()
            return hook
        self.f = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.f.layer2.register_forward_hook(get_features(self.feat, 'feat1'))
        self.f.layer3.register_forward_hook(get_features(self.feat, 'feat2'))
        self.f.layer4.register_forward_hook(get_features(self.feat, 'feat3'))
    
    def forward(self, x):
        o = self.f(x)
        return self.feat['feat1'], self.feat['feat2'], self.feat['feat3']

In [61]:
class FullNet(nn.Module):
    def __init__(self):
        super(FullNet, self).__init__()
        self.feature_net = ResnetFeature()
    
    def forward(self, query_image, target_box, search_image):
        # qf = query features, sf = search features
        qf1, qf2, qf3 = self.feature_net(query_image)
        sf1, sf2, sf3 = self.feature_net(search_image)
        
        roi11 = roi_align(input=qf1, boxes=target_box, output_size=3, spatial_scale=qf1.shape[-1]/W, aligned=True)
        roi12 = roi_align(input=qf1, boxes=target_box, output_size=5, spatial_scale=qf1.shape[-1]/W, aligned=True)
        roi13 = roi_align(input=qf1, boxes=target_box, output_size=7, spatial_scale=qf1.shape[-1]/W, aligned=True)
        
        roi21 = roi_align(input=qf2, boxes=target_box, output_size=3, spatial_scale=qf2.shape[-1]/W, aligned=True)
        roi22 = roi_align(input=qf2, boxes=target_box, output_size=5, spatial_scale=qf2.shape[-1]/W, aligned=True)
        roi23 = roi_align(input=qf2, boxes=target_box, output_size=7, spatial_scale=qf2.shape[-1]/W, aligned=True)
        
        roi31 = roi_align(input=qf3, boxes=target_box, output_size=3, spatial_scale=qf3.shape[-1]/W, aligned=True)
        roi32 = roi_align(input=qf3, boxes=target_box, output_size=5, spatial_scale=qf3.shape[-1]/W, aligned=True)
        roi33 = roi_align(input=qf3, boxes=target_box, output_size=7, spatial_scale=qf3.shape[-1]/W, aligned=True)
        
        #return cls_output, reg_output
        return roi11

In [None]:
# implement tracker to conduct got10k experiments

#resize image and box, TODO
def prepro(image, box=None):
    return image, box

#get best box from classifier and regressor outputs, TODO
def find_box(cls_output, reg_output):
    return found_box

#get training labels from ground truth anno, TODO
def anno_to_labels(anno):
    return cls_label, reg_label

#TODO
def read_image(image_file):
    return image


class TrackerA(Tracker):
    def __init__(self):
        super(TrackerA, self).__init__()
        self.net = FullNet()
        self.optimizer = optim.Adam(self.net.parameters(), lr=LR)
        
    
    #this one we use for our training
    def step(self, image_files, annos):
        self.qs = []
        self.ss = []
        self.bs = []
        for i in range(annos.shape[0]):
            image_file = image_files[i]
            image = read_image(image_file)
            anno = annos[i]
            
            #if first frame
            if i == 0:
                self.query_image, self.target_box = prepro(image, anno)
                continue
            
            self.search_image = prepro(image)
            
            self.qs.append(self.query_image)
            self.ss.append(self.search_image)
            self.bs.append(self.target_box)
            
            # train when batch
            if i % BATCHSIZE == 0:
                cls_output, reg_output = self.net(torch.cat(self.qs, 0), self.bs, torch.cat(self.ss, 0))
            
                #getting losses
                cls_label, reg_label = anno_to_labels(self.bs)
                cls_loss, reg_loss = get_losses(cls_output, reg_output, cls_label, reg_label)
                loss = cls_loss + LAMBDA * reg_loss
            
                #training
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                #reset batch lists
                self.qs = []
                self.ss = []
                self.bs = []
    
    #save model while training
    def save(self, epoch)
        path = os.path.join("./modelsave/", "model_" + str(epoch) + ".pth")
    
    #these two are used in experiments, TODO
    def init(self, image, box):
        self.query_image, self.target_box = prepro(image, box)
    
    def update(self, image):
        self.search_image = prepro(image)
        cls_output, reg_output = self.net(self.query_image, self.target_box, self.search_image)
        found_box = find_box(cls_output, reg_output)
        return found_box

In [62]:
f = FullNet()

In [63]:
q = torch.randn(8, 3, 512, 512)
s = torch.randn(8, 3, 512, 512)
b = [torch.randn(1, 4) for i in range(8)]

In [65]:
f(q, b, s)

RuntimeError: Couldn't load custom C++ ops. This can happen if your PyTorch and torchvision versions are incompatible, or if you had errors while compiling torchvision from source. For further information on the compatible versions, check https://github.com/pytorch/vision#installation for the compatibility matrix. Please check your PyTorch version with torch.__version__ and your torchvision version with torchvision.__version__ and verify if they are compatible, and if not please reinstall torchvision so that it matches your PyTorch install.

In [None]:
dataset = UAV123(root_dir='./data/UAV123/')
#tracker = TrackerA()
for epoch in range(EPOCHS):
    for image_files, annos in dataset:
        #tracker.step(image_files, annos)
        print(annos[0])
    tracker.save(epoch)