# Classification Model Performance

Use this notebook to track validation metrics across epochs for classifier fine-tuning runs. Set up the experiment list below to point at your TensorBoard logs (each directory must contain the `fold-*` subfolders written by `Trainer`).

In [None]:
from pathlib import Path
import sys
repo_root = Path.cwd().resolve().parents[0]
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
from xai.models.SimplifiedClinicalTransformer.Topologies.BertLikeTransformer.Explainer.ClassificationEvaluator import compute_performance_folds

In [None]:
from samecode.plot.pyplot import subplots
import seaborn as sns
import numpy as np
import pandas as pd

In [None]:
# Root folder that contains experiment subdirectories (each with fold-* logs)
results_root = Path('..') / 'results' / 'runs'
results_root

In [None]:
# (run_id, checkpoint_folder, label)
# checkpoint_folder should be '' when you want the parent directory itself
experiments = [
    ('ClassifierBaseline', '', 'Baseline'),
    ('ClassifierFinetune', 'model.E000050.h5', 'Finetuned'),
]

metric = 'epoch_sparse_categorical_accuracy'  # e.g., epoch_loss, epoch_sparse_categorical_accuracy
split = 'validation'

In [None]:
perf_frames = []
summary_frames = []

for run_id, checkpoint, label in experiments:
    run_path = results_root / run_id
    if checkpoint:
        run_path = run_path / checkpoint
    run_path = run_path.resolve()
    if not run_path.exists():
        raise FileNotFoundError(f'Run path not found: {run_path}')

    perf = compute_performance_folds(
        path=str(run_path),
        label=label,
        metric=metric,
        split=split,
    )

    summary = (
        perf.groupby(['epoch', 'Model'])
        .agg({metric: ['mean', 'std']})
        .reset_index()
    )
    summary.columns = ['epoch', 'Model', f'{metric}_mean', f'{metric}_std']

    perf_frames.append(perf)
    summary_frames.append(summary)

perf_df = pd.concat(perf_frames).reset_index(drop=True)
summary_df = pd.concat(summary_frames).reset_index(drop=True)
summary_df.head()

In [None]:
axs = subplots(cols=1, w=6, h=4)
plot_df = summary_df.groupby(['epoch', 'Model']).median().reset_index()
sns.lineplot(data=plot_df, x='epoch', y=f'{metric}_mean', hue='Model', ax=axs[0])
axs[0].set_xlabel('Epoch', weight='bold')
axs[0].set_ylabel('Mean {}'.format(metric.replace('epoch_', '').replace('_', ' ').title()), weight='bold')
axs[0].set_title(f'Classification Performance ({split} set)')
sns.despine(offset=10, trim=True)

In [None]:
# Optional: inspect the raw per-fold values
perf_df.sort_values(['Model', 'epoch']).head()