### Imports

In [None]:
import os
import argparse
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from speechbrain.inference import EncoderClassifier

from voicestudio.datasets import LIBRITTS_P

### Configuration

In [None]:
# Dataset configuration
DATASET_ROOT = "./data"
DATASET_URL = "train-clean-100"  # Change to: dev-clean, test-clean, etc.
ANNOTATOR = "df1"  # Speaker prompt annotator

# Model configuration
ENCODER_SOURCE = "speechbrain/spkrec-ecapa-voxceleb"
ENCODER_SAVEDIR = "tmp/ecapa"

# Processing configuration
BATCH_SIZE = 512
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Output configuration
OUTPUT_DIR = "./results/embeddings"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Configuration:")
print(f"  Dataset: {DATASET_URL}")
print(f"  Device: {DEVICE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Output: {OUTPUT_DIR}")

### Dataset Preparation

In [None]:
dataset = LIBRITTS_P(
    root=DATASET_ROOT,
    url=DATASET_URL,
    annotator=ANNOTATOR,
    download=True,
)
print(f"Dataset size: {len(dataset)}")

In [None]:
# Load metadata for style_prompt_key mapping
metadata_path = os.path.join(
    DATASET_ROOT, 
    "metadata_w_style_prompt_tags_v230922.csv"
)

print(f"Loading metadata from: {metadata_path}")
metadata_df = pd.read_csv(metadata_path)
style_key_map = metadata_df.set_index('item_name')['style_prompt_key'].to_dict()

print(f"Loaded {len(style_key_map)} style prompt keys")

### Model Setup

In [None]:
encoder = EncoderClassifier.from_hparams(
    source=ENCODER_SOURCE,
    savedir=ENCODER_SAVEDIR,
    run_opts={"device": DEVICE}
)
encoder.eval()

### Dataloader Setup

In [None]:
def collate_fn(batch):
    """Collate function for DataLoader.
    
    Returns:
        waveforms: List of waveforms (variable length)
        metadata: List of metadata dicts
    """
    waveforms = []
    metadata_list = []
    
    for item in batch:
        (waveform, sr, orig_text, norm_text, 
         spk_id, ch_id, utt_id, style_list, speaker_list) = item
        
        waveforms.append(waveform)
        
        # Calculate duration properly
        # waveform can be [samples] or [1, samples]
        num_samples = waveform.shape[-1] if waveform.dim() > 1 else len(waveform)
        
        metadata_list.append({
            'utterance_id': utt_id,
            'speaker_id': spk_id,
            'chapter_id': ch_id,
            'style_prompt_key': style_key_map.get(utt_id, 'unknown'),
            'normalized_text': norm_text,
            'duration': num_samples / sr,
        })
    
    return waveforms, metadata_list

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True if DEVICE == "cuda" else False,
)

In [None]:
embeddings_list = []
metadata_records = []

with torch.no_grad():
    for batch_idx, (waveforms, metadata_batch) in enumerate(tqdm(dataloader, desc="Processing")):
        # Process each waveform in batch
        batch_embeddings = []
        
        for waveform in waveforms:
            # Ensure correct shape: [1, samples]
            if waveform.dim() == 1:
                waveform = waveform.unsqueeze(0)
            
            # Move to device and extract embedding
            waveform = waveform.to(DEVICE)
            embedding = encoder.encode_batch(waveform)
            
            # Extract tensor and move to CPU
            embedding = embedding.squeeze().cpu().numpy()
            batch_embeddings.append(embedding)
        
        # Store results
        embeddings_list.extend(batch_embeddings)
        metadata_records.extend(metadata_batch)
        
        # Progress update every 10 batches
        if (batch_idx + 1) % 10 == 0:
            processed = len(embeddings_list)
            print(f"Processed {processed} samples")

In [None]:
# Convert to numpy array
embeddings_array = np.stack(embeddings_list)
print(f"\nEmbeddings shape: {embeddings_array.shape}")
print(f"Embedding dimension: {embeddings_array.shape[1]}")

# Create metadata DataFrame with specified column order
metadata_df = pd.DataFrame(metadata_records, columns=[
    'utterance_id',
    'speaker_id',
    'chapter_id', 
    'style_prompt_key',
    'normalized_text',
    'duration',
])

print(f"\nMetadata shape: {metadata_df.shape}")
print(metadata_df.head())

print(f"len(embeddings_list): {len(embeddings_list)}")
print(f"len(metadata_records): {len(metadata_records)}")

In [None]:
output_prefix = f"speaker_embeddings_{DATASET_URL.replace('-', '_')}"
embeddings_path = os.path.join(OUTPUT_DIR, f"{output_prefix}.npy")
metadata_path = os.path.join(OUTPUT_DIR, f"{output_prefix}_metadata.csv")

# Save embeddings
print(f"\nSaving embeddings to: {embeddings_path}")
np.save(embeddings_path, embeddings_array)

# Save metadata
print(f"Saving metadata to: {metadata_path}")
metadata_df.to_csv(metadata_path, index=False)