In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch

from detection.testing import _get_checkpoint_filename, _get_results_folder
from detection.visualization import show_loss_history

from train_configurations.utils import get_standard_test_dataset
from train_configurations import (regnet_y_400mf, regnet_y_800mf,
                                  tracknet_v2, tracknet_v2_2f,
                                  tracknet_v2_4f, tracknet_v2_6f,
                                  tracknet_v2_rnn, tracknet_v2_rnn_scheduler)

plt.style.use('default')

detection_folder = '../TFM/figures/detection'
trajectories_folder = '../TFM/figures/trajectories'

if not os.path.exists(detection_folder):
    os.makedirs(detection_folder)
if not os.path.exists(trajectories_folder):
    os.makedirs(trajectories_folder)

%load_ext autoreload
%autoreload 2

# sequence of train configurations
training_configurations = [regnet_y_400mf, regnet_y_800mf,
                           tracknet_v2_2f, tracknet_v2,
                           tracknet_v2_4f, tracknet_v2_6f,
                           tracknet_v2_rnn, tracknet_v2_rnn_scheduler]


# get the name of the train configuration as string
def tc_name(training_configuration):
    """Return the name of the train configuration as string"""
    name = training_configuration.__name__.split('.')[-1]
    if name == 'tracknet_v2':
        name = 'tracknet_v2_3f'
    if name == 'tracknet_v2_rnn_scheduler':
        name = 'tracknet_v2_rnn_s'
    return name

# Detection

## Architectures and train configuration explanations

Table for the architectures

In [None]:
table = []
n = [3, 3, 2, 3, 4, 6, 1, 1]
h = [0, 0, 0, 0, 0, 0, 3, 3]
epochs = [40, 40, 20, 42, 23, 20, 35, 35]
batch_size = [16, 16, 8, 8, 8, 8, 8, 8]
for ti, tc in enumerate(training_configurations):
    d = {'configuration': tc_name(tc),
         'n': n[ti],
         'h': h[ti],
         'epochs': epochs[ti],
         'batch size': batch_size[ti]
         }
    table.append(d)

output_df = pd.DataFrame(table)
display(output_df)
print(output_df.to_latex(index=False))


## Examples

First let's define some utility functions:

In [None]:
def get_trained_model(training_configuration, training_phase=None):
    """Get the model of the provided train configuration with the final weights."""
    config = training_configuration.Config()

    checkpoint_folder = config._checkpoint_folder
    if training_phase is not None:
        checkpoint_folder = os.path.join(checkpoint_folder, training_phase)

    results_folder = _get_results_folder(checkpoint_folder, None)
    checkpoint_path = os.path.join(results_folder, _get_checkpoint_filename(checkpoint_folder))

    model = config.get_model()
    model.eval()
    model.load(checkpoint_path, device='cpu')

    return get_standard_test_dataset(training_configuration, 'prova', is_rnn=training_phase is not None), model


def get_output(model, dataset, i):
    """Get the model output for the given dataset element."""
    frame, _ = dataset[i]

    with torch.no_grad():
        output = model(frame.to(torch.float32).unsqueeze(0)).numpy().squeeze()
    return output


def get_frame(dataset, i):
    """Get the last input frame of the given dataset element element."""
    frame, _ = dataset[i]
    return frame[-3:].to(torch.float32).numpy().transpose(1, 2, 0)

Here we visualize an example of input frames and target heatmap. In this example, we will use the ``tracknet_v2_2f`` training configuration.

In addition, the model output is also shown

In [None]:
# load model from the desired train configuration
dataset, model = get_trained_model(tracknet_v2_2f)

dataset_element = 15
frames, heatmap = dataset[dataset_element]

frames = frames.to(torch.float32)
frames = frames.numpy().transpose(1, 2, 0)
frames = [frames[:,:,3*i:3*(i+1)] for i in range(2)]

heatmap = heatmap.to(torch.float32).squeeze().numpy()

