## Copy Data Into Folds

In [None]:
import os
import json
import random
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

mins_per_fold = 50
fold_data_dir = "/media/george-vengrovski/disk1/decoder_data"

birds_wav_paths = [
    "/media/george-vengrovski/disk2/canary/yarden_data/llb3_data/llb3_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb11_data/llb11_songs",
    "/media/george-vengrovski/disk2/canary/yarden_data/llb16_data/llb16_songs"
]

song_detection_json = "files/contains_llb.json"

# Build a mapping from filename to its full path for all birds
wav_file_to_path = {}
for bird_path in birds_wav_paths:
    if os.path.isdir(bird_path):
        for fname in os.listdir(bird_path):
            if fname.endswith('.wav'):
                wav_file_to_path[fname] = os.path.join(bird_path, fname)

# Parse the song detection JSON and collect song files and their durations per bird
with open(song_detection_json, 'r') as f:
    data = json.load(f)

bird_song_files = {}  # bird_id -> list of (filename, duration_seconds)
for entry in data:
    if not entry.get('song_present', False):
        continue
    filename = entry['filename']
    if filename not in wav_file_to_path:
        continue
    bird_id = filename.split('_')[0]
    total_duration = 0.0
    for seg in entry.get('segments', []):
        onset_ms = seg.get('onset_ms', 0)
        offset_ms = seg.get('offset_ms', 0)
        total_duration += (offset_ms - onset_ms) / 1000.0
    if total_duration <= 0:
        continue
    if bird_id not in bird_song_files:
        bird_song_files[bird_id] = []
    bird_song_files[bird_id].append((filename, total_duration))

# For each bird, randomly assign files to folds so each fold has at least mins_per_fold minutes
folds_info = {}  # bird_id -> list of folds, each fold is list of (filename, duration)
for bird_id, files in bird_song_files.items():
    random.shuffle(files)
    folds = []
    current_fold = []
    current_fold_duration = 0.0
    for fname, dur in files:
        current_fold.append((fname, dur))
        current_fold_duration += dur
        if current_fold_duration >= mins_per_fold * 60:
            folds.append(current_fold)
            current_fold = []
            current_fold_duration = 0.0
    if current_fold:  # Add any remaining files to a final fold
        folds.append(current_fold)
    folds_info[bird_id] = folds

# Copy files to their respective fold directories with progress bar
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_dir = os.path.join(fold_data_dir, bird_id, f"fold{i+1}")
        os.makedirs(fold_dir, exist_ok=True)
        print(f"Copying files for {bird_id} fold {i+1} ({len(fold)} files)...")
        for fname, _ in tqdm(fold, desc=f"{bird_id} fold{i+1}", leave=False):
            src = wav_file_to_path[fname]
            dst = os.path.join(fold_dir, fname)
            if not os.path.exists(dst):
                shutil.copy2(src, dst)

# Gather fold durations for plotting
plot_bird_ids = []
plot_fold_names = []
plot_fold_minutes = []
for bird_id, folds in folds_info.items():
    for i, fold in enumerate(folds):
        fold_minutes = sum(dur for _, dur in fold) / 60
        plot_bird_ids.append(bird_id)
        plot_fold_names.append(f"fold{i+1}")
        plot_fold_minutes.append(fold_minutes)

# Plot bar plots showing minutes per fold for each bird
plt.figure(figsize=(12, 6))
sns.set_style("whitegrid")
bar_labels = [f"{bird}-{fold}" for bird, fold in zip(plot_bird_ids, plot_fold_names)]
bars = plt.bar(bar_labels, plot_fold_minutes, color='skyblue')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.1f}',
             ha='center', va='bottom')
plt.title('Minutes of Song Data per Fold (per Bird)', fontsize=14, pad=20)
plt.xlabel('Bird-Fold', fontsize=12)
plt.ylabel('Total Duration (minutes)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## We want to create embeddings for each fold

In [None]:
import sys
import importlib

import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'  # Add this before running your code

decoding_module = importlib.import_module("decoding")

class Args:
    def __init__(self, mode, bird_name, model_name, wav_folder, song_detection_json_path, num_samples_umap):
        self.mode = mode
        self.bird_name = bird_name
        self.model_name = model_name
        self.wav_folder = wav_folder
        self.song_detection_json_path = song_detection_json_path
        self.num_samples_umap = num_samples_umap
        self.num_random_files_spec = 1000  # Default value
        self.single_threaded_spec = False  # Default value
        self.nfft = 1024  # Default value
        self.raw_spectrogram_umap = False  # Default value for store_true flag
        self.state_finding_algorithm_umap = "HDBSCAN"  # Default value
        self.context_umap = 1000  # Default value

for root, dirs, files in os.walk(fold_data_dir):
    for dir in dirs:
        if "fold" in dir:
            bird = os.path.basename(root)
            bird_name_fold = f"{bird}_{dir}"
            wav_folder = os.path.join(root, dir)
            args = Args(
                mode="single",
                bird_name=bird_name_fold,
                model_name="BF_Canary_Joint_Run",
                wav_folder=wav_folder,
                song_detection_json_path=song_detection_json,
                num_samples_umap="1e6"
            )
            print(f"Running decoding.py --mode single --bird_name {bird_name_fold} --model_name BF_Canary_Joint_Run --wav_folder {wav_folder} --song_detection_json_path {song_detection_json} --num_samples_umap 1e6")
            decoding_module.main(args)


## Evalulate TweetyBERT, Parameteric, Load and Transform on Folds