In [None]:
import torch
from lib.dataset.crop.lasot import parser_lasot
from lib.dataset.crop.lasot import par_crop_lasot
from lib.dataset.crop.lasot import gen_json_lasot
import matplotlib.pyplot as plt
from tracking.basic_model.et_tracker import ET_Tracker
from tracking.basic_model.cmt_et_tracker import CMT_ET_Tracker
from tracking.basic_model.wavelet_et_tracker import WAVE_ET_Tracker
import torch.optim as optim
from torch.utils.data import DataLoader
from __future__ import division
import torchvision.transforms as transforms
from scipy.ndimage.filters import gaussian_filter
from easydict import EasyDict as edict
from torch.utils.data import Dataset
from lib.utils.utils import *
from tqdm import tqdm

In [None]:
from tracking.basic_model.et_tracker import ET_Tracker
from tracking.basic_model.cmt_et_tracker import CMT_ET_Tracker
from tracking.basic_model.wavelet_et_tracker import WAVE_ET_Tracker

In [None]:
# videos = ["kangaroo"]
# for j in videos:
#     for i in tqdm(range(1,21)):
#         curr = f"kangaroo-{i}"
#         par_crop_lasot.crop_video("C:\\Users\\tommy\\ettrack\\data\\LaSOT", j, curr, "C:\\Users\\tommy\\ettrack\\data\\LaSOT_cropped\\crop511", 511)

In [None]:
# parser_lasot.main(dataDir='C:\\Users\\tommy\\ettrack\\data\\LaSOT', dataCropDir='C:\\Users\\tommy\\ettrack\\data\\LaSOT_cropped')

In [None]:
# gen_json_lasot.main(dataCropDir='C:\\Users\\tommy\\ettrack\\data\\LaSOT_cropped')

In [None]:
config = edict()

# ------config for general parameters------
config.GPUS = "0,1,2,3"
config.WORKERS = 32
config.PRINT_FREQ = 10
config.OUTPUT_DIR = 'logs'
config.CHECKPOINT_DIR = 'snapshot'

config.ET = edict()
config.ET.TRAIN = edict()
config.ET.TEST = edict()
config.ET.REID = edict()
config.ET.TUNE = edict()
config.ET.DATASET = edict()
config.ET.DATASET.VID = edict()
config.ET.DATASET.GOT10K = edict()
config.ET.DATASET.COCO = edict()
config.ET.DATASET.DET = edict()
config.ET.DATASET.LASOT = edict()
config.ET.DATASET.YTB = edict()
config.ET.DATASET.VISDRONE = edict()
config.ET.DATASET.MIX = edict()


# own parameters
config.ET.DEVICE = 'cuda'

# augmentation
config.ET.DATASET.SHIFT = 4
config.ET.DATASET.SCALE = 0.05
config.ET.DATASET.COLOR = 1
config.ET.DATASET.FLIP = 0
config.ET.DATASET.BLUR = 0
config.ET.DATASET.GRAY = 0
config.ET.DATASET.MIXUP = 0
config.ET.DATASET.CUTOUT = 0
config.ET.DATASET.CHANNEL6 = 0
config.ET.DATASET.LABELSMOOTH = 0
config.ET.DATASET.ROTATION = 0
config.ET.DATASET.SHIFTs = 64
config.ET.DATASET.SCALEs = 0.18

config.ET.DATASET.MIX.DIST = 'beta'
config.ET.DATASET.MIX.ALPHA = 1.0
config.ET.DATASET.MIX.BETA = 1.0
config.ET.DATASET.MIX.MIN = 0
config.ET.DATASET.MIX.MAX = 1
config.ET.DATASET.MIX.PROB = 1



# LaSOT
config.ET.DATASET.LASOT.PATH = 'C:\\Users\\tommy\\ettrack\\data\\LaSOT_cropped\\crop511'
config.ET.DATASET.LASOT.ANNOTATION = 'C:\\Users\\tommy\\ettrack\\data\\LaSOT_cropped\\train.json'
config.ET.DATASET.LASOT.RANGE = 100
config.ET.DATASET.LASOT.USE = 34887