heatmap_pred = get_output(model, dataset, dataset_element)

In [None]:
w, h, dpi = 2*640, 2*360, 100
fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, nrows=2, ncols=2)
axs = axs.ravel()

for ti, (frame, ax) in enumerate(zip(frames, axs)):
    ax.set_title(f'Frame {ti+1}')
    ax.imshow(frame)

    x = dataset._label_df['x'][ti+dataset_element]
    y = dataset._label_df['y'][ti+dataset_element]
    ax.scatter(x*frame.shape[1], y*frame.shape[0], zorder=100, facecolors='none', edgecolors='y', linewidths=3, s=150)

axs[2].imshow(frames[-1])
axs[2].imshow(heatmap, alpha=0.5, cmap='gray', vmin=0, vmax=1)
axs[2].set_title('Target heatmap')

axs[3].imshow(frames[-1])
axs[3].imshow(heatmap_pred, alpha=0.5, cmap='gray', vmin=0, vmax=1)
axs[3].set_title('Predicted heatmap')

fig.tight_layout()

fig.savefig(os.path.join(detection_folder, 'sample_input_a.png'))

plt.show()

In [None]:
w, h, dpi = 640, 360, 100
fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

ax.imshow(frames[-1])

ax.set_axis_off()
fig.tight_layout(pad=0)
fig.savefig(os.path.join(detection_folder, 'sample_input_frame.png'))

plt.show()


w, h, dpi = 640, 360, 100
fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

ax.imshow(frames[-1])
ax.imshow(heatmap_pred, alpha=0.5, cmap='gray', vmin=0, vmax=1)

ax.set_axis_off()
fig.tight_layout(pad=0)
fig.savefig(os.path.join(detection_folder, 'sample_input_overlay.png'))

plt.show()


w, h, dpi = 640, 360, 100
fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

ax.imshow(heatmap_pred, cmap='gray', vmin=0, vmax=1)

ax.set_axis_off()
fig.tight_layout(pad=0)
fig.savefig(os.path.join(detection_folder, 'sample_input_heatmap.png'))

plt.show()

### Output for all models

Here we visualize the output of all the various models.

Visualize the output heatmap for the various tracknet variants superimposed on the last frame of the input sequence

In [None]:
for tc in training_configurations[2:]:
    print(tc_name(tc))
    sample = 100
    if 'rnn' in tc_name(tc):
        training_phase = 'phase_3_0'
    else:
        training_phase = None
        sl = tc.Config()._sequence_length
        sample -= sl

    dataset, model = get_trained_model(tc, training_phase)
    frame = get_frame(dataset, sample)
    output =  get_output(model, dataset, sample)

    w, h, dpi = 640, 360, 120
    fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

    ax.imshow(frame)
    ax.imshow(output, alpha=0.5, cmap='gray', vmin=0, vmax=1)

    ax.set_axis_off()
    fig.tight_layout(pad=0)

    fig.savefig(os.path.join(detection_folder, f"{tc_name(tc)}_example.png"))

    plt.show()

## Loss history

Regnet and Standard configurations

In [None]:
from detection.visualization import get_loss_history
get_loss_history(tc)[0].min()


In [None]:
for i, tc in enumerate(training_configurations[:-2]):
    w, h, dpi = 640, 420, 130
    fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)
    ax.set_yscale('log')

    epoch_range = (1, 20) if i>=2 else None
    ax = show_loss_history(tc, ax=ax, epoch_range=epoch_range)

    if epoch_range is not None:
        ax.set_xticks(np.arange(0, 21, 5))
    ax.set_ylim(2e-5, 5e-2)

    ax.set_title(tc_name(tc))

    fig.tight_layout()
    fig.savefig(os.path.join(detection_folder, f'{tc_name(tc)}_loss.pdf'))
    plt.show()

Recurrent architecture

