## Copy Data Into Folds

In [None]:
import os
import json
import random
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

mins_per_fold = 50
fold_data_dir = "/media/george-vengrovski/disk1/decoder_data"

birds_wav_paths = [
    "/media/george-vengrovski/disk2/canary/yarden_data/llb3_data/llb3_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb11_data/llb11_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb16_data/llb16_songs"
]

song_detection_json = "files/contains_llb.json"

# Build a mapping from filename to its full path for all birds
wav_file_to_path = {}
for bird_path in birds_wav_paths:
    if os.path.isdir(bird_path):
        for fname in os.listdir(bird_path):
            if fname.endswith('.wav'):
                wav_file_to_path[fname] = os.path.join(bird_path, fname)

# Parse the song detection JSON and collect song files and their durations per bird
with open(song_detection_json, 'r') as f:
    data = json.load(f)

bird_song_files = {}  # bird_id -> list of (filename, duration_seconds)
for entry in data:
    if not entry.get('song_present', False):
        continue
    filename = entry['filename']
    if filename not in wav_file_to_path:
        continue
    bird_id = filename.split('_')[0]
    total_duration = 0.0
    for seg in entry.get('segments', []):
        onset_ms = seg.get('onset_ms', 0)
        offset_ms = seg.get('offset_ms', 0)
        total_duration += (offset_ms - onset_ms) / 1000.0
    if total_duration <= 0:
        continue
    if bird_id not in bird_song_files:
        bird_song_files[bird_id] = []
    bird_song_files[bird_id].append((filename, total_duration))

# For each bird, randomly assign files to folds so each fold has at least mins_per_fold minutes
folds_info = {}  # bird_id -> list of folds, each fold is list of (filename, duration)
for bird_id, files in bird_song_files.items():
    random.shuffle(files)
    folds = []
    current_fold = []
    current_fold_duration = 0.0
    for fname, dur in files:
        current_fold.append((fname, dur))
        current_fold_duration += dur
        if current_fold_duration >= mins_per_fold * 60:
            folds.append(current_fold)
            current_fold = []
            current_fold_duration = 0.0
    if current_fold:  # Add any remaining files to a final fold
        folds.append(current_fold)
    folds_info[bird_id] = folds

# Copy files to their respective fold directories with progress bar
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_dir = os.path.join(fold_data_dir, bird_id, f"fold{i+1}")
        os.makedirs(fold_dir, exist_ok=True)
        print(f"Copying files for {bird_id} fold {i+1} ({len(fold)} files)...")
        for fname, _ in tqdm(fold, desc=f"{bird_id} fold{i+1}", leave=False):
            src = wav_file_to_path[fname]
            dst = os.path.join(fold_dir, fname)
            if not os.path.exists(dst):
                shutil.copy2(src, dst)

# Gather fold durations for plotting
plot_bird_ids = []
plot_fold_names = []
plot_fold_minutes = []
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_minutes = sum(dur for _, dur in fold) / 60
        plot_bird_ids.append(bird_id)
        plot_fold_names.append(f"fold{i+1}")
        plot_fold_minutes.append(fold_minutes)

# Plot bar plots showing minutes per fold for each bird
plt.figure(figsize=(12, 6))
sns.set_style("whitegrid")
bar_labels = [f"{bird}-{fold}" for bird, fold in zip(plot_bird_ids, plot_fold_names)]
bars = plt.bar(bar_labels, plot_fold_minutes, color='skyblue')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.1f}',
             ha='center', va='bottom')