# train
config.ET.TRAIN.SCRATCH = False
config.ET.TRAIN.EMA = 0.9998
config.ET.TRAIN.NEG_WEIGHT = 0.1
config.ET.TRAIN.GROUP = "resrchvc"
config.ET.TRAIN.EXID = "setting1"
config.ET.TRAIN.MODEL = "ET"
config.ET.TRAIN.RESUME = False
config.ET.TRAIN.START_EPOCH = 0
config.ET.TRAIN.END_EPOCH = 50
config.ET.TRAIN.TEMPLATE_SIZE = 128
config.ET.TRAIN.SEARCH_SIZE = 256
config.ET.TRAIN.STRIDE = 16
config.ET.TRAIN.BATCH = 32
config.ET.TRAIN.PRETRAIN = 'pretrain.model'
config.ET.TRAIN.LR_POLICY = 'log'
config.ET.TRAIN.LR = 0.001
config.ET.TRAIN.LR_END = 0.00001
config.ET.TRAIN.MOMENTUM = 0.9
config.ET.TRAIN.WEIGHT_DECAY = 0.0001
config.ET.TRAIN.WHICH_USE = ['LASOT']  # VID or 'GOT10K'
config.ET.TRAIN.FREEZE_LAYER = []


In [None]:
sample_random = random.Random()


