In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn.functional as F
from torchvision.transforms import ToTensor

from utils.dataset import VideoDataset, MyConcatDataset
from utils.training import train_model
from utils.models import TrackNetV2MSE, TrackNetV2NLL

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

%load_ext autoreload
%autoreload 2

## Dataset

Load dataset

In [None]:
sequence_length = 3
one_output_frame = True

image_size = (360, 640)

dataset_params = dict(image_size=image_size,
                      sigma=5,
                      sequence_length=sequence_length,
                      heatmap_mode='image',
                      duplicate_equality_threshold=0.97,
                      one_output_frame=one_output_frame,
                      drop_duplicate_frames=True)

dataset_demo = VideoDataset(root="../datasets/dataset_finales_2020_en/", **dataset_params)

roots = [f'../datasets/dataset_lluis/game{i+1}' for i in range(5)]

# training dataset
dataset_train_list = []
dataset_train_list.append(VideoDataset(root="../datasets/dataset_finales_2020_en/", transform=ToTensor(), target_transform=ToTensor(), split='train', **dataset_params))
for root in roots:
    dataset_train_list.append(VideoDataset(root=root, transform=ToTensor(), target_transform=ToTensor(), **dataset_params))

dataset_train = MyConcatDataset(dataset_train_list)


# validation dataset
dataset_val_list = []
dataset_val_list.append(VideoDataset(root="../datasets/dataset_finales_2020_en/", transform=ToTensor(), target_transform=ToTensor(), split='val', **dataset_params))
for root in roots[-1:]:
    dataset_val_list.append(VideoDataset(root=root, transform=ToTensor(), target_transform=ToTensor(), **dataset_params))

dataset_val = MyConcatDataset(dataset_val_list)

In [None]:
w, h, dpi = 480, 853, 50

fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

frames, labels = dataset_demo[57]
ax.imshow(frames[-1])
ax.imshow(labels, alpha=0.6, cmap='gray')
ax.set_axis_off()
fig.tight_layout(pad=0)

plt.show()

# Training

Load model and set checkpoint folder

In [None]:
model = TrackNetV2MSE(sequence_length=sequence_length, one_output_frame=one_output_frame)
checkpoint_folder = './checkpoints/checkpoints_360_640_mse'

train_model_partial = partial(train_model, loss_function=F.mse_loss)

In [None]:
model = TrackNetV2NLL(sequence_length=sequence_length)
checkpoint_folder = './checkpoints/checkpoints_360_640_nll'

train_model_partial = partial(train_model, loss_function=F.nll_loss)

## Training loop

In [None]:
#TODO: check the training recipe

batch_size = 2

# data loaders
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
data_loader_val = DataLoader(dataset_val, batch_size=batch_size)

checkpoint_dict = train_model_partial(model,
                                      data_loader_train,
                                      data_loader_val,
                                      epochs=10,
                                      device=device,
                                      checkpoint_folder=checkpoint_folder,
                                      additional_info={'dataset_train': dataset_train.get_info(),
                                                       'dataset_val': dataset_val.get_info()})

Plot train and validation loss

In [None]:
train_loss = np.loadtxt(os.path.join(checkpoint_folder, "loss_history.csv"))
val_loss = np.loadtxt(os.path.join(checkpoint_folder, "loss_history_val.csv"))

fig, ax = plt.subplots()
ax.set_yscale('log')

ax.plot(train_loss, label='train loss')
ax.plot(val_loss, label='val loss')

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")

ax.legend()

plt.show()

Load weights from checkpoint

In [None]:
model.to(device)
model.load("checkpoints\checkpoints_360_640_mse\checkpoint_0007_best.ckpt", device=device)
model.eval();

# Analysis

## Compute error

In [None]:
from utils.testing import compute_positions, save_outputs
from torch.utils.data import Subset

In [None]:
data_loader_test = DataLoader(dataset_val, batch_size=4)
true_positions, predicted_positions = compute_positions(model, data_loader_test, device=device)

