In [None]:
import collections
import operator
import pathlib
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2 as cv
import pandas as pd
import matplotlib.pyplot as plt
from torchsummary import summary
from matplotlib import rc

%matplotlib inline

%load_ext autoreload
%autoreload 2

font = {'family': 'Times New Roman', 'weight': 'bold', 'size': 12}
rc('font', **font)

In [None]:
def collect_metric_results_for_iters(
    results_dir_path,
    metrics = ('mota', 'motp'),
    results_file_patt = 'eval_results.csv',
):
    metric_vals = collections.defaultdict(list)

    for file in pathlib.Path(results_dir_path).rglob(results_file_patt):
        model_dir = file.parent.parent
        n_model_iters = int(model_dir.stem)
        
        results = pd.read_csv(str(file), index_col=0)
        overall = results.loc['OVERALL']
        for metric in metrics:
            metric_vals[metric].append((n_model_iters, overall[metric]))
    
    return metric_vals

In [None]:
def plot_1d_eval_results_multiple(dir_path_label_pairs, metric, metric_label):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
    
    ax.set_title('Comparison Plot of Overall Tracker Metrics')
    ax.set_xlabel('Iteration')
    ax.set_ylabel(metric_label)

    for results_dir_path, label in dir_path_label_pairs:
        metric_vals = collect_metric_results_for_iters(results_dir_path, (metric,))
        vals = metric_vals[metric]
        vals = sorted(vals, key=operator.itemgetter(0))
        xs = np.asarray([pair[0] for pair in vals])
        ys = np.asarray([pair[1] for pair in vals])
        ax.plot(xs, ys, label=label)
    
    ax.legend(loc='best')
    fig.tight_layout()

    return fig

dir_path_label_pairs = (
    ('./eval_orig_mini', 'Original (mini)'), ('./eval_dsa', 'Original')
)

fig_mota = plot_1d_eval_results_multiple(dir_path_label_pairs, 'mota', 'MOTA')
fig_mota.show()

In [None]:
fig_motp = plot_1d_eval_results_multiple(dir_path_label_pairs, 'motp', 'MOTP')
fig_motp.show()