class OceanDataset(Dataset):
    def __init__(self, cfg):
        super(OceanDataset, self).__init__()
        # pair information
        self.template_size = cfg.ET.TRAIN.TEMPLATE_SIZE
        self.search_size = cfg.ET.TRAIN.SEARCH_SIZE

        # self.size = 25
        self.stride = cfg.ET.TRAIN.STRIDE
        self.size = round(self.search_size / self.stride)

        # aug information
        self.color = cfg.ET.DATASET.COLOR
        self.flip = cfg.ET.DATASET.FLIP
        self.rotation = cfg.ET.DATASET.ROTATION
        self.blur = cfg.ET.DATASET.BLUR
        self.shift = cfg.ET.DATASET.SHIFT
        self.scale = cfg.ET.DATASET.SCALE
        self.gray = cfg.ET.DATASET.GRAY
        self.label_smooth = cfg.ET.DATASET.LABELSMOOTH
        self.mixup = cfg.ET.DATASET.MIXUP
        self.cutout = cfg.ET.DATASET.CUTOUT

        # aug for search image
        self.shift_s = cfg.ET.DATASET.SHIFTs
        self.scale_s = cfg.ET.DATASET.SCALEs

        self.grids()

        self.transform_extra = transforms.Compose(
            [transforms.ToPILImage(), ] +
            ([transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), ] if self.color > random.random() else [])
            + ([transforms.RandomHorizontalFlip(), ] if self.flip > random.random() else [])
            + ([transforms.RandomRotation(degrees=10), ] if self.rotation > random.random() else [])
            + ([transforms.Grayscale(num_output_channels=3), ] if self.gray > random.random() else [])
            + ([transforms.Cutout(n_holes=1, length=16)] if self.cutout > random.random() else [])
        )

        # train data information
        print('train datas: {}'.format(cfg.ET.TRAIN.WHICH_USE))
        self.train_datas = []  # all train dataset
        start = 0
        self.num = 0
        for data_name in cfg.ET.TRAIN.WHICH_USE:
            dataset = subData(cfg, data_name, start)
            self.train_datas.append(dataset)
            start += dataset.num  # real video number
            self.num += dataset.num_use  # the number used for subset shuffle

        self._shuffle()
        print(cfg)

        '''normalization'''
        mean = [0.485, 0.456, 0.406]  # IMAGENET MEAN (RGB)
        std = [0.229, 0.224, 0.225]  # IMAGENET STD (RGB)
        self.transform_norm = transforms.Normalize(mean=mean, std=std)

    def __len__(self):
        return self.num

    def __getitem__(self, index):
        """
        pick a vodeo/frame --> pairs --> data aug --> label
        """

        index = self.pick[index]
        dataset, index = self._choose_dataset(index)

        template, search = dataset._get_pairs(index, dataset.data_name)
        template, search = self.check_exists(index, dataset, template, search)

        template_image = cv2.cvtColor(cv2.imread(template[0]), cv2.COLOR_BGR2RGB)  # numpy array
        search_image = cv2.cvtColor(cv2.imread(search[0]), cv2.COLOR_BGR2RGB)  # numpy array

        template_box = self._toBBox(template_image, template[1])
        search_box = self._toBBox(search_image, search[1])

        template, tbox, _ = self._augmentation(template_image, template_box, self.template_size)
        search, bbox, dag_param = self._augmentation(search_image, search_box, self.search_size, search=True)

        # from PIL image to numpy
        template = np.array(template)
        search = np.array(search)
        '''2020.08.13 positive region is adaptive to the stride'''
        out_label = self._dynamic_label([self.size, self.size], dag_param.shift,
                                        rPos=16 / self.stride)

        reg_label, reg_weight = self.reg_label(bbox)

        template, search = map(lambda x: np.transpose(x, (2, 0, 1)).astype(np.float32), [template, search])
        '''Normalization'''
        template, search = map(lambda x: self.transform_norm(torch.tensor(x) / 255.0), [template, search])
        return template, search, out_label, reg_label, reg_weight, np.array(bbox, np.float32)  # self.label 15*15/17*17

    # ------------------------------------
    # function groups for selecting pairs
    # ------------------------------------
    def grids(self):
        """
        each element of feature map on input search image
        :return: H*W*2 (position for each element)
        """
        sz = self.size

        sz_x = sz // 2
        sz_y = sz // 2

        x, y = np.meshgrid(np.arange(0, sz) - np.floor(float(sz_x)),
                           np.arange(0, sz) - np.floor(float(sz_y)))

        self.grid_to_search = {}
        self.grid_to_search_x = x * self.stride + self.search_size // 2
        self.grid_to_search_y = y * self.stride + self.search_size // 2
        #  (0,0) top left (stride)
        ## bbox in 0,0 in pixel coord-system

    def reg_label(self, bbox):
        """
        generate regression label
        :param bbox: [x1, y1, x2, y2]
        :return: [l, t, r, b]
        """
        x1, y1, x2, y2 = bbox
        l = self.grid_to_search_x - x1  # [17, 17]
        t = self.grid_to_search_y - y1
        r = x2 - self.grid_to_search_x
        b = y2 - self.grid_to_search_y
        l, t, r, b = map(lambda x: np.expand_dims(x, axis=-1), [l, t, r, b])
        reg_label = np.concatenate((l, t, r, b), axis=-1)  # [17, 17, 4]
        reg_label_min = np.min(reg_label, axis=-1)
        inds_nonzero = (reg_label_min > 0).astype(float)  # location not inside the box

        return reg_label, inds_nonzero

    def check_exists(self, index, dataset, template, search):
        name = dataset.data_name
        while True:
            if 'RGBT' in name or 'GTOT' in name and 'RGBTRGB' not in name and 'RGBTT' not in name:
                if not (os.path.exists(template[0][0]) and os.path.exists(search[0][0])):
                    index = random.randint(0, 100)
                    template, search = dataset._get_pairs(index, name)
                    continue
                else:
                    return template, search
            else:
                #print(f'second case')
                if not (os.path.exists(template[0]) and os.path.exists(search[0])):
                    print(f'paths do not exist: {template[0]}, {search[0]}')
                    index = random.randint(0, 100)
                    template, search = dataset._get_pairs(index, name)
                    continue
                else:
                    return template, search

    def _shuffle(self):
        """
        random shuffel
        """
        pick = []
        m = 0
        while m < self.num:
            p = []
            for subset in self.train_datas:
                sub_p = subset.pick
                p += sub_p
            sample_random.shuffle(p)

            pick += p
            m = len(pick)
        self.pick = pick
        print("dataset length {}".format(self.num))

    def _choose_dataset(self, index):
        for dataset in self.train_datas:
            if dataset.start + dataset.num > index:
                return dataset, index - dataset.start

    def _get_image_anno(self, video, track, frame, RGBT_FLAG=False):
        """
        get image and annotation
        """

        frame = "{:06d}".format(frame)
        if not RGBT_FLAG:
            image_path = join(self.root, video, "{}.{}.x.jpg".format(frame, track))
            image_anno = self.labels[video][track][frame]
            return image_path, image_anno
        else:  # rgb
            in_image_path = join(self.root, video, "{}.{}.in.x.jpg".format(frame, track))
            rgb_image_path = join(self.root, video, "{}.{}.rgb.x.jpg".format(frame, track))
            image_anno = self.labels[video][track][frame]
            in_anno = np.array(image_anno[-1][0])
            rgb_anno = np.array(image_anno[-1][1])

            return [in_image_path, rgb_image_path], (in_anno + rgb_anno) / 2

    def _get_pairs(self, index):
        """
        get training pairs
        """
        video_name = self.videos[index]
        video = self.labels[video_name]
        track = random.choice(list(video.keys()))
        track_info = video[track]
        try:
            frames = track_info['frames']
        except:
            frames = list(track_info.keys())

        template_frame = random.randint(0, len(frames) - 1)

        left = max(template_frame - self.frame_range, 0)
        right = min(template_frame + self.frame_range, len(frames) - 1) + 1
        search_range = frames[left:right]
        template_frame = int(frames[template_frame])
        search_frame = int(random.choice(search_range))

        return self._get_image_anno(video_name, track, template_frame), \
            self._get_image_anno(video_name, track, search_frame)

    def _posNegRandom(self):
        """
        random number from [-1, 1]
        """
        return random.random() * 2 - 1.0

    def _toBBox(self, image, shape):
        '''
        image: input image
        shape: bounding box
        '''
        imh, imw = image.shape[:2]
        if len(shape) == 4:
            w, h = shape[2] - shape[0], shape[3] - shape[1]
        else:
            w, h = shape

        context_amount = 0.5
        exemplar_size = self.template_size

        wc_z = w + context_amount * (w + h)
        hc_z = h + context_amount * (w + h)

        s_z = np.sqrt(wc_z * hc_z)
        scale_z = exemplar_size / s_z
        w = w * scale_z
        h = h * scale_z
        cx, cy = imw // 2, imh // 2
        bbox = center2corner(Center(cx, cy, w, h))
        return bbox

    def _crop_hwc(self, image, bbox, out_sz, padding=(0, 0, 0)):
        """
        crop image
        """
        bbox = [float(x) for x in bbox]
        a = (out_sz - 1) / (bbox[2] - bbox[0])
        b = (out_sz - 1) / (bbox[3] - bbox[1])
        c = -a * bbox[0]
        d = -b * bbox[1]
        mapping = np.array([[a, 0, c],
                            [0, b, d]]).astype(np.float64)
        crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding)
        return crop

    def _draw(self, image, box, name):
        """
        draw image for debugging
        """
        draw_image = np.array(image.copy())
        x1, y1, x2, y2 = map(lambda x: int(round(x)), box)
        cv2.rectangle(draw_image, (x1, y1), (x2, y2), (0, 255, 0))
        cv2.circle(draw_image, (int(round(x1 + x2) / 2), int(round(y1 + y2) / 2)), 3, (0, 0, 255))
        cv2.putText(draw_image, '[x: {}, y: {}]'.format(int(round(x1 + x2) / 2), int(round(y1 + y2) / 2)),
                    (int(round(x1 + x2) / 2) - 3, int(round(y1 + y2) / 2) - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.3,
                    (255, 255, 255), 1)
        cv2.imwrite(name, draw_image)

    def _draw_reg(self, image, grid_x, grid_y, reg_label, reg_weight, save_path, index):
        """
        visiualization
        reg_label: [l, t, r, b]
        """
        draw_image = image.copy()
        # count = 0
        save_name = join(save_path, '{:06d}.jpg'.format(index))
        h, w = reg_weight.shape
        for i in range(h):
            for j in range(w):
                if not reg_weight[i, j] > 0:
                    continue
                else:
                    x1 = int(grid_x[i, j] - reg_label[i, j, 0])
                    y1 = int(grid_y[i, j] - reg_label[i, j, 1])
                    x2 = int(grid_x[i, j] + reg_label[i, j, 2])
                    y2 = int(grid_y[i, j] + reg_label[i, j, 3])

                    draw_image = cv2.rectangle(draw_image, (x1, y1), (x2, y2), (0, 255, 0))

        cv2.imwrite(save_name, draw_image)

    def _mixupRandom(self):
        """
        gaussian random -- 0.3~0.7
        """
        return random.random() * 0.4 + 0.3

    # ------------------------------------
    # function for data augmentation
    # ------------------------------------
    def _augmentation(self, image, bbox, size, search=False):
        """
        data augmentation for input pairs
        """
        shape = image.shape
        crop_bbox = center2corner((shape[0] // 2, shape[1] // 2, size, size))
        param = edict()

        if search:
            param.shift = (self._posNegRandom() * self.shift_s, self._posNegRandom() * self.shift_s)  # shift
            param.scale = (
            (1.0 + self._posNegRandom() * self.scale_s), (1.0 + self._posNegRandom() * self.scale_s))  # scale change
        else:
            param.shift = (self._posNegRandom() * self.shift, self._posNegRandom() * self.shift)  # shift
            param.scale = (
            (1.0 + self._posNegRandom() * self.scale), (1.0 + self._posNegRandom() * self.scale))  # scale change

        crop_bbox, _ = aug_apply(Corner(*crop_bbox), param, shape)

        x1, y1 = crop_bbox.x1, crop_bbox.y1
        bbox = BBox(bbox.x1 - x1, bbox.y1 - y1, bbox.x2 - x1, bbox.y2 - y1)

        scale_x, scale_y = param.scale
        bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y, bbox.x2 / scale_x, bbox.y2 / scale_y)

        image = self._crop_hwc(image, crop_bbox, size)  # shift and scale

        if self.blur > random.random():
            image = gaussian_filter(image, sigma=(1, 1, 0))

        image = self.transform_extra(image)  # other data augmentation
        return image, bbox, param

    def _mixupShift(self, image, size):
        """
        random shift mixed-up image
        """
        shape = image.shape
        crop_bbox = center2corner((shape[0] // 2, shape[1] // 2, size, size))
        param = edict()

        param.shift = (self._posNegRandom() * 64, self._posNegRandom() * 64)  # shift
        crop_bbox, _ = aug_apply(Corner(*crop_bbox), param, shape)

        image = self._crop_hwc(image, crop_bbox, size)  # shift and scale

        return image

    # ------------------------------------
    # function for creating training label
    # ------------------------------------
    def _dynamic_label(self, fixedLabelSize, c_shift, rPos=2, rNeg=0):
        if isinstance(fixedLabelSize, int):
            fixedLabelSize = [fixedLabelSize, fixedLabelSize]

        # assert (fixedLabelSize[0] % 2 == 1)

        d_label = self._create_dynamic_logisticloss_label(fixedLabelSize, c_shift, rPos, rNeg)

        return d_label

    def _create_dynamic_logisticloss_label(self, label_size, c_shift, rPos=2, rNeg=0):
        if isinstance(label_size, int):
            sz = label_size
        else:
            sz = label_size[0]

        sz_x = sz // 2 + int(-c_shift[0] / self.stride)  # 8 is strides
        sz_y = sz // 2 + int(-c_shift[1] / self.stride)

        x, y = np.meshgrid(np.arange(0, sz) - np.floor(float(sz_x)),
                           np.arange(0, sz) - np.floor(float(sz_y)))

        dist_to_center = np.abs(x) + np.abs(y)  # Block metric
        label = np.where(dist_to_center <= rPos,
                         np.ones_like(y),
                         np.where(dist_to_center < rNeg,
                                  0.5 * np.ones_like(y),
                                  np.zeros_like(y)))
        return label




# ---------------------
# for a single dataset
# ---------------------

class subData(object):
    """
    for training with multi dataset
    """

    def __init__(self, cfg, data_name, start):
        self.data_name = data_name
        self.start = start

        info = cfg.ET.DATASET[data_name]
        self.frame_range = info.RANGE
        self.num_use = info.USE
        self.root = info.PATH

        with open(info.ANNOTATION) as fin:
            self.labels = json.load(fin)
            self._clean()
            self.num = len(self.labels)  # video numer

        self._shuffle()

    def _clean(self):
        """
        remove empty videos/frames/annos in dataset
        """
        # no frames
        to_del = []
        for video in self.labels:
            for track in self.labels[video]:
                frames = self.labels[video][track]
                frames = list(map(int, frames.keys()))
                frames.sort()
                self.labels[video][track]['frames'] = frames
                if len(frames) <= 0:
                    print("warning {}/{} has no frames.".format(video, track))
                    to_del.append((video, track))

        for video, track in to_del:
            try:
                del self.labels[video][track]
            except:
                pass

        # no track/annos
        to_del = []

        if self.data_name == 'YTB':
            to_del.append('train/1/YyE0clBPamU')  # This video has no bounding box.
        print(self.data_name)

        for video in self.labels:
            if len(self.labels[video]) <= 0:
                print("warning {} has no tracks".format(video))
                to_del.append(video)

        for video in to_del:
            try:
                del self.labels[video]
            except:
                pass

        self.videos = list(self.labels.keys())
        print('{} loaded.'.format(self.data_name))

    def _shuffle(self):
        """
        shuffel to get random pairs index (video)
        """
        lists = list(range(self.start, self.start + self.num))
        m = 0
        pick = []
        while m < self.num_use:
            sample_random.shuffle(lists)
            pick += lists
            m += self.num

        self.pick = pick[:self.num_use]
        return self.pick

    def _get_image_anno(self, video, track, frame):
        """
        get image and annotation
        """

        frame = "{:06d}".format(frame)

        image_path = join(self.root, video, "{}.{}.x.jpg".format(frame, track))
        image_anno = self.labels[video][track][frame]
        return image_path, image_anno

    def _get_pairs(self, index, data_name):
        """
        get training pairs
        """
        video_name = self.videos[index]
        video = self.labels[video_name]
        track = random.choice(list(video.keys()))
        track_info = video[track]
        try:
            frames = track_info['frames']
        except:
            frames = list(track_info.keys())

        template_frame = random.randint(0, len(frames) - 1)

        left = max(template_frame - self.frame_range, 0)
        right = min(template_frame + self.frame_range, len(frames) - 1) + 1
        search_range = frames[left:right]

        template_frame = int(frames[template_frame])
        search_frame = int(random.choice(search_range))

        return self._get_image_anno(video_name, track, template_frame), \
            self._get_image_anno(video_name, track, search_frame)

    def _get_negative_target(self, index=-1):
        """
        dasiam neg
        """
        if index == -1:
            index = random.randint(0, self.num - 1)
        video_name = self.videos[index]
        video = self.labels[video_name]
        track = random.choice(list(video.keys()))
        track_info = video[track]

        frames = track_info['frames']
        frame = random.choice(frames)

        return self._get_image_anno(video_name, track, frame)

In [None]:
lasot = OceanDataset(config)

In [None]:
def seed_everything(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


seed_everything(2000)
# Check gpu
device = "mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
ettracker = ET_Tracker(linear_reg=True)
ettracker.to(device)

In [None]:
optimizer = optim.SGD(ettracker.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
train_loader = DataLoader(lasot, batch_size=32, shuffle=True)

In [None]:
losses = []
min_loss = float('inf')
for epoch in tqdm(range(10)):
    for i, data in enumerate(train_loader):
        template, search, out_label, reg_label, reg_weight, some_array = data
        template = template.to(device)
        search = search.to(device)
        out_label = out_label.to(device)
        reg_label = reg_label.to(device)
        reg_weight = reg_weight.to(device)
        
        optimizer.zero_grad()
        output = ettracker(template, search, out_label, reg_label, reg_weight)
        cls, reg = output
        loss = cls + reg
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
        if loss < min_loss:
            min_loss = loss
            print(f'Epoch: {epoch}, Batch: {i}, MIN Loss: {loss.item()}')
            PATH = './ettracker_min.pth'
            torch.save(ettracker.state_dict(), PATH)  # save model to path
            continue
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')
    scheduler.step()
    
        

In [None]:
PATH = './ettracker.pth'

torch.save(ettracker.state_dict(), PATH)  # save model to path

In [None]:
import pickle
with open('losses.pkl', 'wb') as f:
    pickle.dump(losses, f)

In [None]:
parser_lasot.main(dataDir='C:\\Users\\tommy\\ettrack\\data\\test\\LaSOT', dataCropDir='C:\\Users\\tommy\\ettrack\\data\\test\\LaSOT_cropped')

In [None]:
gen_json_lasot.main(dataCropDir='C:\\Users\\tommy\\ettrack\\data\\test\\LaSOT_cropped')

In [None]:
config_test = edict()

# ------config_test for general parameters------
config_test.GPUS = "0,1,2,3"
config_test.WORKERS = 32
config_test.PRINT_FREQ = 10
config_test.OUTPUT_DIR = 'logs'
config_test.CHECKPOINT_DIR = 'snapshot'

config_test.ET = edict()
config_test.ET.TRAIN = edict()
config_test.ET.TEST = edict()
config_test.ET.REID = edict()
config_test.ET.TUNE = edict()
config_test.ET.DATASET = edict()
config_test.ET.DATASET.VID = edict()
config_test.ET.DATASET.GOT10K = edict()
config_test.ET.DATASET.COCO = edict()
config_test.ET.DATASET.DET = edict()
config_test.ET.DATASET.LASOT = edict()
config_test.ET.DATASET.YTB = edict()
config_test.ET.DATASET.VISDRONE = edict()
config_test.ET.DATASET.MIX = edict()


# own parameters
config_test.ET.DEVICE = 'cuda'

# augmentation
config_test.ET.DATASET.SHIFT = 4
config_test.ET.DATASET.SCALE = 0.05
config_test.ET.DATASET.COLOR = 1
config_test.ET.DATASET.FLIP = 0
config_test.ET.DATASET.BLUR = 0
config_test.ET.DATASET.GRAY = 0
config_test.ET.DATASET.MIXUP = 0
config_test.ET.DATASET.CUTOUT = 0
config_test.ET.DATASET.CHANNEL6 = 0
config_test.ET.DATASET.LABELSMOOTH = 0
config_test.ET.DATASET.ROTATION = 0
config_test.ET.DATASET.SHIFTs = 64
config_test.ET.DATASET.SCALEs = 0.18

config_test.ET.DATASET.MIX.DIST = 'beta'
config_test.ET.DATASET.MIX.ALPHA = 1.0
config_test.ET.DATASET.MIX.BETA = 1.0
config_test.ET.DATASET.MIX.MIN = 0
config_test.ET.DATASET.MIX.MAX = 1
config_test.ET.DATASET.MIX.PROB = 1



# LaSOT
config_test.ET.DATASET.LASOT.PATH = 'C:\\Users\\tommy\\ettrack\\data\\test\\LaSOT_cropped\\crop511'
config_test.ET.DATASET.LASOT.ANNOTATION = 'C:\\Users\\tommy\\ettrack\\data\\test\\LaSOT_cropped\\train.json'
config_test.ET.DATASET.LASOT.RANGE = 100
config_test.ET.DATASET.LASOT.USE = 2002


# train
config_test.ET.TRAIN.SCRATCH = False
config_test.ET.TRAIN.EMA = 0.9998
config_test.ET.TRAIN.NEG_WEIGHT = 0.1
config_test.ET.TRAIN.GROUP = "resrchvc"
config_test.ET.TRAIN.EXID = "setting1"
config_test.ET.TRAIN.MODEL = "ET"
config_test.ET.TRAIN.RESUME = False
config_test.ET.TRAIN.START_EPOCH = 0
config_test.ET.TRAIN.END_EPOCH = 50
config_test.ET.TRAIN.TEMPLATE_SIZE = 128
config_test.ET.TRAIN.SEARCH_SIZE = 256
config_test.ET.TRAIN.STRIDE = 16
config_test.ET.TRAIN.BATCH = 32
config_test.ET.TRAIN.PRETRAIN = 'pretrain.model'
config_test.ET.TRAIN.LR_POLICY = 'log'
config_test.ET.TRAIN.LR = 0.001
config_test.ET.TRAIN.LR_END = 0.00001
config_test.ET.TRAIN.MOMENTUM = 0.9
config_test.ET.TRAIN.WEIGHT_DECAY = 0.0001
config_test.ET.TRAIN.WHICH_USE = ['LASOT']  # VID or 'GOT10K'
config_test.ET.TRAIN.FREEZE_LAYER = []


In [None]:
lasot_test = OceanDataset(config_test)

In [None]:
test_loader = DataLoader(lasot_test, batch_size=32, shuffle=False)

In [None]:
losses_test = []
for i, data in enumerate(train_loader): 
    with torch.no_grad():
        ettracker.eval()
        template, search, out_label, reg_label, reg_weight, some_array = data
        template = template.to(device)
        search = search.to(device)
        out_label = out_label.to(device)
        reg_label = reg_label.to(device)
        reg_weight = reg_weight.to(device)
        output = ettracker(template, search, out_label, reg_label, reg_weight)
        cls, reg = output
        loss = cls + reg
        losses_test.append(loss.item())

In [None]:
with open('test_losses.pkl', 'wb') as f:
    pickle.dump(losses_test, f)

In [None]:
losses_test

In [None]:
min_loss_model = ET_Tracker(linear_reg=True)
min_loss_model.load_state_dict(torch.load('./ettracker_min.pth'))
min_loss_model.to(device)
min_loss_model.eval()

In [None]:
min_losses_test = []
for i, data in enumerate(train_loader): 
    with torch.no_grad():
        min_loss_model.eval()
        template, search, out_label, reg_label, reg_weight, some_array = data
        template = template.to(device)
        search = search.to(device)
        out_label = out_label.to(device)
        reg_label = reg_label.to(device)
        reg_weight = reg_weight.to(device)
        output = min_loss_model(template, search, out_label, reg_label, reg_weight)
        cls, reg = output
        loss = cls + reg
        min_losses_test.append(loss.item())

In [None]:
with open('min_loss_test_losses.pkl', 'wb') as f:
    pickle.dump(min_losses_test, f)

In [None]:
# with open("min_loss_test_losses.pkl", "rb") as fp:   # Unpickling
#     b = pickle.load(fp)