## Copy Data Into Folds

In [None]:
import os
import json
import random
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

mins_per_fold = 50
fold_data_dir = "/media/george-vengrovski/disk1/decoder_data"

birds_wav_paths = [
    "/media/george-vengrovski/disk2/canary/yarden_data/llb3_data/llb3_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb11_data/llb11_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb16_data/llb16_songs"
]

song_detection_json = "files/contains_llb.json"

# Build a mapping from filename to its full path for all birds
wav_file_to_path = {}
for bird_path in birds_wav_paths:
    if os.path.isdir(bird_path):
        for fname in os.listdir(bird_path):
            if fname.endswith('.wav'):
                wav_file_to_path[fname] = os.path.join(bird_path, fname)

# Parse the song detection JSON and collect song files and their durations per bird
with open(song_detection_json, 'r') as f:
    data = json.load(f)

bird_song_files = {}  # bird_id -> list of (filename, duration_seconds)
for entry in data:
    if not entry.get('song_present', False):
        continue
    filename = entry['filename']
    if filename not in wav_file_to_path:
        continue
    bird_id = filename.split('_')[0]
    total_duration = 0.0
    for seg in entry.get('segments', []):
        onset_ms = seg.get('onset_ms', 0)
        offset_ms = seg.get('offset_ms', 0)
        total_duration += (offset_ms - onset_ms) / 1000.0
    if total_duration <= 0:
        continue
    if bird_id not in bird_song_files:
        bird_song_files[bird_id] = []
    bird_song_files[bird_id].append((filename, total_duration))

# For each bird, randomly assign files to folds so each fold has at least mins_per_fold minutes
folds_info = {}  # bird_id -> list of folds, each fold is list of (filename, duration)
for bird_id, files in bird_song_files.items():
    random.shuffle(files)
    folds = []
    current_fold = []
    current_fold_duration = 0.0
    for fname, dur in files:
        current_fold.append((fname, dur))
        current_fold_duration += dur
        if current_fold_duration >= mins_per_fold * 60:
            folds.append(current_fold)
            current_fold = []
            current_fold_duration = 0.0
    if current_fold:  # Add any remaining files to a final fold
        folds.append(current_fold)
    folds_info[bird_id] = folds

# Copy files to their respective fold directories with progress bar
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_dir = os.path.join(fold_data_dir, bird_id, f"fold{i+1}")
        os.makedirs(fold_dir, exist_ok=True)
        print(f"Copying files for {bird_id} fold {i+1} ({len(fold)} files)...")
        for fname, _ in tqdm(fold, desc=f"{bird_id} fold{i+1}", leave=False):
            src = wav_file_to_path[fname]
            dst = os.path.join(fold_dir, fname)
            if not os.path.exists(dst):
                shutil.copy2(src, dst)

# Gather fold durations for plotting
plot_bird_ids = []
plot_fold_names = []
plot_fold_minutes = []
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_minutes = sum(dur for _, dur in fold) / 60
        plot_bird_ids.append(bird_id)
        plot_fold_names.append(f"fold{i+1}")
        plot_fold_minutes.append(fold_minutes)

# Plot bar plots showing minutes per fold for each bird
plt.figure(figsize=(12, 6))
sns.set_style("whitegrid")
bar_labels = [f"{bird}-{fold}" for bird, fold in zip(plot_bird_ids, plot_fold_names)]
bars = plt.bar(bar_labels, plot_fold_minutes, color='skyblue')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.1f}',
             ha='center', va='bottom')
