In [5]:
import sys
from pathlib import Path

# Auto-find project root (works across platforms)
project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.utils.data.sampler import Sampler
import numpy as np
import torchaudio
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
import os
import librosa

import sys
from pathlib import Path
import argparse


# Auto-find project root (works across platforms)
if str(Path.cwd().parent) not in sys.path:
    sys.path.append(str(Path.cwd().parent))
    
# Assume ECAPA_TDNN model code is imported or defined here
    # This is the correct way, assuming project root is in sys.path
# from voice_cloning.speaker_encoder.ecapa_tdnn import ECAPA_TDNN_SMALL

class GE2ELoss(nn.Module):
    def __init__(self, init_w=10.0, init_b=-5.0):
        super(GE2ELoss, self).__init__()
        self.w = nn.Parameter(torch.tensor(init_w))
        self.b = nn.Parameter(torch.tensor(init_b))
    
    def forward(self, embeddings):
        N, M, D = embeddings.shape
        
        centroids = torch.mean(embeddings, dim=1)
        sum_centroids = centroids * M
        sum_centroids_excl = sum_centroids.unsqueeze(1) - embeddings
        centroids_excl = sum_centroids_excl / (M - 1 + 1e-6)
        
        embeddings_flat = embeddings.reshape(N*M, D)
        centroids_excl_flat = centroids_excl.reshape(N*M, D)
        
        sim_matrix = F.cosine_similarity(
            embeddings_flat.unsqueeze(1), 
            centroids.unsqueeze(0), 
            dim=2
        )
        
        sim_self = F.cosine_similarity(embeddings_flat, centroids_excl_flat, dim=1)
        
        speaker_indices = torch.arange(N).view(N, 1).expand(-1, M).reshape(N*M)
        sim_matrix[torch.arange(N*M), speaker_indices] = sim_self
        
        sim_matrix = sim_matrix * self.w + self.b
        loss = F.cross_entropy(sim_matrix, speaker_indices.to(embeddings.device))
        
        return loss

class GE2EBatchSampler(Sampler):
    def __init__(self, dataset, n_speakers, n_utterances, num_batches):
        self.n_speakers = n_speakers
        self.n_utterances = n_utterances
        self.num_batches = num_batches
        
        self.speaker_to_indices = defaultdict(list)
        for idx, (_, spk_id) in enumerate(dataset):
            self.speaker_to_indices[spk_id].append(idx)
        
        self.speakers = list(self.speaker_to_indices.keys())
        
    def __iter__(self):
        for _ in range(self.num_batches):
            selected_speakers = np.random.choice(
                self.speakers, self.n_speakers, replace=False
            )
            batch = []
            for speaker in selected_speakers:
                indices = self.speaker_to_indices[speaker]
                if len(indices) < self.n_utterances:
                    selected = np.random.choice(
                        indices, self.n_utterances, replace=True
                    )
                else:
                    selected = np.random.choice(
                        indices, self.n_utterances, replace=False
                    )
                batch.extend(selected)
            yield batch
            
    def __len__(self):
        return self.num_batches

class VoxCeleb2Dataset(Dataset):
    def __init__(self, audio_paths, speaker_ids, sr=16000, duration=3):
        self.sr = sr
        self.duration = duration
        self.audio_paths = audio_paths
        self.speaker_ids = speaker_ids
        self.spk_to_id = {spk: idx for idx, spk in enumerate(set(speaker_ids))}
        
    def __len__(self):
        return len(self.audio_paths)
    
    def load_audio(self, path: str):
        wav_ref, sr = librosa.load(path)
        wav_ref = torch.FloatTensor(wav_ref).unsqueeze(0)
        resample_fn = torchaudio.transforms.Resample(sr, self.sr)
        wav_ref = resample_fn(wav_ref)
        return wav_ref
    
    def __getitem__(self, idx):
        # Load audio and process to fixed length
        waveform = self.load_audio(self.audio_paths[idx])  # Implement this
        waveform = self.process_waveform(waveform)
        speaker_id = self.spk_to_id[self.speaker_ids[idx]]
        return waveform, speaker_id
    
    def process_waveform(self, waveform):
        target_len = self.sr * self.duration
        if waveform.shape[-1] > target_len:
            start = np.random.randint(0, waveform.shape[-1] - target_len)
            waveform = waveform[..., start:start+target_len]
        else:
            pad = target_len - waveform.shape[-1]
            waveform = F.pad(waveform, (0, pad))
        return waveform.squeeze()

