In [None]:
%load_ext autoreload
%autoreload 2

from gorillatracker.classification.clustering import MERGED_DF, EXT_MERGED_DF
from gorillatracker.classification.metrics import analyse_embedding_space
from gorillatracker.classification.reid import run_knn_openset_recognition_cv, visualize_metrics, get_optimal_threshold
import numpy as np

edf = EXT_MERGED_DF

# Filter and prepare the data
df = edf[(edf["dataset"] == "SPAC") & (edf["model"] == "EfN-Pretrained")].reset_index(drop=True)
method = "knn1"

analysis = analyse_embedding_space(df)
max_distance = analysis["global_max_dist"]
min_distance = analysis["global_min_dist"]
print(f"Global min distance: {min_distance}, Global max distance: {max_distance}")
# Set up parameters
thresholds = np.linspace(0, max_distance, 60)
cv_results = run_knn_openset_recognition_cv(thresholds, df, method=method, construction_method="equal_classes")

visualize_metrics(cv_results, thresholds)
optimal_threshold_multiclass = get_optimal_threshold(cv_results, metric="multiclass_f1_weighted")
print(optimal_threshold_multiclass)

In [None]:
from gorillatracker.classification.reid import sweep_configs, configs, batch_visualize_metrics

partial = [
    ("SPAC+min3+max10", "ViT-Finetuned", "knn1centroid", "equal_classes"),
    ("SPAC+min3+max10", "ViT-Finetuned", "knn1centroid_iqr", "equal_classes"),
]
# partial = configs
results = sweep_configs(EXT_MERGED_DF, partial, resolution=40)
print(results.keys())
# batch_visualize_metrics(results, partial)

In [None]:
import pickle
from collections import defaultdict


def make_pickleable(d):
    if isinstance(d, (dict, defaultdict)):
        return {k: make_pickleable(v) for k, v in d.items()}
    elif isinstance(d, list):
        return [make_pickleable(v) for v in d]
    else:
        return d


pickle.dumps(make_pickleable(results))
print(results.values())

In [None]:
batch_visualize_metrics(results, partial)

# Visualize reid_sweep.py

In [None]:
import numpy as np
import pandas as pd
import pickle

def find_max_f1_weighted(cv_results):
    """Find the maximum weighted F1 score and its corresponding threshold."""
    f1_scores = [
        np.mean(metrics["multiclass_f1_weighted"])
        for metrics in cv_results.values()
        if "multiclass_f1_weighted" in metrics
    ]
    if not f1_scores:
        return 0, None
    max_f1 = max(f1_scores)
    max_threshold = next(
        threshold
        for threshold, metrics in cv_results.items()
        if "multiclass_f1_weighted" in metrics and np.mean(metrics["multiclass_f1_weighted"]) == max_f1
    )
    return max_f1, max_threshold


def process_results(results):
    """Process the results dictionary to extract max F1 scores and thresholds."""
    return {
        dataset: {
            model: {
                labelling_method: {"max_f1": max_f1, "threshold": threshold_at_max}
                for (_, _, labelling_method, _), value in results.items()
                if _dataset == dataset and _model == model
                for max_f1, threshold_at_max in [find_max_f1_weighted(value["cv_results"])]
            }
            for model in set(model for (_dataset, model, _, _) in results.keys() if _dataset == dataset)
        }
        for dataset in set(dataset for (dataset, _, _, _) in results.keys())
    }


def generate_data_rows(processed_results):
    """Generate data rows for the LaTeX table, grouping models under datasets."""
    labelling_methods = ["knn1", "knn5", "knn5distance", "knn1centroid", "knn1centroid_iqr"]

    data_rows = []
    for dataset, models in sorted(processed_results.items()):
        data_rows.append(f"\\multicolumn{{12}}{{l}}{{\\textbf{{{dataset}}}}} \\\\")
        for model, methods in sorted(models.items()):
            row = [f"& {model}"]
            row.extend(
                f" & {methods.get(method, {}).get('threshold', '-'):.2f} & {methods.get(method, {}).get('max_f1', '-'):.2f}"
                for method in labelling_methods
            )
            data_rows.append(" ".join(row) + " \\\\")
        data_rows.append("\\midrule")

    if data_rows:
        data_rows.pop()  # Remove the last \midrule

    return data_rows


def create_latex_table(data_rows):
    """Create the LaTeX table string."""
    latex_table = r"""
\begin{table}[H]
    \centering
    \resizebox{\textwidth}{!}{%
    \begin{tabular}{llcccccccccc}
    \toprule
    & & \multicolumn{2}{c}{knn1} & \multicolumn{2}{c}{knn5} & \multicolumn{2}{c}{knn5distance} & \multicolumn{2}{c}{knn1centroid} & \multicolumn{2}{c}{knn1centroid\_iqr} \\
    \cmidrule(lr){3-4} \cmidrule(lr){5-6} \cmidrule(lr){7-8} \cmidrule(lr){9-10} \cmidrule(lr){11-12}
    & & Thresh & F1 & Thresh & F1 & Thresh & F1 & Thresh & F1 & Thresh & F1 \\
    \midrule
    DATA_ROWS
    \bottomrule
    \end{tabular}%
    }
    \caption{Comparison of Labelling Approaches at the best weighted multiclass F1 score, given a fixed Threshold}
    \label{tab:labelling-metrics}
\end{table}
    """
    return latex_table.replace("DATA_ROWS", "\n    ".join(data_rows))


def main(results):
    """Main function to process results and generate the LaTeX table."""
    processed_results = process_results(results)
    data_rows = generate_data_rows(processed_results)
    latex_table = create_latex_table(data_rows)
    return latex_table


# Usage
filepath = "/workspaces/gorillatracker/sep26_reid_results.pkl"
with open(filepath, "rb") as f:
    results = pickle.load(f)
latex_table = main(results)
print(latex_table)