plt.title('Minutes of Song Data per Fold (per Bird)', fontsize=14, pad=20)
plt.xlabel('Bird-Fold', fontsize=12)
plt.ylabel('Total Duration (minutes)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## We want to create embeddings for each fold

In [None]:
import sys
import importlib

import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'  # Add this before running your code

decoding_module = importlib.import_module("decoding")

class Args:
    def __init__(self, mode, bird_name, model_name, wav_folder, song_detection_json_path, num_samples_umap):
        self.mode = mode
        self.bird_name = bird_name
        self.model_name = model_name
        self.wav_folder = wav_folder
        self.song_detection_json_path = song_detection_json_path
        self.num_samples_umap = num_samples_umap
        self.num_random_files_spec = 1000  # Default value
        self.single_threaded_spec = False  # Default value
        self.nfft = 1024  # Default value
        self.raw_spectrogram_umap = False  # Default value for store_true flag
        self.state_finding_algorithm_umap = "HDBSCAN"  # Default value
        self.context_umap = 1000  # Default value

for root, dirs, files in os.walk(fold_data_dir):
    for dir in dirs:
        if "fold" in dir:
            bird = os.path.basename(root)
            bird_name_fold = f"{bird}_{dir}"
            wav_folder = os.path.join(root, dir)
            args = Args(
                mode="single",
                bird_name=bird_name_fold,
                model_name="BF_Canary_Joint_Run",
                wav_folder=wav_folder,
                song_detection_json_path=song_detection_json,
                num_samples_umap="1e6"
            )
            print(f"Running decoding.py --mode single --bird_name {bird_name_fold} --model_name BF_Canary_Joint_Run --wav_folder {wav_folder} --song_detection_json_path {song_detection_json} --num_samples_umap 1e6")
            decoding_module.main(args)


## Evalulate TweetyBERT, Parameteric, Load and Transform on Folds

In [None]:
import sys, pathlib, os, shutil, time
from collections import defaultdict
from typing import List, Dict

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import umap
from umap.parametric_umap import ParametricUMAP # Keep for now, but logic will be commented
import hdbscan
from hdbscan import approximate_predict

sys.path.insert(0, str(pathlib.Path("src").resolve()))
from src.decoder import TweetyBertClassifier, SongDataSet_Image, CollateFunction

# ── I/O paths ────────────────────────────────────────────────────────────────
root      = pathlib.Path().resolve()
npz_dir   = root / "files"
save_dir  = root / "results" / "decoder_eval"
save_dir.mkdir(parents=True, exist_ok=True)

# reduce to two folds for speed; slice each to 10 k frames
raw_files = [
    "llb3_fold1.npz",
    "llb3_fold2.npz",
    "llb3_fold8.npz",
    "llb3_fold4.npz",
    "llb3_fold1.npz",
    "llb3_fold3.npz",
    "llb3_fold7.npz",
    "llb3_fold9.npz",
    "llb3_fold5.npz",
    "llb3_fold6.npz",
    "llb16_fold2.npz",
    "llb16_fold4.npz",
    "llb16_fold1.npz",
    "llb16_fold3.npz",
    "llb16_fold5.npz",
    "llb11_fold2.npz",
    "llb11_fold8.npz",
    "llb11_fold4.npz",
    "llb11_fold1.npz",
    "llb11_fold3.npz",
    "llb11_fold7.npz",
    "llb11_fold5.npz",
    "llb11_fold6.npz",
]
MAX_FRAMES = 1_000_000

# ── helpers ──────────────────────────────────────────────────────────────────
def group_by_bird(fnames: List[str]) -> Dict[str, List[str]]:
    d = defaultdict(list)
    for f in fnames:
        d[f.split("_")[0]].append(f)
    return {k: sorted(v) for k, v in d.items()}

def load_states_labels(fp: pathlib.Path):
    with np.load(fp) as f:
        # Return embeddings, spectrograms, hdbscan_labels from file, ground_truth_labels from file
        # Use f.get for 's' to be robust if 's' key is missing, though it should be present.
        return (f["predictions"][:MAX_FRAMES], # Embeddings
                f['s'][:MAX_FRAMES],    # Spectrograms
                f["hdbscan_labels"][:MAX_FRAMES],       # HDBSCAN labels from the NPZ file
                f["ground_truth_labels"][:MAX_FRAMES])  # Ground truth labels from the NPZ file

def time_and_keep(reducer, X):
    t0 = time.perf_counter()
    Z  = reducer.transform(X)
    return Z, time.perf_counter() - t0

def time_decoder_forward(model, loader, dev):
    model.eval(); t0 = time.perf_counter()
    with torch.no_grad():
        for b in loader:
            _ = model(b["spec"].to(dev))
    return time.perf_counter() - t0

# ── benchmark ────────────────────────────────────────────────────────────────
def benchmark_bird(bird_id: str,
                   fold_files: List[str],
                   ctx: int = 1_000,
                   pumap_epochs: int = 200,          # fewer epochs for test
                   dev: str = ("cuda" if torch.cuda.is_available() else "cpu"),
                   model_dir: str | None = None) -> None:

    fold_paths = [npz_dir / f for f in fold_files]
    csv_path = save_dir / "timings.csv"
    header   = not csv_path.exists()          # write header only once

    for fit_path in fold_paths:
        fit_name = fit_path.name
        # X_fit for UMAP/PUMAP are the embeddings.
        # The TweetyBertClassifier (dec) loads its own data (specs, hdbscan_labels) from fit_path via dec.prepare_data()
        X_fit_embeddings, _, _, _ = load_states_labels(fit_path)
        X_fit = X_fit_embeddings # Use X_fit as the variable name for UMAP/PUMAP input for consistency

        # 1) classic UMAP + HDBSCAN(predict)
        umap_red = None # Placeholder if not running UMAP
        hdb_u = None    # Placeholder if not running UMAP
        # umap_red = umap.UMAP(n_neighbors=30, min_dist=0., n_components=6,
        #                      metric="euclidean", low_memory=True,
        #                      n_jobs=-1).fit(X_fit)
        # if umap_red: # Only fit hdb_u if umap_red was actually run
        #     hdb_u = hdbscan.HDBSCAN(min_samples=1, min_cluster_size=1_000,
        #                             prediction_data=True) \
        #                     .fit(umap_red.embedding_)

        # 2) parametric-UMAP + HDBSCAN(predict)
        pumap = None # Placeholder if not running
        hdb_p = None # Placeholder if not running
        # pumap = ParametricUMAP(n_neighbors=30, min_dist=0., n_components=6, # Keep ParametricUMAP import for now
        #                        metric="euclidean", n_epochs=pumap_epochs,    # but comment out its usage
        #                        batch_size=2_048, verbose=False).fit(X_fit)
        # if pumap: # Only fit hdb_p if pumap was actually run
        #     hdb_p = hdbscan.HDBSCAN(min_samples=1, min_cluster_size=1_000,
        #                             prediction_data=True) \
        #                     .fit(pumap.embedding_)

        # 3) TweetyBert decoder
        if model_dir is None:
            raise ValueError("need model_dir")
        dec_dir = save_dir / f"{bird_id}__{fit_name}__decoder"
        dec = TweetyBertClassifier(model_dir=model_dir,
                                   linear_decoder_dir=str(dec_dir),
                                   context_length=ctx)
        dec.prepare_data(str(fit_path), test_train_split=0.8)
        dec.create_dataloaders(batch_size=256)
        dec.create_classifier()
        dec.train_classifier(lr=1e-3, desired_total_batches=300,
                             batches_per_eval=25, patience=4)
        dec_model = dec.classifier_model.to(dev)

        # ─ evaluate on *other* fold(s) ───────────────────────────────────────
        for eval_path in fold_paths:
            if eval_path.name == fit_name:
                continue

            prefix = f"{fit_name}__{eval_path.name}"
            results_npz_path = save_dir / f"{prefix}.results.npz"
            # If results already exist, skip computation for this pair
            if results_npz_path.exists():
                print(f"Skipping {results_npz_path} (already exists)")
                row = {
                    "bird": bird_id, "fit_fold": fit_name, "eval_fold": eval_path.name,
                    "umap_s_per_row":   0,
                    "pumap_s_per_row":  0,
                    "decoder_s_per_row": 0,
                    "results_path": str(results_npz_path)
                }
                pd.DataFrame([row]).to_csv(csv_path, mode="a", header=header, index=False)
                header = False
                continue

            # Load all necessary data from the evaluation fold
            X_eval_embeddings, specs_eval, _, ground_truth_labels_eval = load_states_labels(eval_path)
            # X_eval_embeddings are for UMAP/PUMAP
            # specs_eval are for the decoder
            # ground_truth_labels_eval is for the final NPZ output

            n_rows_embeddings = len(X_eval_embeddings)
            n_frames_spec = specs_eval.shape[0] if specs_eval.ndim > 0 and specs_eval.size > 0 else 0

            # classic UMAP timing + preds + emb dump
            # if umap_red and hdb_u: # Only run if umap was actually fitted
            #     Z_u, t_u = time_and_keep(umap_red, X_eval_embeddings)
            #     lbl_u, _ = approximate_predict(hdb_u, Z_u)
            # else: # Default values if UMAP is skipped
            Z_u, t_u = np.array([]), 0.0
            lbl_u = np.array([])

            # parametric timing + preds + emb dump
            # if pumap and hdb_p: # Only run if pumap was actually fitted
            #     Z_p, t_p = time_and_keep(pumap, X_eval_embeddings)
            #     lbl_p, _ = approximate_predict(hdb_p, Z_p)
            # else: # Default values if pumap is skipped
            Z_p, t_p = np.array([]), 0.0
            lbl_p = np.array([])

            # --- Decoder Evaluation with Actual Spectrograms ---
            tmp_eval_dir = dec_dir / "eval_tmp_actual" # New name for clarity
            shutil.rmtree(tmp_eval_dir, ignore_errors=True)
            tmp_eval_dir.mkdir(parents=True, exist_ok=True)

            pred_dec = np.array([]) # Default if no spec frames
            t_d = 0.0

            if n_frames_spec > 0:
                # Assuming specs_eval from file is (time, freq_bins) e.g., (MAX_FRAMES, 196)
                # The TweetyBertClassifier._save_data transposes specs for SongDataSet_Image.
                # Spectrograms from file might need the same padding as in TweetyBertClassifier.prepare_data
                # current_spec_freq_bins = specs_eval.shape[1]
                # expected_freq_bins_after_padding = 196 # Based on dummy_spec and TweetyBertClassifier internals
                # if current_spec_freq_bins < expected_freq_bins_after_padding:
                #    padding_amount = expected_freq_bins_after_padding - current_spec_freq_bins
                #    specs_eval = np.pad(specs_eval, ((0,0), (padding_amount // 2, padding_amount - padding_amount //2 )), 'constant', constant_values=0)
                # For now, assume specs_eval has the correct number of freq bins (e.g. 196) as used by the model.

                seg_id_eval = 0
                for start_idx in range(0, n_frames_spec, ctx):
                    seg_len = min(ctx, n_frames_spec - start_idx)
                    spec_segment = specs_eval[start_idx : start_idx + seg_len, :]

                    spec_segment = np.pad(spec_segment, ((0, 0), (20, 0)), 'constant', constant_values=0)


                    if seg_len < ctx: # Pad the segment if it's shorter than context length
                        pad_width = ((0, ctx - seg_len), (0, 0))
                        spec_segment = np.pad(spec_segment, pad_width, mode='constant', constant_values=0)
                    
                    # Save with transpose, similar to TweetyBertClassifier._save_data
                    np.savez(tmp_eval_dir / f"{seg_id_eval}.npz",
                             labels=np.zeros(ctx, dtype=np.int64), # Dummy labels for loader
                             s=spec_segment.T, # Transpose: (freq_bins, ctx)
                             vocalization=np.zeros(ctx, dtype=np.int8)) # Dummy vocalization
                    seg_id_eval += 1
                
                if seg_id_eval > 0:
                    eval_set_actual = SongDataSet_Image(tmp_eval_dir,
                                                 num_classes=dec.num_classes, # num_classes from training on fit_path
                                                 segment_length=ctx,
                                                 infinite_loader=False)
                    eval_loader_actual = DataLoader(
                        eval_set_actual, batch_size=1, shuffle=False, # batch_size=1 for sequential processing
                        collate_fn=CollateFunction(segment_length=ctx))

                    preds_list = []
                    dec_model.eval()
                    t0_dec_eval = time.perf_counter()
                    with torch.no_grad():
                        for batch in eval_loader_actual:
                            spec_tensor = (batch["spec"] if isinstance(batch, dict) else batch[0]).to(dev)
                            logits = dec_model(spec_tensor) # dec_model is LinearProbeModel, output (B, S, C)
                            # Get class prediction for each time step in the segment
                            preds_list.append(torch.argmax(logits, dim=2).cpu().numpy()) # (B, S)
                    t_d = time.perf_counter() - t0_dec_eval
                    
                    if preds_list:
                        # Concatenate predictions from all segments and truncate to original spec length
                        pred_dec_full = np.concatenate([p.squeeze(0) for p in preds_list]) # if batch_size=1
                        pred_dec = pred_dec_full[:n_frames_spec]

            np.savez_compressed(
                results_npz_path,
                umap_labels=lbl_u,
                umap_embeddings=Z_u,
                pumap_labels=lbl_p,
                pumap_embeddings=Z_p,
                decoder_labels=pred_dec,
                ground_truth_labels=ground_truth_labels_eval # Save actual ground truth
            )

            # Clean up temporary eval files for decoder
            shutil.rmtree(tmp_eval_dir, ignore_errors=True)

            row = {
                "bird": bird_id, "fit_fold": fit_name, "eval_fold": eval_path.name,
                "umap_s_per_row":   t_u / n_rows_embeddings if n_rows_embeddings > 0 else 0,
                "pumap_s_per_row":  t_p / n_rows_embeddings if n_rows_embeddings > 0 else 0,
                "decoder_s_per_row": t_d / n_frames_spec if n_frames_spec > 0 else 0,
                "results_path": str(results_npz_path)
            }
            pd.DataFrame([row]).to_csv(csv_path, mode="a", header=header, index=False)
            header = False

        # Nuke the decoder train/test folders and dir after we are done with them
        shutil.rmtree(dec_dir, ignore_errors=True)

        if umap_red: del umap_red
        if hdb_u: del hdb_u
        del dec, dec_model # pumap, hdb_p already handled or placeholders
        if pumap: del pumap
        if hdb_p: del hdb_p

        if dev.startswith("cuda"):
            torch.cuda.empty_cache()

# ── tiny run ─────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    bird_groups = group_by_bird(raw_files)
    model_dir_param = str(root / "experiments" / "BF_Canary_Joint_Run")
    for bird, flist in bird_groups.items():
        benchmark_bird(bird, flist, model_dir=model_dir_param)
    print("\n--- mean sec / frame ---")
    timing_df = pd.read_csv(save_dir / "timings.csv")
    display(timing_df.groupby("bird")[["umap_s_per_row",
                                       "pumap_s_per_row",
                                       "decoder_s_per_row"]].mean())


2025-05-23 17:51:15.367380: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-23 17:51:15.391529: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Number of classes: 18
Step 25: Train Loss 0.7733 FER = 16.29%, Val Loss = 0.6695
Step 50: Train Loss 0.2792 FER = 6.85%, Val Loss = 0.2819
Step 75: Train Loss 0.1820 FER = 4.60%, Val Loss = 0.1793
