In [None]:
%tb
import pandas as pd
import glob
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

DIST_THRESHOLD=8

source_dir = Path("./datasets/tracknet-2/val_data")
val_dir = Path("./runs/detect/val5")

print(f"val_dir: {val_dir}")
    
pred_json = pd.read_json(val_dir.joinpath("predictions.json"), dtype={"match_name":str, "video_name":str})

videos = glob.glob("*/video/*.mp4", root_dir=source_dir)

x = list(np.arange(0.5, 1, 0.01))

y_fn = np.zeros(len(x), dtype=np.int32)
y_fp = np.zeros(len(x), dtype=np.int32)
y_tn = np.zeros(len(x), dtype=np.int32)
y_tp = np.zeros(len(x), dtype=np.int32)

def transform_coordinates_back(coord, w, h, target_size=640):
    
    # Determine padding
    max_dim = max(w, h)
    pad = (max_dim - min(w, h)) // 2
    
    # Adjust for scaling
    scale_factor = max_dim / target_size
    coord[0] *= scale_factor  # scale X
    coord[1] *= scale_factor  # scale Y

    # Adjust for padding
    if h < w:
        coord[1] -= pad
    else:
        coord[0] -= pad  # if height is greater, adjust X
    
    return coord

def val(threshold):
    fn, fp, tn, tp = 0, 0, 0, 0

    for v in videos:
        video_path = Path(v)
        match_name = video_path.parts[0]
        video_name = video_path.stem

        gt = pd.read_csv(source_dir.joinpath(Path(f"{match_name}/csv/{video_name}_ball.csv")))

        # filtering by match_name, video_name
        preds = pred_json.loc[(pred_json['match_name'] == match_name) & (pred_json['video_name'] == video_name)]
        preds = preds.reset_index()

        for i in range(len(preds)):
            for idx, frame_id in enumerate(range(preds.iloc[i].frame_id_min, preds.iloc[i].frame_id_max + 1)):
                gt_ = gt.iloc[frame_id]
                pr_ = preds.iloc[i].pred[idx]

                assert gt_.Frame == frame_id

                pr_ = list(filter(lambda x: x['confidence'] >= threshold, pr_))

                if gt_.Visibility == 0 and len(pr_) > 0:
                    fp += len(pr_)
                elif gt_.Visibility == 1 and len(pr_) > 0:
                    found = 0
                    for j in range(len(pr_)):
                        coord = np.array([pr_[j]['x'], pr_[j]['y']])
                        coord = transform_coordinates_back(coord, 1280, 720, 640)
                        dist = np.linalg.norm(np.array([gt_.X, gt_.Y]) - coord)

                        if dist < DIST_THRESHOLD:
                            found += 1
                        else:
                            fp += 1
                    if found > 0:
                        tp += 1
                    else:
                        fn += 1
                elif gt_.Visibility == 0 and len(pr_) == 0:
                    tn += 1
                elif gt_.Visibility == 1 and len(pr_) == 0:
                    fn += 1
    return fn, fp, tn, tp

for i, threshold in tqdm(enumerate(x)):
    y_fn[i], y_fp[i], y_tn[i], y_tp[i] = val(threshold)

#print(f"{y_fn=}, {y_fp=}, {y_tn=}, {y_tp=}")

acc = (y_tn + y_tp) / (y_fn + y_fp + y_tn + y_tp)
precision = y_tp / (y_tp + y_fp)
recall = y_tp / (y_tp + y_fn)
f1 = 2 / ((1 / precision) + (1 / recall))

plt.plot(x, y_fn, label="FN")
plt.plot(x, y_fp, label="FP")
plt.plot(x, y_tn, label="TN")
plt.plot(x, y_tp, label="TP")
plt.legend()
plt.xlabel("Confidence Threshold")
plt.ylabel("Count")
plt.grid()
plt.show()

plt.plot(x, acc, label="Accuracy")
plt.plot(x, precision, label="Precision")
plt.plot(x, recall, label="Recall")
plt.plot(x, f1, label="F1-Score")
plt.legend()
plt.xlabel("Confidence Threshold")
plt.grid()
plt.show()