plt.title('Minutes of Song Data per Fold (per Bird)', fontsize=14, pad=20)
plt.xlabel('Bird-Fold', fontsize=12)
plt.ylabel('Total Duration (minutes)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## We want to create embeddings for each fold

In [None]:
import sys
import importlib

import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'  # Add this before running your code

decoding_module = importlib.import_module("decoding")

class Args:
    def __init__(self, mode, bird_name, model_name, wav_folder, song_detection_json_path, num_samples_umap):
        self.mode = mode
        self.bird_name = bird_name
        self.model_name = model_name
        self.wav_folder = wav_folder
        self.song_detection_json_path = song_detection_json_path
        self.num_samples_umap = num_samples_umap
        self.num_random_files_spec = 1000  # Default value
        self.single_threaded_spec = False  # Default value
        self.nfft = 1024  # Default value
        self.raw_spectrogram_umap = False  # Default value for store_true flag
        self.state_finding_algorithm_umap = "HDBSCAN"  # Default value
        self.context_umap = 1000  # Default value

for root, dirs, files in os.walk(fold_data_dir):
    for dir in dirs:
        if "fold" in dir:
            bird = os.path.basename(root)
            bird_name_fold = f"{bird}_{dir}"
            wav_folder = os.path.join(root, dir)
            args = Args(
                mode="single",
                bird_name=bird_name_fold,
                model_name="BF_Canary_Joint_Run",
                wav_folder=wav_folder,
                song_detection_json_path=song_detection_json,
                num_samples_umap="1e6"
            )
            print(f"Running decoding.py --mode single --bird_name {bird_name_fold} --model_name BF_Canary_Joint_Run --wav_folder {wav_folder} --song_detection_json_path {song_detection_json} --num_samples_umap 1e6")
            decoding_module.main(args)


## Evalulate TweetyBERT, Parameteric, Load and Transform on Folds

In [8]:
import sys, pathlib, shutil, time, gc
from collections import defaultdict, Counter
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
# ── we only need classic UMAP for logits viz ───────────────────────────────
import umap
# from umap.parametric_umap import ParametricUMAP
# import hdbscan
# from hdbscan import approximate_predict
import matplotlib.pyplot as plt


# ── constants ────────────────────────────────────────────────────────────────
ROOT       = pathlib.Path().resolve()
NPZ_DIR    = ROOT / "files"
SAVE_DIR   = ROOT / "results" / "decoder_eval"
SAVE_DIR.mkdir(parents=True, exist_ok=True)
CSV_PATH   = SAVE_DIR / "timings.csv"      # central log for resumption

RAW_FILES = [
    # "llb3_fold1.npz",
    # "llb3_fold2.npz",
    # "llb3_fold3.npz",
    # "llb3_fold4.npz",
    # "llb3_fold5.npz",
    # "llb3_fold6.npz",
    # "llb3_fold7.npz",
    # "llb3_fold8.npz",
    # "llb3_fold9.npz",
    # "llb11_fold1.npz",
    # "llb11_fold2.npz",
    # "llb11_fold3.npz",
    # "llb11_fold4.npz",
    # "llb11_fold5.npz",
    # "llb11_fold6.npz",
    # "llb11_fold7.npz",
    # "llb11_fold8.npz",
    # "llb16_fold1.npz",
    # "llb16_fold2.npz",
    # "llb16_fold3.npz",
    "llb16_fold4.npz",
    "llb16_fold5.npz"
]

MAX_FRAMES = 1_000_000
CTX        = 1_000
DEV        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR  = ROOT / "experiments" / "BF_Canary_Joint_Run"

# ── helpers: i/o + tiny utils ────────────────────────────────────────────────
def group_by_bird(fnames: List[str]) -> Dict[str, List[str]]:
    d = defaultdict(list)
    for f in fnames:
        d[f.split("_")[0]].append(f)
    return {k: sorted(v) for k, v in d.items()}

def load_npz(fp: pathlib.Path) -> Tuple[np.ndarray, ...]:
    with np.load(fp) as f:
        return (
            f["predictions"][:MAX_FRAMES],      # embeddings
            f["s"][:MAX_FRAMES],                # spectrograms
            f["hdbscan_labels"][:MAX_FRAMES],   # labels from training fold
            f["ground_truth_labels"][:MAX_FRAMES],
        )

# ── helpers: decoder forward ────────────────────────────────────────────────
sys.path.insert(0, str(pathlib.Path("src").resolve()))
from src.decoder import TweetyBertClassifier, SongDataSet_Image, CollateFunction
def run_decoder(model_torch: torch.nn.Module,
                specs: np.ndarray,
                gt: np.ndarray,
                tmp_dir: pathlib.Path) -> Tuple[np.ndarray, float, np.ndarray]:
    tmp_dir.mkdir(parents=True, exist_ok=True)
    seg_id = 0
    for start in range(0, len(specs), CTX):
        seg = specs[start:start + CTX]
        seg = np.pad(seg, ((0, 0), (20, 0)), constant_values=0)
        if seg.shape[0] < CTX:
            seg = np.pad(seg, ((0, CTX - seg.shape[0]), (0, 0)))
        gt_seg = gt[start:start + CTX]
        if gt_seg.shape[0] < CTX:
            gt_seg = np.pad(gt_seg, (0, CTX - gt_seg.shape[0]))
        # write the slice so the dataset isn't empty
        np.savez(
            tmp_dir / f"{seg_id}.npz",
            labels=gt_seg,
            s=seg.T,                        # freq × time, per decoder expectations
            vocalization=np.zeros(CTX, dtype=np.int8)
        )
        seg_id += 1

    if seg_id == 0:
        raise RuntimeError("no segments saved – tmp_dir empty")

    ds = SongDataSet_Image(tmp_dir,
                           num_classes=int(gt.max()) + 1,
                           segment_length=CTX,
                           infinite_loader=False)
    dl = DataLoader(ds, batch_size=1, shuffle=False,
                    collate_fn=CollateFunction(segment_length=CTX))

    preds, logits_accum, gts_accum = [], [], []
    t0 = time.perf_counter()
    model_torch.eval()
    with torch.no_grad():
        for b in dl:
            s = b[0].to(DEV)           # first element = spectrogram tensor
            gt = b[1].to(DEV)
            logits = model_torch(s)          # (1, S, C)
            preds.append(torch.argmax(logits, 2).cpu().numpy())
            logits_accum.append(logits.cpu().numpy())        # (1, S, C)
            gts_accum.append(torch.argmax(gt, 2).cpu().numpy())

    t_elapsed = time.perf_counter() - t0

    shutil.rmtree(tmp_dir, ignore_errors=True)
    preds   = np.concatenate([p.squeeze(0) for p in preds])[:len(specs)]
    logits  = np.concatenate([l.squeeze(0) for l in logits_accum])[:len(specs)]
    gts     = np.concatenate([g.squeeze(0) for g in gts_accum])[:len(specs)]

    return preds, t_elapsed, logits, gts

# ── per-bird benchmark ───────────────────────────────────────────────────────
def benchmark_bird(bird_id: str, fold_files: List[str]) -> None:
    fold_paths = [NPZ_DIR / f for f in fold_files]
    csv_rows   = []

    # ── resume support ────────────────────────────────────────────────────
    # read anything we've already logged so we can skip it on rerun
    if CSV_PATH.exists():
        _done_df = pd.read_csv(CSV_PATH)
        done_keys = {
            (row.bird, row.fit_fold, row.eval_fold)
            for _, row in _done_df.iterrows()
        }
    else:
        done_keys = set()

    for fit_path in fold_paths:
        fit_name = fit_path.name

        # if every possible (fit, eval) pair is done already, skip this fold
        remaining_eval_paths = [
            p for p in fold_paths
            if p != fit_path and (bird_id, fit_name, p.name) not in done_keys
        ]
        if not remaining_eval_paths:
            print(f"[resume] {bird_id} · {fit_name}: all evals finished – skipping")
            continue

        # --- Compute class weights from hdbscan_labels in fit fold ---
        with np.load(fit_path) as f:
            hdbscan_labels = f["hdbscan_labels"][:MAX_FRAMES]
        # Only use non-negative labels (ignore noise if present)
        valid_labels = hdbscan_labels[hdbscan_labels >= 0]
        label_counts = Counter(valid_labels)
        num_classes = int(valid_labels.max()) + 1 if valid_labels.size > 0 else 1
        class_counts = np.array([label_counts.get(i, 0) for i in range(num_classes)])
        # Avoid division by zero
        class_counts[class_counts == 0] = 1
        class_weights = 1.0 / class_counts
        class_weights = class_weights / class_weights.sum() * num_classes
        class_weights = torch.tensor(class_weights, dtype=torch.float32, device=DEV)

        # train linear probe
        dec_dir = SAVE_DIR / f"{bird_id}__{fit_name}__decoder"
        probe   = TweetyBertClassifier(model_dir=str(MODEL_DIR),
                                       linear_decoder_dir=str(dec_dir),
                                       context_length=CTX,
                                       weight=class_weights)
        probe.prepare_data(str(fit_path), test_train_split=0.8)
        probe.create_dataloaders(batch_size=42)
        probe.create_classifier()
        probe.train_classifier(lr=1e-3,
                               desired_total_batches=450,
                               batches_per_eval=25,
                               patience=4)
        
        probe_model = probe.classifier_model.to(DEV)

        for eval_path in remaining_eval_paths:
            k = f"{fit_name}__{eval_path.name}"
            res_path = SAVE_DIR / f"{k}.results.npz"
            if res_path.exists():
                continue

            emb_eval, specs_eval, _, gt_eval = load_npz(eval_path)

            # --- decoder forward pass ---
            lbl_d, t_d, logits_eval, gt_eval = run_decoder(
                probe_model, specs_eval, gt_eval, SAVE_DIR / "tmp_eval"
            )

            # --- UMAP on logits for visualization ---
            n_umap = min(500_000, logits_eval.shape[0])
            if n_umap > 1:
                um = umap.UMAP(
                    n_neighbors=30,
                    min_dist=0.0,
                    n_components=2,
                    metric="cosine",
                    n_jobs=-1,
                    low_memory=True
                )
                Z = um.fit_transform(logits_eval[:n_umap])
                plt.figure(figsize=(10, 8))
                plt.scatter(
                    Z[:, 0], Z[:, 1],
                    c=gt_eval[:n_umap],
                    cmap="tab20",
                    s=5,
                    alpha=0.6
                )
                plt.title(f"UMAP of decoder logits – {bird_id} · {k}")
                plt.xlabel("UMAP-1")
                plt.ylabel("UMAP-2")
                plt.grid(True, linestyle="--", alpha=0.4)
                plt.tight_layout()
                plt.savefig(SAVE_DIR / f"{k}_logits_umap.png", dpi=150)
                plt.close()

            # --- UMAP / PUMAP labels remain disabled ---
            lbl_u = np.zeros(len(emb_eval), dtype=int)
            lbl_p = np.zeros(len(emb_eval), dtype=int)
            t_u   = t_p = 0.0

            np.savez_compressed(res_path,
                                umap_labels=lbl_u,
                                pumap_labels=lbl_p,
                                decoder_labels=lbl_d,
                                ground_truth_labels=gt_eval)

            n_rows_emb = len(emb_eval)
            n_rows_dec = len(specs_eval)
            csv_rows.append(dict(bird=bird_id,
                                 fit_fold=fit_name,
                                 eval_fold=eval_path.name,
                                 umap_s_per_row=0,
                                 pumap_s_per_row=0,
                                 decoder_s_per_row=t_d / n_rows_dec,
                                 results_path=str(res_path)))

        shutil.rmtree(dec_dir, ignore_errors=True)
        # release gpu tensors
        probe_model.to('cpu')
        probe.classifier_model.to('cpu')
        # drop optimizer buffers if present
        if hasattr(probe, 'optimizer'):
            del probe.optimizer
        # clear references
        del probe_model
        del probe
        torch.cuda.empty_cache()
        gc.collect()

    # only append if we actually added new work this run
    if csv_rows:
        pd.DataFrame(csv_rows).to_csv(
            CSV_PATH,
            mode="a",
            header=not CSV_PATH.exists(),
            index=False
        )

# ── driver ───────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    for bird, files in group_by_bird(RAW_FILES).items():
        benchmark_bird(bird, files)

    df = pd.read_csv(SAVE_DIR / "timings.csv")
    print("\n--- mean sec / frame ---")
    print(df.groupby("bird")[["umap_s_per_row",
                              "pumap_s_per_row",
                              "decoder_s_per_row"]].mean())


[resume] llb16 · llb16_fold4.npz: all evals finished – skipping
[resume] llb16 · llb16_fold5.npz: all evals finished – skipping

--- mean sec / frame ---
       umap_s_per_row  pumap_s_per_row  decoder_s_per_row
bird                                                     
llb11             0.0              0.0           0.000005
llb16             0.0              0.0           0.000005
llb3              0.0              0.0           0.000005


## Evalulation

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import v_measure_score
from typing import Dict, Any, Optional, List, Tuple
from collections import Counter
import os
import pathlib

# ─────────────────────────────────────────────────────────────────────────────
# ClusteringMetrics Class (calculates metrics and generates dashboard plot)
# ─────────────────────────────────────────────────────────────────────────────
class ClusteringMetrics:
    """Evaluate clustering vs. ground‑truth phrase labels."""

    def __init__(self, gt: np.ndarray, pred: np.ndarray, silence: int = 0):
        if gt.shape != pred.shape:
            # Try to truncate the longer array if lengths differ by a small margin
            # This can happen if pred_dec was based on spec length and gt was slightly different
            min_len = min(len(gt), len(pred))
            if abs(len(gt) - len(pred)) > 100: # Arbitrary threshold for "small margin"
                 raise ValueError(f"gt (shape {gt.shape}) and pred (shape {pred.shape}) arrays must have identical shape or be very close.")
            print(f"Warning: GT and Pred shapes differ ({gt.shape} vs {pred.shape}). Truncating to shortest length: {min_len}")
            gt = gt[:min_len]
            pred = pred[:min_len]

        self.gt_raw = gt.astype(int)
        self.pred = pred.astype(int)
        self.gt = self._merge_silence(self.gt_raw, silence) # Processed GT

        self.gt_types = np.unique(self.gt)
        self.pred_types = np.unique(self.pred)

        self._build_confusion()
        self.mapping = self._hungarian()

    @staticmethod
    def _merge_silence(arr: np.ndarray, silence: int) -> np.ndarray:
        """Fill contiguous *silence* runs with the nearest neighbour label."""
        if arr.size == 0:
            return arr
        out = arr.copy()
        i = 0
        while i < len(out):
            if out[i] != silence:
                i += 1
                continue
            j = i
            while j < len(out) and out[j] == silence:
                j += 1
            
            # Determine fill value
            # Prefer left non-silence, then right, then keep silence if surrounded
            left_val = out[i-1] if i > 0 and out[i-1] != silence else None
            right_val = out[j] if j < len(out) and out[j] != silence else None

            if left_val is not None:
                fill = left_val
            elif right_val is not None:
                fill = right_val
            else: # Both neighbors are silence or out of bounds
                # If it's an isolated block of silence or all silence, keep as silence
                # Or, if you have a default non-silence label, use that.
                # For now, if no non-silence neighbor, it will effectively extend the last seen non-silence or first seen.
                # A better strategy might be needed if all is silence or starts/ends with long silence.
                # Using a simpler logic: if left exists, use it, else if right exists, use it, else keep silence.
                left_neighbor = out[i - 1] if i > 0 else None
                right_neighbor = out[j] if j < len(out) else None
                
                if left_neighbor is not None and left_neighbor != silence:
                    fill = left_neighbor
                elif right_neighbor is not None and right_neighbor != silence:
                    fill = right_neighbor
                else: # if surrounded by silence or at edges with silence
                    # find first non-silence from start if available
                    first_nonsilence_overall = next((val for val in arr if val != silence), silence)
                    fill = first_nonsilence_overall

            out[i:j] = fill
            i = j
        return out

    def _build_confusion(self) -> None:
        """GT×Pred frame counts + column‑normalised version."""
        if not self.gt_types.size or not self.pred_types.size:
            self.C = np.array([], dtype=int)
            self.C_norm = np.array([], dtype=float)
            return

        gt_idx = {l: i for i, l in enumerate(self.gt_types)}
        pr_idx = {l: i for i, l in enumerate(self.pred_types)}
        self.C = np.zeros((len(self.gt_types), len(self.pred_types)), dtype=int)
        
        for g_val, p_val in zip(self.gt, self.pred):
            if g_val in gt_idx and p_val in pr_idx: # Ensure labels are in unique sets
                 np.add.at(self.C, (gt_idx[g_val], pr_idx[p_val]), 1)
        
        col_sum = self.C.sum(axis=0, keepdims=True)
        with np.errstate(divide='ignore', invalid='ignore'): # Handle division by zero for columns with no predictions
            self.C_norm = np.divide(self.C, col_sum, where=col_sum != 0, out=np.zeros_like(self.C, dtype=float))


    def _hungarian(self) -> Dict[int, int]:
        if not hasattr(self, 'C_norm') or self.C_norm.size == 0:
            return {}
        # Cost matrix for Hungarian algorithm should maximize overlap, so use negative C_norm
        cost_matrix = -self.C_norm
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        
        mapping = {}
        # Ensure indices are within bounds
        for r, c in zip(row_ind, col_ind):
            if r < len(self.gt_types) and c < len(self.pred_types):
                mapping[self.gt_types[r]] = self.pred_types[c]
        return mapping

    def v_measure(self) -> float:
        if self.gt.size == 0 or self.pred.size == 0:
            return 0.0
        # V-measure can't handle if one of the arrays is all one label and the other is different
        # Or if arrays are empty after potential initial processing.
        try:
            return v_measure_score(self.gt, self.pred)
        except ValueError:
             # This can happen if, for example, gt or pred contains only one unique label
             # and it doesn't align in a way v_measure_score expects, or if labels are negative.
             # A simple check: if number of unique labels is 1 for both and they are same, V is 1, else 0.
            if len(self.gt_types) == 1 and len(self.pred_types) == 1:
                return 1.0 if self.gt_types[0] == self.pred_types[0] else 0.0
            return 0.0 # Default for other ValueErrors

    # ------------------------------------------------------------------
    # three flavours of frame-error-rate
    # ------------------------------------------------------------------
    def _fer_generic(
        self,
        use_mask: Optional[np.ndarray] = None
    ) -> float:
        """helper: compute fer over the whole array or a masked subset."""
        if self.gt.size == 0:
            return 0.0

        if use_mask is None:
            use_mask = np.ones_like(self.gt, dtype=bool)

        masked_gt   = self.gt[use_mask]
        masked_pred = self.pred[use_mask]

        if masked_gt.size == 0:
            return 0.0

        correct = 0
        for g, p in zip(masked_gt, masked_pred):
            if g in self.mapping and self.mapping[g] == p:
                correct += 1
        return 100.0 * (1.0 - correct / masked_gt.size)

    def total_fer(self) -> float:
        """error rate across *all* frames (current behaviour)."""
        return self._fer_generic()

    def matched_fer(self) -> float:
        """error rate restricted to frames whose GT label is mapped."""
        mapped_mask = np.isin(self.gt, list(self.mapping.keys()))
        return self._fer_generic(use_mask=mapped_mask)

    # keep old name as alias for backwards compat
    def frame_error_rate(self) -> float:
        return self.total_fer()

    def macro_fer(self) -> float:
        if not self.gt_types.size:
            return 0.0
        
        per_type_fer = []
        for gt_label_type in self.gt_types:
            # Frames corresponding to this ground truth type
            type_mask = (self.gt == gt_label_type)
            if not np.any(type_mask): # No frames for this GT type
                continue

            # If this GT type is not in the mapping, all its frames are errors for Macro FER
            if gt_label_type not in self.mapping:
                per_type_fer.append(1.0) # 100% error for this unmapped type
                continue

            mapped_pred_label = self.mapping[gt_label_type]
            
            # Calculate errors for this type
            errors_for_type = np.sum(self.pred[type_mask] != mapped_pred_label)
            total_for_type = np.sum(type_mask)
            
            per_type_fer.append(errors_for_type / total_for_type if total_for_type > 0 else 0.0)
            
        return 100.0 * np.mean(per_type_fer) if per_type_fer else 0.0

    def stats(self) -> Dict[str, Any]:
        if self.gt.size == 0: # Handle empty case
            return dict(
                pct_types_mapped=0, pct_frames_mapped=0,
                mapped_counts={}, unmapped_counts={},
            )
        counts = {g: (self.gt == g).sum() for g in self.gt_types}
        mapped_gt_types = set(self.mapping.keys())
        
        mapped_frames = 0
        for gt_label in mapped_gt_types:
            if gt_label in counts:
                mapped_frames += counts[gt_label]

        total_frames = self.gt.size
        
        return dict(
            pct_types_mapped=100 * len(mapped_gt_types) / len(self.gt_types) if self.gt_types.size else 0,
            pct_frames_mapped=100 * mapped_frames / total_frames if total_frames else 0,
            mapped_counts={k: v for k, v in counts.items() if k in mapped_gt_types},
            unmapped_counts={k: v for k, v in counts.items() if k not in mapped_gt_types},
        )

    def plot(self, title: str = "Clustering Evaluation", figsize=(18, 10)) -> plt.Figure:
        st = self.stats()
        fig = plt.figure(figsize=figsize)
        gs = GridSpec(2, 3, figure=fig, height_ratios=[1, 1.2])

        def _annot_bar(ax, data, color, ttl):
            if not data: # Check if data dictionary is empty
                ax.text(0.5, 0.5, "No data to display", ha="center", va="center")
                ax.set_axis_off()
                return
            
            labels, vals = zip(*sorted(data.items())) # Sort for consistent bar order
            
            # Ensure vals is suitable for sum; handle potential non-numeric if data was structured unexpectedly
            valid_vals = [v for v in vals if isinstance(v, (int, float))]
            if not valid_vals or sum(valid_vals) == 0: # also check if sum is zero
                perc = np.zeros_like(labels, dtype=float)
            else:
                perc = 100 * np.array(vals) / sum(valid_vals) # Normalize by sum of counts in this category

            bars = ax.bar(range(len(labels)), perc, color=color) # Use range for x, then set_xticklabels
            ax.set_xticks(range(len(labels)))
            ax.set_xticklabels([str(l) for l in labels], rotation=45, ha="right", fontsize=8)


            for bar_obj, p_val in zip(bars, perc): # Renamed bar to bar_obj
                if p_val > 0.5: # Threshold for displaying text
                    ax.text(bar_obj.get_x() + bar_obj.get_width() / 2,
                            p_val + 0.3, # Position text above bar
                            f"{p_val:.1f}%",
                            ha="center", va="bottom", fontsize=7)
            ax.set_title(ttl)
            ax.set_ylabel("% frames within this category") # Clarified ylabel
            ax.tick_params(axis="x", rotation=45, labelsize=8)
            ax.set_ylim(0, max(10, np.max(perc) * 1.1 if len(perc)>0 and np.max(perc) > 0 else 10)) # Dynamic Y limit

        # Summary box
        ax0 = fig.add_subplot(gs[0, 0]); ax0.axis("off")
        txt = (
            f"Overall FER : {self.frame_error_rate():.1f}%\n"
            f"Macro FER   : {self.macro_fer():.1f}%\n"
            f"V‑measure   : {self.v_measure():.3f}\n\n"
            f"GT types    : {len(self.gt_types)}\n"
            f"Pred types  : {len(self.pred_types)}\n"
            f"Mapped GT types: {st['pct_types_mapped']:.1f}%"
        )
        ax0.text(0.05, 0.95, txt, va="top", ha="left", fontsize=10, bbox=dict(fc="whitesmoke", alpha=.8, boxstyle="round,pad=0.5"))

        # Pies for types and frames mapped
        ax1 = fig.add_subplot(gs[0, 1])
        if self.gt_types.size > 0 :
             ax1.pie([st['pct_types_mapped'], 100 - st['pct_types_mapped']], labels=["Mapped", "Unmapped"],
                     autopct="%.1f%%", colors=["#8fd175", "#f28e8e"], startangle=90)
        else:
            ax1.text(0.5, 0.5, "No GT types", ha="center", va="center")
        ax1.set_title("GT Label Types")

        ax2 = fig.add_subplot(gs[0, 2])
        if self.gt.size > 0:
            ax2.pie([st['pct_frames_mapped'], 100 - st['pct_frames_mapped']], labels=["Mapped", "Unmapped"],
                    autopct="%.1f%%", colors=["#71b3ff", "#ffb471"], startangle=90)
        else:
            ax2.text(0.5, 0.5, "No GT frames", ha="center", va="center")
        ax2.set_title("GT Frames")
        
        # Bars for mapped and unmapped counts
        ax3 = fig.add_subplot(gs[1, 0:2]) # Spans two columns
        _annot_bar(ax3, st['mapped_counts'], "#4daf4a", "Mapped GT Labels (% of Mapped Frames)")
        
        ax4 = fig.add_subplot(gs[1, 2]) # Single column
        _annot_bar(ax4, st['unmapped_counts'], "#d73027", "Unmapped GT Labels (% of Unmapped Frames)")

        fig.suptitle(title, fontsize=16, y=0.98) # Adjust y for suptitle
        fig.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust rect to prevent suptitle overlap
        return fig

# ─────────────────────────────────────────────────────────────────────────────
# Smoothing Helper Functions
# ─────────────────────────────────────────────────────────────────────────────
def basic_majority_vote(labels: np.ndarray, window_size: int) -> np.ndarray:
    """Apply basic majority vote smoothing to a 1D array of labels."""
    if window_size <= 1 or len(labels) == 0:
        return labels.copy()
    
    n = len(labels)
    smoothed_labels = np.copy(labels)
    half_window = window_size // 2
    
    for i in range(n):
        start = max(0, i - half_window)
        end = min(n, i + half_window + 1)
        window = labels[start:end]
        
        if len(window) > 0:
            counts = Counter(window)
            # Simple tie-breaking: pick the first one encountered (Python's default for Counter.most_common)
            # or could be counts.most_common(1)[0][0]
            # To make it slightly more stable or prefer original if it's a tie:
            top_two = counts.most_common(2)
            if len(top_two) == 1: # Only one item in window or all same
                 smoothed_labels[i] = top_two[0][0]
            elif top_two[0][1] > top_two[1][1]: # Clear winner
                 smoothed_labels[i] = top_two[0][0]
            else: # Tie, prefer original label if it's among the most common
                tied_labels = [item[0] for item in top_two if item[1] == top_two[0][1]]
                if labels[i] in tied_labels:
                    smoothed_labels[i] = labels[i]
                else:
                    smoothed_labels[i] = top_two[0][0] # Fallback to the first most common
        # If window is empty (should not happen with proper start/end), original label is kept
            
    return smoothed_labels

def smooth_labels_per_sequence(
    raw_labels: np.ndarray, 
    dataset_indices: np.ndarray, 
    window_size: int
) -> np.ndarray:
    """Apply majority vote smoothing independently to each sequence defined by dataset_indices."""
    if window_size <= 1:
        return raw_labels.copy()
    
    smoothed_labels = np.zeros_like(raw_labels)
    unique_indices = np.unique(dataset_indices)
    
    for seq_idx in unique_indices:
        mask = (dataset_indices == seq_idx)
        sequence_labels = raw_labels[mask]
        if len(sequence_labels) > 0: # Ensure there are labels for this sequence
            smoothed_labels[mask] = basic_majority_vote(sequence_labels, window_size)
        # If sequence_labels is empty, corresponding part of smoothed_labels remains 0 or its initial value.
            
    return smoothed_labels

# ─────────────────────────────────────────────────────────────────────────────
# Main Evaluation Function (Simplified)
# ─────────────────────────────────────────────────────────────────────────────
def evaluate_predictions_with_smoothing(
    gt_labels: np.ndarray,              # Ground truth labels
    predicted_labels_raw: np.ndarray, # Raw predicted labels (e.g., from decoder)
    dataset_indices: np.ndarray,      # Array indicating sequence/song boundaries for smoothing
    output_dir: str,                  # Directory to save reports
    smoothing_windows: List[int] = [0, 200], # Window sizes for smoothing
    base_title: str = "Decoder"       # Base for plot titles
) -> Dict[int, ClusteringMetrics]:
    """
    Evaluates predicted labels against ground truth for specified smoothing windows.
    Generates a report (plot and text summary) for each window.
    """
    output_path = pathlib.Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    metrics_results_per_window: Dict[int, ClusteringMetrics] = {}

    print(f"Starting evaluation for base title: {base_title}")
    print(f"Ground truth labels shape: {gt_labels.shape}")
    print(f"Raw predicted labels shape: {predicted_labels_raw.shape}")
    print(f"Dataset indices shape: {dataset_indices.shape}")

    for window_size in smoothing_windows:
        print(f"\n--- Evaluating with smoothing window: {window_size} ---")
        
        # Apply smoothing
        if window_size == 0:
            smoothed_predictions = predicted_labels_raw.copy()
            print("Using raw predictions (no smoothing).")
        else:
            if dataset_indices.size == 0:
                print("Warning: dataset_indices is empty. Applying smoothing globally instead of per-sequence.")
                smoothed_predictions = basic_majority_vote(predicted_labels_raw, window_size)
            elif len(np.unique(dataset_indices)) == 1 and len(dataset_indices) == len(predicted_labels_raw):
                print(f"Applying smoothing with window {window_size} to the whole sequence (single dataset index).")
                smoothed_predictions = basic_majority_vote(predicted_labels_raw, window_size)
            else:
                print(f"Applying per-sequence smoothing with window {window_size}.")
                smoothed_predictions = smooth_labels_per_sequence(predicted_labels_raw, dataset_indices, window_size)
        print(f"Shape of smoothed predictions: {smoothed_predictions.shape}")

        # Calculate metrics using the ClusteringMetrics class
        cm = ClusteringMetrics(gt_labels, smoothed_predictions)
        metrics_results_per_window[window_size] = cm
        
        # --- Generate and Save Report for this window ---
        window_report_dir = output_path / f"{base_title.replace(' ', '_')}_window_{window_size}"
        window_report_dir.mkdir(parents=True, exist_ok=True)
        
        # Dashboard Plot
        plot_title = f"{base_title} Evaluation (Smoothing Window: {window_size})"
        fig = cm.plot(title=plot_title)
        plot_save_path = window_report_dir / "metrics_dashboard.png"
        fig.savefig(plot_save_path, dpi=150, bbox_inches="tight")
        plt.close(fig) # Close plot to free memory
        
        # Text Summary
        summary_text_path = window_report_dir / "summary_metrics.txt"
        stats_data = cm.stats() # from ClusteringMetrics
        with open(summary_text_path, "w") as f:
            f.write(f"Metrics Summary for: {base_title}\n")
            f.write(f"Smoothing Window Size: {window_size}\n")
            f.write("-------------------------------------------------\n")
            f.write(f"V-measure Score          : {cm.v_measure():.4f}\n")
            f.write(f"Total FER               : {cm.total_fer():.2f}%\n")
            f.write(f"Matched-only FER        : {cm.matched_fer():.2f}%\n")
            f.write(f"Macro Frame Error Rate   : {cm.macro_fer():.2f}%\n")
            f.write("-------------------------------------------------\n")
            f.write(f"GT Label Types           : {len(cm.gt_types)}\n")
            f.write(f"Predicted Label Types    : {len(cm.pred_types)}\n")
            f.write(f"% GT Types Mapped        : {stats_data['pct_types_mapped']:.2f}%\n")
            f.write(f"% GT Frames Mapped       : {stats_data['pct_frames_mapped']:.2f}%\n")
            f.write("-------------------------------------------------\n")
            f.write("Mapped GT Label Counts:\n")
            for label, count in sorted(stats_data['mapped_counts'].items()):
                f.write(f"  Label {label}: {count}\n")
            f.write("\nUnmapped GT Label Counts:\n")
            for label, count in sorted(stats_data['unmapped_counts'].items()):
                f.write(f"  Label {label}: {count}\n")
        
        print(f"Saved reports for window {window_size} to: {window_report_dir}")
        print(f"  V-measure: {cm.v_measure():.3f}, FER: {cm.frame_error_rate():.1f}%, Macro FER: {cm.macro_fer():.1f}%")

    # Dump all metrics to a txt file at the end
    all_metrics_txt = output_path / "all_metrics_dump.txt"
    with open(all_metrics_txt, "w") as f:
        for window_size, cm in metrics_results_per_window.items():
            stats_data = cm.stats()
            f.write(f"==== Window Size: {window_size} ====\n")
            f.write(f"V-measure Score          : {cm.v_measure():.4f}\n")
            f.write(f"Total FER               : {cm.total_fer():.2f}%\n")
            f.write(f"Matched-only FER        : {cm.matched_fer():.2f}%\n")
            f.write(f"Macro Frame Error Rate   : {cm.macro_fer():.2f}%\n")
            f.write(f"GT Label Types           : {len(cm.gt_types)}\n")
            f.write(f"Predicted Label Types    : {len(cm.pred_types)}\n")
            f.write(f"% GT Types Mapped        : {stats_data['pct_types_mapped']:.2f}%\n")
            f.write(f"% GT Frames Mapped       : {stats_data['pct_frames_mapped']:.2f}%\n")
            f.write("Mapped GT Label Counts:\n")
            for label, count in sorted(stats_data['mapped_counts'].items()):
                f.write(f"  Label {label}: {count}\n")
            f.write("Unmapped GT Label Counts:\n")
            for label, count in sorted(stats_data['unmapped_counts'].items()):
                f.write(f"  Label {label}: {count}\n")
            f.write("\n")
    print(f"Dumped all metrics to {all_metrics_txt}")

    return metrics_results_per_window

import pathlib

if __name__ == "__main__":
    # iterate over all .npz result files and just print metrics
    results_dir = pathlib.Path("/home/george-vengrovski/Documents/projects/tweety_bert_paper/results/decoder_eval")
    smoothing_windows = [0, 50, 100]

    all_metrics = []
    for results_file in results_dir.iterdir():
        if results_file.suffix != ".npz":
            continue
        print(f"\nprocessing {results_file.name}")
        try:
            data = np.load(results_file)
            gt = data.get("ground_truth_labels")
            pred = data.get("decoder_labels")
            if gt is None or pred is None:
                print("  missing ground_truth_labels or decoder_labels, skipping")
                continue
        except Exception as e:
            print(f"  failed to load {results_file.name}: {e}")
            continue

        for w in smoothing_windows:
            if w == 0:
                sm = pred.copy()
            else:
                sm = basic_majority_vote(pred, w)
            cm = ClusteringMetrics(gt, sm)
            print(
                f"  window {w}: "
                f"v-measure={cm.v_measure():.3f}, "
                f"total_fer={cm.total_fer():.2f}%, "
                f"matched_fer={cm.matched_fer():.2f}%, "
                f"macro_fer={cm.macro_fer():.2f}%"
            )
            all_metrics.append({
                "file": results_file.name,
                "window": w,
                "v_measure": cm.v_measure(),
                "total_fer": cm.total_fer(),
                "matched_fer": cm.matched_fer(),
                "macro_fer": cm.macro_fer()
            })

    # Dump all metrics to a txt file
    metrics_txt_path = results_dir / "all_metrics_dump.txt"
    with open(metrics_txt_path, "w") as f:
        for entry in all_metrics:
            f.write(
                f"File: {entry['file']}, Window: {entry['window']}\n"
                f"  v-measure={entry['v_measure']:.3f}, "
                f"total_fer={entry['total_fer']:.2f}%, "
                f"matched_fer={entry['matched_fer']:.2f}%, "
                f"macro_fer={entry['macro_fer']:.2f}%\n"
            )
    print(f"Dumped all metrics to {metrics_txt_path}")



processing llb3_fold8.npz__llb3_fold3.npz.results.npz
  window 0: v-measure=0.848, total_fer=7.38%, matched_fer=6.57%, macro_fer=31.93%
  window 50: v-measure=0.859, total_fer=6.93%, matched_fer=6.11%, macro_fer=31.41%
  window 100: v-measure=0.859, total_fer=6.64%, matched_fer=5.82%, macro_fer=31.08%

processing llb11_fold4.npz__llb11_fold7.npz.results.npz
  window 0: v-measure=0.858, total_fer=14.87%, matched_fer=14.87%, macro_fer=20.43%
  window 50: v-measure=0.874, total_fer=14.88%, matched_fer=14.88%, macro_fer=19.12%
  window 100: v-measure=0.872, total_fer=13.31%, matched_fer=13.31%, macro_fer=17.66%

processing llb3_fold4.npz__llb3_fold9.npz.results.npz
  window 0: v-measure=0.835, total_fer=8.83%, matched_fer=8.74%, macro_fer=35.28%
  window 50: v-measure=0.846, total_fer=8.37%, matched_fer=8.28%, macro_fer=34.65%
  window 100: v-measure=0.846, total_fer=8.07%, matched_fer=7.98%, macro_fer=34.04%

processing llb3_fold3.npz__llb3_fold2.npz.results.npz
  window 0: v-measure=0.8