<i>Copyright (c) Microsoft Corporation. All rights reserved.</i>

<i>Licensed under the MIT License.</i>

# Training a Multi-Object Tracking Model

## 00 Initialization

In [None]:
# Ensure edits to libraries are loaded and plotting is shown in the notebook.
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#Regular Python Libraries
import os
import os.path as osp
import sys

# Third party tools
from ipywidgets import Video
import matplotlib.pyplot as plt
import torch
import torchvision

# Computer Vision repository
sys.path.append("../../")
from utils_cv.common.data import data_path, download, unzip_url
from utils_cv.common.gpu import which_processor, is_windows
from utils_cv.tracking.data import Urls
from utils_cv.tracking.dataset import TrackingDataset 
from utils_cv.tracking.model import TrackingLearner, write_video 

# Change matplotlib backend so that plots are shown for windows
if is_windows():
    plt.switch_backend("TkAgg")

print(f"TorchVision: {torchvision.__version__}")
which_processor()

This shows your machine's GPUs (if it has any) and the computing device `torch/torchvision` is using.

Next, set some model runtime parameters.

In [None]:
# training params
EPOCHS = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 1
MODEL_PATH = "./models/fairmot_ft.pth" # the path to save the finetuned model

# inference params
CONF_THRES = 0.3
TRACK_BUFFER = 300

# data
TRAIN_DATA_PATH = unzip_url(Urls.fridge_objects_path, exist_ok=True)
EVAL_DATA_PATH = unzip_url(Urls.carcans_annotations_path, exist_ok=True)

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using torch device: {device}")

## 01 Finetune a Pretrained Model

Initialize the training dataset.

In [None]:
data_train = TrackingDataset(
    TRAIN_DATA_PATH,
    batch_size=BATCH_SIZE
)

Initialize and load the model. We use the baseline FairMOT model, which must be downloaded [here](https://drive.google.com/file/d/1udpOPum8fJdoEQm6n0jsIgMMViOMFinu/view) and saved to `./models`

In [None]:
tracker = TrackingLearner(MODEL_PATH, data_train)
print(f"Model: {type(tracker.model)}")

In [None]:
tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE)

In [None]:
tracker.plot_training_losses()

## 02 Evaluate

Note that `EVAL_DATA_PATH` follows the FairMOT input format.

In [None]:
eval_results = tracker.predict(
    EVAL_DATA_PATH,
    conf_thres=CONF_THRES,
    track_buffer=TRACK_BUFFER,
)

In [None]:
eval_metrics = tracker.evaluate(eval_results, EVAL_DATA_PATH) 

## 03 Predict

In [None]:
input_video = download(
    Urls.carcans_video_path, osp.join(data_path(), "carcans.mp4")
)

In [None]:
test_results = tracker.predict(
    input_video,
    conf_thres=CONF_THRES,
    track_buffer=TRACK_BUFFER,
)

In [None]:
output_video = osp.join(data_path(), "carcans_output.mp4")

In [None]:
write_video(test_results, input_video, output_video)

In [None]:
Video.from_file(output_video)