In [None]:
for tc in training_configurations[-2:]:
    phases = np.array([0, 10, 15, 20, 25, 30, 35]) + 0.5

    w, h, dpi = 640, 420, 130
    fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

    ax = show_loss_history(tc, ax=ax, epoch_range=(1, 35))

    xlim = ax.get_xbound()
    ylim = ax.get_ybound()

    ax.set_ylim(2e-5, 5e-2)
    ylim = ax.get_ybound()

    for ti, p in enumerate(phases[1:-1]):
        ax.plot([p,p], [ylim[0], ylim[1]], 'k--', linewidth=0.5 if ti%2==0 else 0.7)

    # ax.set_ylim(ylim)

    bbox = {'boxstyle': 'round',
            'facecolor': 'w',
            'edgecolor': 'None',
            'alpha': 1}

    for ti, c in enumerate(['r', 'g', 'b']):
        t = 5e-2
        xpos = 10*ti+6 if ti==0 else 10*ti+8
        ax.annotate(f'Phase {ti+1}', [xpos, t*ylim[1]+(1-t)*ylim[0]], bbox=bbox)
        ax.fill_between([phases[2*ti], phases[2*(ti+1)]], ylim[0], ylim[1], color=c, alpha=0.15)

    ax.legend(loc='upper right', framealpha=1)
    ax.set_title(tc_name(tc))

    fig.tight_layout()
    fig.savefig(os.path.join(detection_folder, f'{tc_name(tc)}_loss.pdf'))
    plt.show()

## Analyze results

Compute the position error from the csv outputted from model testing

### Compute the optimal detection threshold

First define the utility function ``get_output_df`` to get a dataframe from the csv file:

In [None]:
def get_output_df(training_configuration, training_phase=None):
    config = training_configuration.Config()

    checkpoint_folder = config._checkpoint_folder
    if training_phase is not None:
        checkpoint_folder = os.path.join(checkpoint_folder, training_phase)
    try:
        results_folder = _get_results_folder(checkpoint_folder, None)
        df = pd.read_csv(os.path.join(results_folder, 'output_val.csv'))
        return df.copy().loc[(df['x_true']!=0) & (df['y_true']!=0)] # extract only the labelled frames

    except FileNotFoundError as e:
        if training_phase is None:
            return get_output_df(training_configuration, training_phase='phase_3_0')
        else:
            raise FileNotFoundError("Output not found")

Utility functions to find the error distribution, and the mean error as function of the detection threshold

In [None]:
def compute_error_distribution(output_df, threshold=None, image_size=(360, 640)):
    if threshold is not None:
        df = output_df.loc[output_df['max_values']>=threshold]
    else:
        df = output_df
    x_true, x_pred, y_true, y_pred = [df[k].values for k in ['x_true', 'x_pred', 'y_true', 'y_pred']]
    return np.sqrt((image_size[1] * (x_true-x_pred))**2 + (image_size[0] * (y_true-y_pred))**2)


def detection_error_curve(output_df, num_thresholds=501):
    thresholds = np.linspace(0, 1, num_thresholds)

    # compute the detection rate for each threhsold
    detection_rate = [len(output_df.loc[output_df['max_values']>=t]) / len(output_df) for t in thresholds]

    # compute the mean positioning error
    mean_errors = []
    median_errors = []
    std_errors = []
    # precision = []
    for t in thresholds:
        error_distribution = compute_error_distribution(output_df, t)
        if len(error_distribution) == 0:
            mean_errors.append(0)
            continue
        mean_errors.append(np.mean(error_distribution))
        median_errors.append(np.median(error_distribution))
        std_errors.append(np.std(error_distribution))
        # precision.append(np.flatnonzero(error_distribution <= distance_threshold).size/error_distribution.size)

    return thresholds, np.asarray(detection_rate), np.asarray(mean_errors), np.asarray(median_errors), np.asarray(std_errors)

Detection and positioning error as function of threshold (all in one)

In [None]:
mean_error = []
median_error = []
std_error = []
dr = []

