In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch

# figures_folder = 'figures'
figures_folder = '../TFM/pictures'

if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)

plt.style.use('default')

# Examples for input and output

**Visualization of input frames and heatmap**

In [None]:
from train_configurations.utils import get_standard_test_dataset
from train_configurations import tracknet_v2
from detection.testing import _get_checkpoint_filename, _get_results_folder

# load model from the desired train configuration
config = tracknet_v2.Config()
dataset = get_standard_test_dataset(tracknet_v2, 'prova')
model = config.get_model()
model.eval()

results_folder = _get_results_folder(config._checkpoint_folder, None)
checkpoint_path = os.path.join(results_folder, _get_checkpoint_filename(config._checkpoint_folder))

model.load(checkpoint_path)

Compute heatmap for an example frame

In [None]:
len(dataset)

In [None]:
dataset_element = 15

with torch.no_grad():
    frames, heatmap = dataset[dataset_element]
    frames = frames.to(torch.float32)

    heatmap_pred = model(frames.unsqueeze(0)).squeeze().numpy()

    heatmap = heatmap.to(torch.float32).squeeze().numpy()
    frames = frames.numpy().transpose(1, 2, 0)

frames = [frames[:,:,3*i:3*(i+1)] for i in range(3)]

In [None]:
w, h, dpi = 2*640, 2*360, 100
fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, nrows=2, ncols=2)
axs = axs.ravel()

for i, (frame, ax) in enumerate(zip(frames, axs)):
    ax.set_title(f'Frame {i+1}')
    ax.imshow(frame)

    x = dataset._label_df['x'][i+dataset_element]
    y = dataset._label_df['y'][i+dataset_element]
    ax.scatter(x*frame.shape[1], y*frame.shape[0], zorder=100, facecolors='none', edgecolors='y', linewidths=3, s=150)

axs[3].imshow(heatmap, cmap='gray')
axs[3].set_title('Target heatmap')

fig.tight_layout()

fig.savefig(os.path.join(figures_folder, 'sample_input_a.png'))

plt.show()

In [None]:
w, h, dpi = 640, 360, 100
fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

ax.imshow(frames[-1], cmap='gray')
ax.imshow(heatmap, cmap='gray', alpha=0.5)

ax.set_axis_off()
fig.tight_layout(pad=0)

fig.savefig(os.path.join(figures_folder, 'sample_input_b.png'))

plt.show()

# Training history

In [None]:
import numpy as np

import matplotlib.pyplot as plt

In [None]:
np.array([0, 1])