def collate_fn(batch):
    waveforms, speaker_ids = zip(*batch)
    waveforms = torch.stack(waveforms)
    speaker_ids = torch.LongTensor(speaker_ids)
    return waveforms, speaker_ids

# Training Configuration
n_speakers = 4      # Reduced from 5 to handle CPU memory better
n_utterances = 4    # Reduced from 5 to handle CPU memory better
batch_size = n_speakers * n_utterances  # = 16 utterances per batch
# Calculate a reasonable number of batches per epoch
# With 4874 recordings, let's aim to see each recording roughly once per epoch
num_batches = 4874 // batch_size  # ≈ 304 batches
emb_dim = 256
lr = 1e-5
num_epochs = 60

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

folder_path = "./data/output/sample"

def get_audio_paths_and_speaker_ids(vox1_test_wav_folder):
    audio_paths = []
    speaker_ids = []

    # Add debug prints
    print(f"Looking for .wav files in: {os.path.abspath(vox1_test_wav_folder)}")

    # Traverse the directory structure
    for root, dirs, files in os.walk(vox1_test_wav_folder):
        for file in files:
            if file.endswith(".wav"):
                # Full path to the .wav file
                audio_paths.append(os.path.join(root, file))
                
                # Extract speaker ID from the path
                speaker_id = os.path.normpath(root).split(os.sep)[-2]
                speaker_ids.append(speaker_id)

    # Add debug prints
    print(f"Found {len(audio_paths)} audio files")
    print(f"Found {len(set(speaker_ids))} unique speakers")
    
    if len(audio_paths) == 0:
        raise ValueError(f"No .wav files found in {vox1_test_wav_folder}")

    return audio_paths, speaker_ids

def evaluate(model, test_csv_path, device, threshold=0.5):
    model.eval()
    test_df = pd.read_csv(test_csv_path).sample(n=1000, random_state=42)
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Evaluating"):
            # Load both audio files
            audio1_path = os.path.join(folder_path, row['audio_1'])
            audio2_path = os.path.join(folder_path, row['audio_2'])
            
            # Create dataset instances for single files
            dataset = VoxCeleb2Dataset([audio1_path, audio2_path], ['spk1', 'spk2'])
            
            # Get embeddings
            audio1, _ = dataset[0]
            audio2, _ = dataset[1]
            
            audio1 = audio1.unsqueeze(0).to(device)
            audio2 = audio2.unsqueeze(0).to(device)
            
            emb1 = model(audio1)
            emb2 = model(audio2)
            
            # Calculate similarity
            similarity = F.cosine_similarity(emb1, emb2)
            
            # Predict (similarity > 0.5 indicates same speaker)
            prediction = (similarity > threshold).int().item()
            
            # Compare with ground truth
            correct += (prediction == row['label'])
            total += 1
    
    accuracy = correct / total
    return accuracy

def evaluate_linear(linear_layer, embeddings_dict, test_csv_path, device, folder_path):
    linear_layer.eval()
    test_df = pd.read_csv(test_csv_path).sample(n=1000, random_state=42)
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Evaluating"):
            audio1_path = os.path.join(folder_path, row['audio_1'])
            audio2_path = os.path.join(folder_path, row['audio_2'])
            
            # Get pre-computed embeddings
            emb1 = embeddings_dict[audio1_path].to(device)
            emb2 = embeddings_dict[audio2_path].to(device)
            
            # Pass through linear layer
            emb1 = linear_layer(emb1)
            emb2 = linear_layer(emb2)
            
            # Calculate similarity
            similarity = F.cosine_similarity(emb1.unsqueeze(0), emb2.unsqueeze(0))
            
            # Predict (similarity > 0.5 indicates same speaker)
            prediction = (similarity > 0.5).int().item()
            
            # Compare with ground truth
            correct += (prediction == row['label'])
            total += 1
    
    accuracy = correct / total
    linear_layer.train()
    return accuracy

# Make folder path absolute if it's relative
folder_path = os.path.abspath(folder_path)

audio_paths, speaker_ids = get_audio_paths_and_speaker_ids(folder_path)