w, h, dpi = 1280, 720, 120
fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, nrows=2, ncols=3)

for i, (tc, ax) in enumerate(zip(training_configurations[2:], axs.T.ravel())):
    threhsolds, detection_rates, mean_errors, median_errors, std_errors = detection_error_curve(get_output_df(tc), 501)

    w = np.where(mean_errors > 0)
    threhsolds = threhsolds[w]
    detection_rates = detection_rates[w]
    mean_errors = mean_errors[w]

    # set chosen threshold
    t = 0.1
    ti = np.argmin(np.abs(t - threhsolds))

    # positioning error
    ax.plot([t,t], [0,max(mean_errors)], '--', color='#000000', linewidth=1, label=f'chosen $v_{{th}}$ = {t}')
    ax.plot(threhsolds, mean_errors, 'C0-', label= f'Positioning error = {mean_errors[ti]:.3g} px')
    ax.plot([0,1], [mean_errors[ti], mean_errors[ti]], '-', color='C0', linewidth=1)

    # detection rate
    ax2 = ax.twinx()
    ax2.plot(threhsolds, detection_rates, 'C1-', label=f'Detection rate = {100*detection_rates[ti]:.0f}%')
    ax2.plot([0,1], [detection_rates[ti], detection_rates[ti]], '-', color='C1', linewidth=1)

    # axis limits and log scale on positioning error
    ax.set_xlim(0, 1)
    # ax.set_yscale('log')
    # ax.set_ylim(0.5, 38)
    ax.set_ylim(0,max(mean_errors))

    # axis limits on detection rate
    ax2.set_ylim(0, 1)

    # set title
    ax.set_title(tc_name(tc))

    # set x-abels
    if i%2==1:
        ax.set_xlabel('$v_{th}$')
    else:
        ax.set_xticklabels([])

    # set y-labels
    if i<2:
        ax.set_ylabel('Positioning error')
    # else:
    #     ax.set_yticklabels([])

    if i>=4:
        ax2.set_ylabel('Detection rate')
    else:
        ax2.set_yticklabels([])

    # show legend
    handles, labels = ax.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(handles + handles2, labels + labels2, loc='center', fontsize='small')

    mean_error.append(mean_errors[ti])
    median_error.append(median_errors[ti])
    std_error.append(std_errors[ti])
    dr.append(detection_rates[ti])

fig.suptitle("Detection rate and positioning error as a funciton of the detection threshold")

fig.tight_layout()
fig.savefig(os.path.join(detection_folder, f'threhsold_curves.pdf'))

plt.show()

Mean error as function of the detection rate

In [None]:
tc = training_configurations[2]

mean_error = []
median_error = []
std_error = []
dr = []

