In [6]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Author: Milan Ondrasovic <milan.ondrasovic@gmail.com>

import math
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import cv2 as cv
import numpy as np

import matplotlib.pyplot as plt

from PIL import Image, ImageStat, ImageOps

from sot.bbox import BBox
from sot.utils import Size

plt.style.use('tableau-colorblind10')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 11

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
my_model_file_path = "../model_siamfc.pth"
my_model = torch.load(my_model_file_path, map_location=torch.device('cpu'))

new_model_file_path = "../siamfc_alexnet_e50.pth"
new_model = torch.load(new_model_file_path, map_location=torch.device('cpu'))

len(my_model), len(new_model)

In [None]:
for my_item, new_item in zip(my_model.items(), new_model.items()):
    print("*" * 80)
    print(f"M: {my_item[0]} {list(my_item[1].shape)}")
    print(f"N: {new_item[0]} {list(new_item[1].shape)}")

In [None]:
new_model_data = OrderedDict()

for my_item, new_item in zip(my_model.items(), new_model.items()):
    new_model_data[my_item[0]] = new_item[1]

for item in new_model_data.items():
    print(f"{item[0]} {list(item[1].shape)}")

torch.save(new_model_data, "../model_combined.pth")

In [11]:
from torch.utils.data import random_split

a, b = random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
for i in a:
    print(i)

2
6
1


In [16]:
from sot.utils import concat_imgs, ColorT

class SiameseTrackingVisualizer:
    def __init__(
            self, exemplar_img: np.ndarray, *,
            border_value: ColorT = (0, 0, 0),
            win_name: str = "Siamese Tracking Preview") -> None:
        self.exemplar_img: np.ndarray = exemplar_img
        self.border_value: ColorT = border_value
        self.win_name: str = win_name

    def show_curr_state(
            self, curr_frame: np.ndarray, instance_img: np.ndarray,
            response_map: np.ndarray) -> None:
        row = concat_imgs(
            (self.exemplar_img, instance_img, response_map), row=True,
            border_value=self.border_value)
        preview_img = concat_imgs(
            (curr_frame, row), row=False, border_value=self.border_value)
        cv.imshow(self.win_name, preview_img)
        cv.waitKey(0)
        cv.destroyWindow(self.win_name)

exemplar_img = np.ones((127, 127, 3)) * (255, 0, 0)
instance_img = np.ones((255, 255, 3)) * (0, 255, 0)

visualizer = SiameseTrackingVisualizer(exemplar_img)

for _ in range(5):
    response_map = np.random.randint(0, 256, (272, 272, 1)).round().astype(np.uint8)
    response_map = cv.cvtColor(response_map, cv.COLOR_GRAY2BGR)
    curr_frame = np.random.randint(0, 256, (600, 800, 3)).round().astype(np.uint8)
    visualizer.show_curr_state(curr_frame, instance_img, response_map)