# Add debug print before creating dataset
print(f"Creating dataset with {len(audio_paths)} files and {len(set(speaker_ids))} speakers")

train_dataset = VoxCeleb2Dataset(audio_paths, speaker_ids)

# Add debug print for batch sampler
print(f"Number of speakers in dataset: {len(train_dataset.spk_to_id)}")

batch_sampler = GE2EBatchSampler(
    train_dataset, 
    n_speakers=min(n_speakers, len(train_dataset.spk_to_id)),  # Ensure n_speakers isn't larger than available speakers
    n_utterances=n_utterances, 
    num_batches=num_batches
)
train_loader = DataLoader(
    train_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn
)

# resume_checkpoint = "../../checkpoints/speaker_encoder.pt"
# model = ECAPA_TDNN_SMALL(
#     feat_dim=1024,
#     feat_type="fbank",
# )

# state_dict = torch.load(resume_checkpoint, map_location=lambda storage, loc: storage)
# model.load_state_dict(state_dict['model'], strict=False)
# _ = model.eval()

# criterion = GE2ELoss().to(device)
# optimizer = Adam(model.parameters(), lr=lr)

test_csv_path = "./data/output/test.csv"

# accuracy = evaluate(model, test_csv_path, device, threshold=threshold)
# print(f"Evaluation Accuracy at threshold {threshold}: {accuracy:.4f}")

# Load linear layer from checkpoint
linear_checkpoint_path = "../../checkpoints_finetuned/best_linear_model.pt"
linear_layer = nn.Linear(256, 64).to(device)
checkpoint = torch.load(linear_checkpoint_path, map_location=device)
linear_layer.load_state_dict(checkpoint['linear_state_dict'])
linear_layer.eval()
print("Loaded linear layer from checkpoint")

# Load pre-computed embeddings
embeddings_path = "../../checkpoints/precomputed_embeddings_test.pt"
if not os.path.isfile(embeddings_path):
    raise FileNotFoundError(f"Pre-computed embeddings not found at {embeddings_path}. Run precompute_embeddings.py first.")

print("Loading pre-computed embeddings...")
embeddings_dict = torch.load(embeddings_path)
print(f"Loaded {len(embeddings_dict)} embeddings")



# Training Loop


Using device: cpu
Looking for .wav files in: /Users/user/Documents/Inno/GenAI/VoiceCloning/voice_cloning/speaker_encoder/data/output/sample
Found 1993 audio files
Found 819 unique speakers
Creating dataset with 1993 files and 819 speakers
Number of speakers in dataset: 819
Loaded linear layer from checkpoint
Loading pre-computed embeddings...
Loaded 1993 embeddings


In [6]:
embeddings_dict.keys()

dict_keys(['./voice_cloning/speaker_encoder/data/output/sample/id10384/vxBFGKGXSFA/00006.wav', './voice_cloning/speaker_encoder/data/output/sample/id10370/OnghF5les5c/00001.wav', './voice_cloning/speaker_encoder/data/output/sample/id10370/YrRRDjYacTg/00001.wav', './voice_cloning/speaker_encoder/data/output/sample/id11090/BhLxC_Ypzew/00040.wav', './voice_cloning/speaker_encoder/data/output/sample/id11090/BhLxC_Ypzew/00012.wav', './voice_cloning/speaker_encoder/data/output/sample/id10142/bCN4w8Vn8V8/00004.wav', './voice_cloning/speaker_encoder/data/output/sample/id10189/2ixmqtKkg7U/00001.wav', './voice_cloning/speaker_encoder/data/output/sample/id10189/B9ZYfVHAvGM/00001.wav', './voice_cloning/speaker_encoder/data/output/sample/id10189/OgDJABfNMso/00001.wav', './voice_cloning/speaker_encoder/data/output/sample/id10519/M8xPkqcv8no/00008.wav', './voice_cloning/speaker_encoder/data/output/sample/id10941/n7_p5G5jcOA/00004.wav', './voice_cloning/speaker_encoder/data/output/sample/id10941/b8G3q

In [10]:
evaluate_linear(linear_layer, embeddings_dict, test_csv_path, device, "./voice_cloning/speaker_encoder/data/output/sample")

Evaluating: 100%|██████████| 1000/1000 [00:00<00:00, 18221.21it/s]


0.605