for tc in training_configurations[2:]:
    # w, h, dpi = 640*2, 360, 100
    w, h, dpi = 1280, 480, 120
    fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, ncols=2)

    threhsolds, detection_rates, mean_errors, median_errors, std_errors = detection_error_curve(get_output_df(tc), 501)

    w = np.where(mean_errors > 0)
    threhsolds = threhsolds[w]
    detection_rates = detection_rates[w]
    mean_errors = mean_errors[w]

    t = 0.1

    ti = np.argmin(np.abs(t - threhsolds))

    # plot detection rate as function of threshold
    axs[0].plot(threhsolds, detection_rates, 'C0-', label=f'detection rate = {100*detection_rates[ti]:.0f}%')
    axs[0].plot([0,1], [detection_rates[ti], detection_rates[ti]], '-', color='C0', linewidth=1)
    axs[0].set_xlabel('Threhsold')
    axs[0].set_ylabel('Detection rate')
    axs[0].set_ylim(0, 1)

    ax2 = axs[0].twinx()

    # plot positioning error as function of threshold
    ax2.plot([t,t], [0,max(mean_errors)], '--', color='#000000', linewidth=1, label=f'chosen threshold = {t}')
    ax2.plot(threhsolds, mean_errors, 'C1-', label= f'$\Delta$ p = {mean_errors[ti]:.3g} px')
    ax2.plot([0,1], [mean_errors[ti], mean_errors[ti]], '-', color='C1', linewidth=1)
    ax2.set_xlabel('Threshold')
    ax2.set_ylabel('$\Delta$ p')
    ax2.set_ylim(0, max(mean_errors))

    handles, labels = axs[0].get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(handles2 + handles, labels2 + labels, loc='center')


    auc = -np.trapz(mean_errors, x=detection_rates)

    axs[1].plot(detection_rates, mean_errors, color='#000000', label=f'auc = {auc:.3g}')
    axs[1].scatter(detection_rates[ti], mean_errors[ti], c='#000000', s=50)
    axs[1].set_xlabel('Detection rate')
    axs[1].set_ylabel('Mean error')
    axs[1].set_ylim(0, max(mean_errors))

    axs[1].legend()

    for ax in axs:
        ax.set_xlim(0, 1.02)
        ax.set_ylim(0, 12)

    fig.suptitle(tc_name(tc))

    fig.tight_layout()
    fig.savefig(os.path.join(detection_folder, f'{tc_name(tc)}.pdf'))

    plt.show()

    mean_error.append(mean_errors[ti])
    median_error.append(median_errors[ti])
    std_error.append(std_errors[ti])
    dr.append(detection_rates[ti])


In [None]:
tc = training_configurations[2]

w, h, dpi = 1280, 720, 150
fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

# ax.set_yscale('log')

ym = 0

for i, tc in enumerate(training_configurations[2:5]):
    threhsolds, detection_rates, mean_errors, median_errors, std_errors = detection_error_curve(get_output_df(tc), 1001)

    # get only those elements where the mean error is greater 
    w = np.where(mean_errors > 0)
    threhsolds = threhsolds[w]
    detection_rates = detection_rates[w]
    mean_errors = mean_errors[w]

    # t = 0.1
    # ti = np.argmin(np.abs(t - threhsolds))
    # em = 1.5
    # ti = np.argmin(np.abs(mean_errors - em))
    dr = 0.97
    ti = np.argmin(np.abs(detection_rates - dr))

    # plot error-detection curve
    # auc = -np.trapz(mean_errors, x=detection_rates)
    # ax.plot(detection_rates, mean_errors, label=f'{tc_name(tc)}:\n     auc = {auc:.3g}', color=f'C{i}')
    ax.plot(detection_rates, mean_errors, label=f'{tc_name(tc)}:\n  dr = {detection_rates[ti]:.3g}\n  t = {threhsolds[ti]:.3g}\n  e = {mean_errors[ti]:.3g}', color=f'C{i}')
    # ax.scatter(detection_rates[ti], mean_errors[ti], s=50, c=f'C{i}')

    # show correspondence lines at threshold 0.1
    ax.plot([detection_rates[ti], detection_rates[ti]], [0,100], '--', color=f'C{i}', linewidth=1)
    ax.plot([0,2], [mean_errors[ti], mean_errors[ti]], '--', color=f'C{i}', linewidth=1)

    ym = max(ym, max(mean_errors))

ax.legend()

ax.set_xlabel('Detection rate')
ax.set_ylabel('Mean error')

ax.set_xlim(0.8, 1)
ax.set_ylim(0, 3)

fig.suptitle("Mean error as function of the detection rate")

fig.tight_layout()
fig.savefig(os.path.join(detection_folder, f'error_detection_curve.pdf'))

plt.show()

Build a dataframe for each train configuration.

In [None]:
df_out = []
for tc in training_configurations:
    # phase = 'phase_3_0' if 'rnn' in tc_name(tc) else None
    df = get_output_df(tc)
    if 'reg' in tc_name(tc):
        df_out.append(df)
    else:
        df2 = df.copy().loc[df['max_values'] > 0.1]
        df2['detection_rate'] = np.zeros(len(df2)) + len(df2)/len(df)
        df_out.append(df2)

