<i>Copyright (c) Microsoft Corporation. All rights reserved.</i>

<i>Licensed under the MIT License.</i>

# Training a Multi-Object Tracking Model

## Initialization

In [1]:
import sys

sys.path.append("../../")

import os
import time
import matplotlib.pyplot as plt
from typing import Iterator
from pathlib import Path
from PIL import Image
from random import randrange
from typing import Tuple
import torch
import torchvision
from torchvision import transforms
import scrapbook as sb

from ipywidgets import Video
from utils_cv.tracking.dataset import TrackingDataset
from utils_cv.tracking.model import TrackingLearner

from utils_cv.common.gpu import which_processor, is_windows

# Change matplotlib backend so that plots are shown for windows
if is_windows():
    plt.switch_backend("TkAgg")

print(f"TorchVision: {torchvision.__version__}")
which_processor()

TorchVision: 0.4.0a0
Torch is using GPU: Tesla K80


This shows your machine's GPUs (if it has any) and the computing device `torch/torchvision` is using.

In [2]:
# Ensure edits to libraries are loaded and plotting is shown in the notebook.
%reload_ext autoreload
%autoreload 2
%matplotlib inline

Next, set some model runtime parameters.

In [3]:
EPOCHS = 2
LEARNING_RATE = 0.0001
BATCH_SIZE = 1
SAVE_MODEL = True

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using torch device: {device}")

Using torch device: cuda


## Prepare Training Dataset

In [4]:
DATA_PATH_TRAIN = "./data/odFridgeObjects_FairMOTformat/"
os.listdir(DATA_PATH_TRAIN)

['labels_with_ids', '.ipynb_checkpoints', 'images']

## Load Training Dataset

In [5]:
data_train = TrackingDataset(
    DATA_PATH_TRAIN,
    batch_size=BATCH_SIZE
)

dataset summary
OrderedDict([('default', 4.0)])
total # identities: 5
start index
OrderedDict([('default', 0)])


## Finetune a Pretrained Model

In [6]:
tracker = TrackingLearner(data_train) 
print(f"Model: {type(tracker.model)}")

Model: <class 'utils_cv.tracking.references.fairmot.models.networks.pose_dla_dcn.DLASeg'>


In [7]:
tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE)

