In [None]:
%load_ext autoreload
%autoreload 2
import dotenv
dotenv.load_dotenv()

# Load trainer

In [None]:
from experimentator import build_experiment, find

exp = build_experiment(find("../configs/ballstate.py"), load_weights=False)

# Show dataset

In [None]:
from matplotlib import pyplot as plt 
from dataset_utilities.ds.raw_sequences_dataset import BallState

batch_size = 7
fig, axes = plt.subplots(len(exp.subsets), batch_size, figsize=(20,9))
for i, subset in enumerate(exp.subsets):
    keys, batch = next(exp.batch_generator(subset, batch_size=batch_size)) # balances batches wrt ball state
    for j, key in enumerate(keys):
        axes[i,j].imshow(batch["batch_input_image"][j])
        #axes[i,j].axis("off")
        axes[i,j].set_xticks([])
        axes[i,j].set_yticks([])
        axes[i,j].set_title(BallState(batch['batch_ball_state'][j]))
    axes[i,0].set_ylabel(subset.name)

# Train for 10 epochs

In [None]:
exp.train(11)

# Load experiment

In [None]:
import os
from experimentator import build_experiment

experiment_id = 'latest'   # update with experiment ID
folder = os.path.join(os.environ['RESULTS_FOLDER'], "ballstate", experiment_id)
exp = build_experiment(os.path.join(folder, 'config.py'))

# Load metrics

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from experimentator import DataCollector

dc = DataCollector(os.path.join(folder, "history.dcp"))
fig, axes = plt.subplots(2, 1, sharex=True)
for ax, name, metric, fct in [
    (axes[0], "Loss",             "loss",                   lambda x:x),
    (axes[1], "FLYING Precision", "classification_metrics", lambda x: x['precision'].iloc[1])
]:
    for subset in ["training", "validation"]:
        label = f"{subset}_{metric}"
        l = np.array([fct(x) if x is not None else np.nan for x in dc[label, :]])
        w = np.where(np.isfinite(l))[0]
        ax.plot(w, l[w], label=label, markersize=5, marker='.')
        ax.set_title(name)
    ax.legend()

In [None]:
# TODO: fix line 2 and 3 inverted
for data in dc.history:
    for name, fct in [
        ('testing_classification_metrics', lambda x:x),
        ('testing_confusion_matrix', lambda m: "\n".join([f"{str(BallState(i))}: \tprecision = {m[i,i]/np.sum(m[:,i]):.05f} \trecall = {m[i,i]/np.sum(m[i,:]):.02f}" for i in range(4)]))
    ]:
        try:
            print(fct(data[name]))
        except:
            pass
    print("")

In [None]:
import seaborn as sn
from dataset_utilities.ds.raw_sequences_dataset import BallState
print([BallState(i) for i in range(4)])
for m in [a for a in dc['testing_confusion_matrix', :] if a is not None]:
    for i in range(1,4):
        print(f"{str(BallState(i))}: \tprecision = {m[i,i]/np.sum(m[:,i]):.05f} \trecall = {m[i,i]/np.sum(m[i,:]):.02f}")
    ax = plt.figure(figsize=(4,3)).gca()
    sn.heatmap(m, annot=True, ax=ax)
    ax.set_xlabel("pred class")
    ax.set_ylabel("true class")

# Evaluation

In [None]:
batch_size = 6
its = {subset.name: exp.batch_generator(subset, batch_size=batch_size) for subset in exp.subsets}

In [None]:
from matplotlib import pyplot as plt
from experimentator.dataset import collate_fn

fig, axes = plt.subplots(len(exp.subsets), batch_size, figsize=(20,10))
for i, subset in enumerate(exp.subsets):
    keys, batch_input = next(its[subset.name]) # balances batches wrt ball state
    batch_output = exp.batch_eval(batch_input)
    print(batch_output['batch_output'].numpy())
    print(batch_output['batch_target'].numpy())
    for j in range(batch_size):
        ax = axes[i,j]
        ax.imshow(batch_input["batch_input_image"][j])
        ax.set_xticks([])
        ax.set_yticks([])
        true_class = str(BallState(batch_input['batch_ball_state'][j]))
        pred_class = str(BallState(np.argmax(batch_output['batch_output'][j])))
        if true_class != pred_class:
            for spine in ['bottom', 'top', 'right', 'left']:
                ax.spines[spine].set_color('red')
        ax.set_title(true_class)
        for y, text in [
            (.9, f"{true_class}"),
            (.2, f"{pred_class}"),
            (.1, f"{np.max(batch_output['batch_output'][j]):.2f}"),
        ]:
            ax.text(.05, y, text, horizontalalignment='left', verticalalignment='center', transform=ax.transAxes, color='white', fontsize=12)
    axes[i,0].set_ylabel(subset.name)