Compute the position error as the magnitude of the difference between the true position and the estimated position

In [None]:
image_size = (360, 640)

for output_df in df_out:
    bbb=compute_error_distribution(output_df)
    output_df['error']=compute_error_distribution(output_df)

### Error histograms

Here we visualize the error distributions for the various models

Summary dataframe with most important statistics of the error:
 - mean
 - standard deviation
 - median
 - percentage with error < 1 px
 - percentage with error < 5 px
 - percentage with error < 10 px

In [None]:
r = {}
r['a'] = 9
r

In [None]:
thresholds = [1, 5, 10] #error strictly smaller than the threshold

summary = []
for tc, df in zip(training_configurations, df_out):
    result = {}
    result['train_configuration'] = tc_name(tc)

    if 'tracknet' in tc_name(tc):
        result['detection rate'] = np.mean(df['detection_rate'])
    else:
        result['detection rate'] = 1

    s = df['error'].size
    if s==0:
        result['mean'] = np.nan
        result['std'] = np.nan
        # result['median'] = np.nan
    else:
        result['mean'] = np.mean(df['error'])
        result['std'] = np.std(df['error'])
        # result['median'] = np.median(df['error'])

    for t in thresholds:
        if s==0:
            result[f'{t} px'] = np.nan
        else:
            result[f'{t} px'] = np.flatnonzero(np.asarray(df['error']) < t).size / s
    summary.append(result)

df_summary = pd.DataFrame(summary)
df_summary

In [None]:
# format the mean, median and std of the errors as {:.3g}
# get the error rate as a percentage and format that as {:.2f}
df_s2 = df_summary.copy()

for k in df_s2.columns:
    if 'px' in k or 'detection' in k:
        df_s2[k] = df_s2[k].apply(lambda x: f'{100*x:.3g}%')
    else:
        df_s2[k] = df_s2[k].apply(lambda x: f'{x:.3g}' if type(x) is float else x)

print(df_s2.to_latex(index=False))

Visualize error distribution histograms (one ifigure per configuration)

In [None]:
def print_summary(summary: dict):
    for k, v in summary.items():
        print_v = v
        if 'px' in k:
            print_v = f'{100*v:.2g}%'
        elif k!='train_configuration':
            print_v = f'{v:.2g}'
        print(f"{k}: {print_v}")


hist_range = thresholds[-1]

for tc, df, s in zip(training_configurations, df_out, summary):
    print_summary(s)

    w, h, dpi = 640, 480, 120
    fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

    ax.hist(df['error'].clip(upper=hist_range), bins=np.arange(hist_range+2), density=True)#, align='left', rwidth=0.8)

    # legend
    for k in ['mean', 'std']:
        ax.plot([], [], ' ', label=f'{k}:'.ljust(14) + f'{s[k]:.1f}'.rjust(5) + ' ')
    for k in [k for k in s.keys() if 'px' in k]:
        ax.plot([], [], ' ', label='error < ' + f'{k}:'.rjust(6) + f'{100*s[k]:.1f}%'.rjust(6) + ' ')

    ax.legend(handlelength=0, prop={'family': 'monospace'}, loc='upper center')

    # set x ticks
    ax.set_xticks(np.arange(0, 11, 2))

    # set axis limits
    ax.set_xlim(-0.5, 11.5)
    ax.set_ylim(0, 1)

    # set title and axis labels
    ax.set_title(tc_name(tc))
    ax.set_xlabel('Error magnitude')
    ax.set_ylabel('Occurrence')

    fig.tight_layout()
    fig.savefig(os.path.join(detection_folder, f'{tc_name(tc)}_error.pdf'))

    plt.show()

Visualize error distribution histograms (all in one figure)

In [None]:
hist_range = thresholds[-1]

w, h, dpi = 1200, 600, 110
fig, axs = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi, nrows=2, ncols=4)