# TODO: improve this, to be made with save_outputs from utils.testing
df_out = pd.DataFrame({'x_true': true_positions[:,0]/image_size[1],
                       'y_true': true_positions[:,1]/image_size[0],
                       'x_pred': predicted_positions[:,0]/image_size[1],
                       'y_pred': predicted_positions[:,1]/image_size[0]})

df_out.to_csv(os.path.join(checkpoint_folder, 'val_output.csv'), index=False)

#save_outputs(true_positions, predicted_positions, dataset_val, output_folder=checkpoint_folder)

## Analyze error

In [None]:
df_out = pd.read_csv(os.path.join(checkpoint_folder, 'val_output.csv'))

error = np.sqrt((image_size[1]*(df_out['x_true']-df_out['x_pred']))**2 + (image_size[0]*(df_out['y_true']-df_out['y_pred']))**2)
error = np.asarray(error)

In [None]:
threshold = 5
print(f"Error smaller than {threshold} pixels: {100*len(np.nonzero(error<=threshold)[0])/len(error):.2g}%")

hist_range=15
plt.hist(error.clip(max=hist_range), bins=np.arange(hist_range+1), density=True, align='left', rwidth=0.8)
plt.show()

# Show example

Get maximum of the heatmap

In [None]:
def get_maximum_coordinates(heatmaps):
    if len(heatmaps.shape)==3:
        map_max_values = np.max(heatmaps.reshape(3, -1), axis=1)
        map_index = np.argmax(map_max_values)
        max_value = np.max(map_max_values)
        x, y = np.nonzero(heatmaps[map_index] == max_value)
    else:
        max_value = np.max(heatmaps)
        x, y = np.nonzero(heatmaps == max_value)

    return y[0], x[0]

Produce output heatmap

In [None]:
frames, heatmaps = dataset_val[20]
heatmaps_np = heatmaps.numpy()
frames_np = frames.numpy()[-3:].transpose(1, 2, 0)

with torch.no_grad():
    outputs = model(frames[None, :].to(device))
outputs_np = outputs.to('cpu').numpy()[0].transpose(0, 1, 2)

In [None]:
w, h, dpi = 1280, 720, 100

fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

ax.imshow(frames_np)

ax.imshow(outputs_np[0], cmap='magma', alpha=0.5)
ax.scatter(*get_maximum_coordinates(heatmaps_np[0]), color='w', label='True position', alpha=0.5)
ax.scatter(*get_maximum_coordinates(outputs_np[0]), color='y', label='Predicted position', alpha=0.5)

ax.legend(framealpha=0.5)

ax.set_axis_off()

fig.tight_layout()

plt.show()

In [None]:
w, h, dpi = 1280, 720*2, 100

fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, nrows=2)

axs[0].imshow(frames_np)

axs[1].imshow(outputs_np[0], cmap='magma')
axs[1].scatter(*get_maximum_coordinates(heatmaps_np[0]), color='w', label='True position')
axs[1].scatter(*get_maximum_coordinates(outputs_np[0]), color='y', label='Predicted position')

axs[1].legend(framealpha=0.5)

fig.tight_layout()
for ax in axs:
    ax.set_axis_off()

plt.show()

# Duplicate frames demo

In [None]:
import cv2
root = "../videos/dataset_finales_2020_en/"
cap = cv2.VideoCapture(os.path.join(root, "video.mp4"))

In [None]:
cap.set(cv2.CAP_PROP_POS_FRAMES, 13620)
_, frame1 = cap.read()
frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2RGB)

_, frame2 = cap.read()
frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB)

In [None]:
plt.imshow(frame1)
plt.show()

plt.imshow(frame2)
plt.show()

In [None]:
plt.imshow(np.abs(frame2 - frame1))
plt.colorbar()
plt.show()

# Modify additional info for old runs

In [None]:
from utils.storage import load_checkpoint_dict, save_checkpoint

In [None]:
d = load_checkpoint_dict('checkpoints/checkpoints_512_3f_mse')

In [None]:
additional_info = d['additional_info']

In [None]:
additional_info['dataset_train'] = additional_info_train
additional_info['dataset_val'] = additional_info_val

In [None]:
save_checkpoint(d, 'checkpoints/checkpoints_512_3f_mse')