axs = axs.T.ravel()

for ti, (tc, df, s, ax) in enumerate(zip(training_configurations, df_out, summary, axs)):
    ax.hist(df['error'].clip(upper=hist_range), bins=np.arange(hist_range+2), density=True)#, align='left', rwidth=0.8)

    # legend
    for k in ['mean', 'std']:
        ax.plot([], [], ' ', label=f'{k}:'.ljust(14) + f'{s[k]:.1f}'.rjust(5) + ' ')
    for k in [k for k in s.keys() if 'px' in k]:
        ax.plot([], [], ' ', label='error < ' + f'{k}:'.rjust(6) + f'{100*s[k]:.1f}%'.rjust(6) + ' ')

    ax.legend(handlelength=0, prop={'family': 'monospace', 'size': 'small'}, loc='upper center')

    # set title and axis labels
    ax.set_title(tc_name(tc))
    if ti<2:
        ax.set_ylabel('Occurrence')
    else:
        ax.set_yticklabels([])

    ax.set_xticks(np.arange(0, 11, 2))
    if ti%2==1:
        ax.set_xlabel('Positioning error in px')
    else:
        ax.set_xticklabels([])

    # set axis limits
    ax.set_xlim(-0.5, 11.5)
    ax.set_ylim(0, 1)

fig.tight_layout()
fig.savefig(os.path.join(detection_folder, f'errors.pdf'))

plt.show()

### Error as function of maximum heatmap value

Here we do a scatterplot of the error and the maximum heatmap value

In [None]:
from scipy.optimize import curve_fit

def linear(x, a, b):
    return a*x + b

def quadratic(x, a, b, c):
    return a*x*x + b*x + c

for tc, df in zip(training_configurations[2:], df_out[2:]):
    w, h, dpi = 640, 360, 120
    fig, ax = plt.subplots(figsize=(w/dpi, h/dpi), dpi=dpi)

    # sns.histplot(x=df['max_values'], y=df['error'], bins=(np.linspace(0, 1, 50), np.arange(15)), ax=ax, cbar=True)
    ax.scatter(x=df['max_values'], y=df['error'], alpha=0.5)

    idx = df['max_values'] > 0.2
    if idx.values.sum() > 0:
        params, cov = curve_fit(linear, df['max_values'][idx], df['error'][idx])
        x = np.linspace(0, 1, 2)
        plt.plot(x, linear(x, *params), 'k-')

        params, cov = curve_fit(quadratic, df['max_values'][idx], df['error'][idx])
        x = np.linspace(0, 1, 101)
        plt.plot(x, quadratic(x, *params), 'r-')

    ax.set_xlabel('max heatmap value')
    ax.set_ylabel('position error')

    ax.set_xlim(-0.05, 1.05)
    ax.set_title(tc_name(tc))

    plt.show()

# Trajectories fitting

In [None]:
from trajectories.data_reading import get_candidates, get_frame, get_heatmap

from trajectories.fitting import fit_trajectories
from trajectories.filtering import build_trajectory_graph, find_shortest_paths, build_path_mapping
from trajectories.visualization import create_trajectory_video, visualize_trajectory_graph, show_neighboring_trajectories

starting_frame, candidates, n_candidates, values = get_candidates(tracknet_v2)
frame_sequence = list(range(starting_frame, starting_frame + len(candidates)))

fitting_info = fit_trajectories(candidates, n_candidates, starting_frame)
trajectory_graph = build_trajectory_graph(fitting_info)
shortest_paths = find_shortest_paths(trajectory_graph)
path_mapping = build_path_mapping(fitting_info, shortest_paths)

## Trajectory graph example

In [None]:
ax = visualize_trajectory_graph(trajectory_graph, shortest_paths[0][0], 889)
fig = ax.figure
fig.savefig(os.path.join(trajectories_folder, f'graph_example.pdf'))
plt.close(fig)

## Fit examples

In [None]:
from IPython.display import clear_output

dpi = 150
dark_mode = False

Big examples for success and failure

In [None]:
# success example

sf = 996
num_prev = 4
num_next = 3
filename = f'example.png'
fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                  fitting_info=fitting_info, path_mapping=path_mapping,
                                  starting_frame=sf,
                                  dark_mode=dark_mode, dpi=dpi,
                                  num_prev=num_prev, num_next=num_next,
                                  num_frames=0)
plt.close(fig)

# fail example
sf = 1502
num_prev=2
num_next=3

filename = f'example_fail.png'
fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                  fitting_info=fitting_info, path_mapping=path_mapping,
                                  starting_frame=sf,
                                  dark_mode=dark_mode, dpi=dpi,
                                  num_prev=num_prev, num_next=num_next,
                                  num_frames=0)
plt.close(fig)

clear_output()

Shot example (near team)

In [None]:
for sf in [983, 991, 992, 996]:
    if sf<=991:
        num_prev = 3
        num_next = 4
    else:
        num_prev = 4
        num_next = 3
    filename = f'{sf}.png'
    fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                      fitting_info=fitting_info, path_mapping=path_mapping,
                                      starting_frame=sf,
                                      dark_mode=dark_mode, dpi=dpi,
                                      num_prev=num_prev, num_next=num_next,
                                      num_frames=0)
    plt.close(fig)
    clear_output()

Bounce example (floor)

In [None]:
for sf in [1006, 1007, 1016]:
    if sf<=1006:
        num_prev = 4
        num_next = 3
    else:
        num_prev = 5
        num_next = 2
    filename = f'{sf}.png'
    fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                      fitting_info=fitting_info, path_mapping=path_mapping,
                                      starting_frame=sf,
                                      dark_mode=dark_mode, dpi=dpi,
                                      num_prev=num_prev, num_next=num_next,
                                      num_frames=0)
    plt.close(fig)
    clear_output()

Shot example (far team)

In [None]:
for sf in [1024, 1025]:
    if sf<=1024:
        num_prev = 5
        num_next = 2
    else:
        num_prev = 6
        num_next = 1
    filename = f'{sf}.png'
    fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                      fitting_info=fitting_info, path_mapping=path_mapping,
                                      starting_frame=sf,
                                      dark_mode=dark_mode, dpi=dpi,
                                      num_prev=num_prev, num_next=num_next,
                                      num_frames=0)
    plt.close(fig)
    clear_output()

for sf in [1538, 1547, 1549, 1555]:
    if sf<=1547:
        num_next = 4
        if sf==1538:
            num_prev = 2
        else:
            num_prev = 3
    else:
        num_prev = 4
        num_next = 3
    filename = f'{sf}.png'
    fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                      fitting_info=fitting_info, path_mapping=path_mapping,
                                      starting_frame=sf,
                                      dark_mode=dark_mode, dpi=dpi,
                                      num_prev=num_prev, num_next=num_next,
                                      num_frames=0)
    plt.close(fig)
    clear_output()

Failure example

In [None]:
for sf in [1495, 1496, 1499, 1501, 1502]:
    filename = f'fail_{sf}.png'

    display='k_min k_max params'
    num_next=3
    num_prev=2
    if sf == 1495:
        display='k_min k_max params'
        display_prev = None
        num_next=4
        num_prev=1
    else:
        display_prev = 'k_max'

    fig, ax = create_trajectory_video(tracknet_v2, os.path.join(trajectories_folder, filename),
                                      fitting_info=fitting_info, path_mapping=path_mapping,
                                      starting_frame=sf,
                                      dark_mode=dark_mode, dpi=dpi,
                                      num_next=num_next, num_prev=num_prev,
                                      alpha_prev=1, alpha_next=1,
                                      display=display, display_prev=display_prev, display_next='k_min',
                                      show_outside_range=True,
                                      num_frames=0)
    clear_output()
    # plt.show()
    plt.close(fig)