In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import torch.nn as nn
import math
import torch
import torch.nn.functional as F
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install webrtcvad

## Preprocess and Dataset

In [None]:
train_path = '/kaggle/input/vietnam-celeb-dataset/vietnam-celeb-t.txt'  # Folder lead to the Train Path
root_dir = '/kaggle/input/vietnam-celeb-dataset/full-dataset/data/'  # The folder to contain the audio file


import os
import webrtcvad
import torchaudio
import numpy as np
transform = torchaudio.transforms.MFCC(sample_rate=16000,
            n_mfcc=80,  # You may adjust this value depending on the number of MFCCs you want
            melkwargs={
                "n_fft": 512,
                "win_length": 400,
                "hop_length": 160,
                "f_min": 20,
                "f_max": 7600,
                "window_fn": torch.hamming_window,
                "n_mels": 80,
            }
        )
def mfcc_transform(waveform):
    return transform(waveform)

def remove_silence(waveform, sample_rate=16000, frame_duration_ms=30):
    vad = webrtcvad.Vad(2)  # Moderate aggressiveness (0-3)
    waveform_np = waveform.squeeze().numpy()
    waveform_int16 = (waveform_np * 32767).astype(np.int16)
    frame_length = int(sample_rate * frame_duration_ms / 1000)
    frames = [waveform_int16[i:i+frame_length] for i in range(0, len(waveform_int16), frame_length)]
    
    voiced_frames = []
    for frame in frames:
        if len(frame) == frame_length and vad.is_speech(frame.tobytes(), sample_rate):
            voiced_frames.append(frame)
    
    if voiced_frames:
        voiced_waveform = np.concatenate(voiced_frames).astype(np.float32) / 32767
        return torch.tensor(voiced_waveform, dtype=torch.float32).unsqueeze(0)
    return waveform

In [None]:
def preprocess_audio(file_path, max_frames = 1000):
    waveform, _ = torchaudio.load(file_path)
    waveform = remove_silence(waveform, sample_rate=16000)
    mfcc = mfcc_transform(waveform)  # [1, 80, num_frames]
    
    # Pad or truncate to max_frames
    num_frames = mfcc.size(2)
    if num_frames < max_frames:
        padding = torch.zeros(1, 80, max_frames - num_frames)
        mfcc = torch.cat([mfcc, padding], dim=2)  # [1, 80, max_frames]
    elif num_frames > max_frames:
        mfcc = mfcc[:, :, :max_frames]  # [1, 80, max_frames]
    
    return mfcc # [1, 80, max_frames]

In [None]:
import os
import numpy as np
from torch.utils.data import Dataset


class VietnamCelebDatasetTrain(Dataset):
    def __init__(self, train_path, root_dir, sr=16000, duration=10):
        self.root_dir = root_dir
        self.filepaths = []
        self.labels = []

        with open(train_path, 'r') as f:
            for line in f:
                speaker_id, audio_filename = line.strip().split()
                audio_path = os.path.join(self.root_dir, speaker_id, audio_filename)
                if os.path.exists(audio_path):
                    self.filepaths.append(audio_path)
                    self.labels.append(speaker_id)

        self.label_map = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
        self.labels = [self.label_map[label] for label in self.labels]

    def __len__(self):
        return len(self.filepaths)
        
    def __getitem__(self, idx):
        # Load the waveform
        wav_path = self.filepaths[idx]
        mfcc = preprocess_audio(wav_path)

        # Get label and meta
        label = self.labels[idx]

        # Return as dictionary to match the format expected by the trainer
        return {
            'input_values': mfcc,
            'speaker_labels': torch.tensor(label, dtype=torch.long),
        }

In [None]:
class DataCollatorVietnamCeleb:
    def __init__(self):
        pass

    def __call__(self, batch):
        """
        Collates a batch of items into batched tensors.
        
        Args:
            batch (List[Dict]): List of dictionaries from dataset
            
        Returns:
            dict: Dictionary with batched tensors
        """
        # Extract from batch
        input_values = torch.stack([item['input_values'] for item in batch])
        speaker_labels = torch.stack([item['speaker_labels'] for item in batch])
        
        # Create dictionary for model input
        return {
            'input_values': input_values,
            'speaker_labels': speaker_labels,
        }

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_collator = DataCollatorVietnamCeleb()

In [None]:
train_dataset = VietnamCelebDatasetTrain(train_path, root_dir)

In [None]:
print(len(train_dataset))

In [None]:
class VietnamCelebDatasetValidation(Dataset):
    def __init__(self, val_path, root_dir, sr=16000, duration=10):
        """
        Dataset for speaker verification validation using pre-defined utterance pairs
        
        Args:
            val_path: Path to validation file with pre-defined pairs
            root_dir: Root directory containing the audio files
            sr: Sample rate
            duration: Max duration in seconds
        """
        self.root_dir = root_dir
        self.sr = sr
        self.duration = duration
        self.max_length = sr * duration
        
        # Store pairs and labels
        self.pairs = []
        self.labels = []
        
        # Read validation file with pairs
        with open(val_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 3:
                    label, utt_path1, utt_path2 = parts
                    
                    # Create full paths
                    audio_path1 = os.path.join(self.root_dir, utt_path1)
                    audio_path2 = os.path.join(self.root_dir, utt_path2)
                    
                    # Check if both files exist
                    if os.path.exists(audio_path1) and os.path.exists(audio_path2):
                        self.pairs.append((audio_path1, audio_path2))
                        self.labels.append(int(label))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        audio_path1, audio_path2 = self.pairs[idx]
        label = self.labels[idx]
        
        # Load and process first waveform
        mfcc1 = preprocess_audio(audio_path1)
        
        # Load and process second waveform
        mfcc2 = preprocess_audio(audio_path2)
        
        return {
            'input_values': mfcc1,
            'input_values2': mfcc2,
            'pair_labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
class ValidationDataCollator:
    def __call__(self, batch):
        input_values = torch.stack([item['input_values'] for item in batch])
        input_values2 = torch.stack([item['input_values2'] for item in batch])
        pair_labels = torch.stack([item['pair_labels'] for item in batch])
        
        return {
            'input_values': input_values,
            'input_values2': input_values2,
            'pair_labels': pair_labels
        }

In [None]:
from torch.utils.data import Subset
# Validation data
val_path = "/kaggle/input/vietnam-celeb-dataset/vietnam-celeb-e.txt"
val_dataset = VietnamCelebDatasetValidation(val_path, root_dir)

In [None]:
import random
sample_indices = random.sample(range(len(val_dataset)), 5000)
val_dataset = Subset(val_dataset, indices=sample_indices)
val_collator = ValidationDataCollator()

In [None]:
class CombinedDataCollator:
    def __init__(self, train_collator, val_collator):
        self.train_collator = train_collator
        self.val_collator = val_collator
        
    def __call__(self, batch):
        # Check if this is validation data by looking for input_values2
        if isinstance(batch[0], dict) and 'input_values2' in batch[0]:
            return self.val_collator(batch)
        else:
            return self.train_collator(batch)
combined_collator = CombinedDataCollator(train_collator, val_collator)

In [None]:
from torch.utils.data import  DataLoader

## Res2Net

In [None]:
class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
        """ Constructor
        Args:
            inplanes: input channel dimensionality
            planes: output channel dimensionality
            stride: conv stride. Replaces pooling layer.
            downsample: None when stride = 1
            baseWidth: basic width of conv3x3
            scale: number of scale.
            type: 'normal': normal set. 'stage': first block of a new stage.
        """
        super(Bottle2neck, self).__init__()

        width = int(math.floor(planes * (baseWidth/64.0)))
        self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width*scale)
        
        if scale == 1:
          self.nums = 1
        else:
          self.nums = scale -1
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
        convs = []
        bns = []
        for i in range(self.nums):
          convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
          bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width  = width

    def forward(self, x):
        
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
          if i==0 or self.stype=='stage':
            sp = spx[i]
          else:
            sp = sp + spx[i]
          sp = self.convs[i](sp)
          sp = self.relu(self.bns[i](sp))
          if i==0:
            out = sp
          else:
            out = torch.cat((out, sp), 1)
        if self.scale != 1 and self.stype=='normal':
          out = torch.cat((out, spx[self.nums]),1)
        elif self.scale != 1 and self.stype=='stage':
          out = torch.cat((out, self.pool(spx[self.nums])),1)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [None]:
class Res2Net(nn.Module):
    def __init__(self, block, layers, baseWidth=26, scale=4, embedding_size=256):
        self.inplanes = 64
        super(Res2Net, self).__init__()
        self.baseWidth = baseWidth
        self.scale = scale
        # Single-channel input for MFCCs
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # Embedding layer for speaker verification
        self.embedding = nn.Linear(512 * block.expansion, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                           stype='stage', baseWidth=self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, input_values=None, labels=None, attention_mask=None, **kwargs):
        x = input_values
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.embedding(x)  # Output embeddings

        return x

In [None]:
from torchinfo import summary
res2net50_26w_8s = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8)
model = res2net50_26w_8s
summary(model = model, input_size  = [1, 1,224,224], col_names = ["input_size", "output_size", "num_params", "trainable"], row_settings = ["var_names"])


## Loss Function

In [None]:
class AAMSoftmaxLoss(nn.Module):
    def __init__(self, embedding_dim=256, num_speakers=1000, margin=0.2, scale=30):
        super(AAMSoftmaxLoss, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_speakers = num_speakers
        self.margin = margin
        self.scale = scale
        
        # Weight for the speaker classification
        self.weight = nn.Parameter(torch.FloatTensor(num_speakers, embedding_dim))
        nn.init.xavier_normal_(self.weight, gain=1)
        
        # Pre-compute constants for efficiency
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin
        
    def forward(self, embeddings, labels):
        # Normalize embeddings and weights
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        
        # Compute cosine similarity
        cosine = F.linear(embeddings_norm, weight_norm)
        
        # Add angular margin penalty
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        
        # Apply one-hot encoding for the target labels
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1)
        
        # Apply margin to the target classes only
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        
        # Scale the output
        output = output * self.scale
        
        # Cross entropy loss
        loss = F.cross_entropy(output, labels)
        
        return loss, output

## Metrics for evaluation

In [None]:
def compute_eer(fnr, fpr, scores=None):
    """
    Compute Equal Error Rate (EER) from false negative and false positive rates.
    Returns: (eer, threshold)
    """
    # Make sure fnr and fpr are numpy arrays
    fnr = np.array(fnr)
    fpr = np.array(fpr)
    
    # In case arrays are empty
    if len(fnr) == 0 or len(fpr) == 0:
        print("WARNING: Empty FNR or FPR arrays")
        return 0.5, 0.0  # Return default values
    
    # Calculate difference between FNR and FPR
    diff = fnr - fpr
    
    # Find where the difference changes sign
    # If diff changes sign, find the crossing point
    if np.any(diff >= 0) and np.any(diff <= 0):
        # Find indices where diff changes sign
        positive_indices = np.flatnonzero(diff >= 0)
        negative_indices = np.flatnonzero(diff <= 0)
        
        if len(positive_indices) > 0 and len(negative_indices) > 0:
            # Get the boundary indices
            idx1 = positive_indices[0]
            idx2 = negative_indices[-1]
            
            # Check if indices are out of bounds
            if idx1 >= len(fnr) or idx2 >= len(fnr):
                print("WARNING: Index out of bounds")
                # Find closest points
                abs_diff = np.abs(fnr - fpr)
                min_idx = np.argmin(abs_diff)
                eer = (fnr[min_idx] + fpr[min_idx]) / 2
                threshold = scores[min_idx] if scores is not None else min_idx
                return eer, threshold
            
            # Linear interpolation to find the EER
            if fnr[idx1] == fpr[idx1]:
                # Exactly equal at this point
                eer = fnr[idx1]
                threshold = scores[idx1] if scores is not None else idx1
            else:
                # Interpolate between idx1 and idx2
                x = [fpr[idx2], fpr[idx1]]
                y = [fnr[idx2], fnr[idx1]]
                eer = np.mean(y)  # Approximate EER
                
                # If scores are provided, interpolate threshold
                if scores is not None and idx1 < len(scores) and idx2 < len(scores):
                    threshold = (scores[idx1] + scores[idx2]) / 2
                else:
                    threshold = (idx1 + idx2) / 2
        else:
            # Fallback if indices are not found
            abs_diff = np.abs(fnr - fpr)
            min_idx = np.argmin(abs_diff)
            eer = (fnr[min_idx] + fpr[min_idx]) / 2
            threshold = scores[min_idx] if scores is not None else min_idx
    else:
        # Fallback if no sign change - find the closest points
        abs_diff = np.abs(fnr - fpr)
        min_idx = np.argmin(abs_diff)
        eer = (fnr[min_idx] + fpr[min_idx]) / 2
        threshold = scores[min_idx] if scores is not None else min_idx
    
    return eer, threshold

In [None]:
def compute_speaker_metrics(eval_pred):
    """Compute EER metrics for speaker verification."""
    # Extract embeddings and labels
    embeddings1 = eval_pred.embeddings1
    embeddings2 = eval_pred.embeddings2
    pair_labels = eval_pred.labels
    
    # Compute similarity scores
    similarity_scores = np.array([
        np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + 1e-10)
        for e1, e2 in zip(embeddings1, embeddings2)
    ])
    
    # Compute FPR and FNR
    thresholds = np.sort(similarity_scores)
    fpr = np.zeros(len(thresholds))
    fnr = np.zeros(len(thresholds))
    
    for i, threshold in enumerate(thresholds):
        # Predictions based on threshold
        pred = (similarity_scores >= threshold).astype(int)
        
        # True positives, false positives, true negatives, false negatives
        tp = np.sum((pred == 1) & (pair_labels == 1))
        fp = np.sum((pred == 1) & (pair_labels == 0))
        tn = np.sum((pred == 0) & (pair_labels == 0))
        fn = np.sum((pred == 0) & (pair_labels == 1))
        
        # FPR and FNR
        fpr[i] = fp / (fp + tn) if (fp + tn) > 0 else 0
        fnr[i] = fn / (fn + tp) if (fn + tp) > 0 else 0
    
    # Calculate EER
    eer, threshold = compute_eer(fnr, fpr, similarity_scores)
    result = {
        "eer": eer,
        "eer_threshold": threshold
    }
    # Log EER to wandb directly
    wandb.log({"eer": eer})
    
    # Create and log DET curve to wandb
    if len(fpr) > 10:  # Only log if we have enough points
        
        # Log histogram of similarity scores
        try:
            wandb.log({
                "similarity_scores": wandb.Histogram(similarity_scores),
                "same_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 1]),
                "diff_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 0])
            })
        except Exception as e:
            print(f"Error logging histograms: {e}")
    
    return result

## Wandb

In [None]:
import wandb
import os
os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY"
wandb.init(
    project = "Res2Net",
    name="Res2Net training 1",
    config={
        "learning_rate": 3e-4,
        "architecture": "Res2Net",
        "dataset": "vietnam-celeb",
        "epochs": 5,
    }
)

## Trainer

In [None]:
from transformers import Trainer
import torch
import numpy as np

# Custom Trainer implementation for speaker verification
class SpeakerVerificationTrainer(Trainer):
    def __init__(self, *args, total_epochs=10, margin=0.2, scale=30, num_speakers=1000, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.total_epochs = total_epochs
        self.margin = margin
        self.scale = scale
        self.num_speakers = num_speakers
        embedding_dim = 256  # This gets the embedding dimension
        print(embedding_dim)
        
        # Initialize AAMSoftmax criterion
        self.criterion = AAMSoftmaxLoss(
            embedding_dim=embedding_dim,  # Get embedding dim from model
            num_speakers=num_speakers,
            margin=margin,
            scale=scale
        ).to(self.args.device)
        
        # For storing embeddings during evaluation
        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []
        
        # Log criterion parameters to wandb
        wandb.config.update({
            "embedding_dim": embedding_dim,
            "aam_margin": margin,
            "aam_scale": scale,
            "num_speakers": num_speakers
        })

    def get_train_dataloader(self):
        """Create a working dataloader for training"""
        # Create a simple dataloader that we know works
        return DataLoader(
            self.train_dataset, 
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            num_workers=4,  # Critical: use single process
            pin_memory=False
        )
    
    def get_eval_dataloader(self, eval_dataset=None):
        """Create a working dataloader for evaluation"""
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        return DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator,
            num_workers=4,  # Critical: use single process
            pin_memory=False
        )
        
    def scheduling(self, total_training_epoch, current_epoch, threshold=0.3):
        """Calculate the alpha value for the current epoch."""
        if total_training_epoch <= 1:
            return threshold
        alpha = (current_epoch - 1) / (total_training_epoch - 1)
        return min(max(alpha, threshold), 1 - threshold)

    def training_step(self, model, inputs, num_items=None):
        """Override training step to update alpha parameter."""
        # Get current epoch as integer (HF stores fractional)
        current_epoch = int(self.state.epoch) + 1
        new_alpha = self.scheduling(self.total_epochs, current_epoch)
     
        # Safely update alpha depending on model wrapping
        if hasattr(model, 'module') and hasattr(model.module, 'alpha'):
            model.module.alpha = new_alpha
            alpha_value = model.module.alpha
        elif hasattr(model, 'alpha'):
            model.alpha = new_alpha
            alpha_value = model.alpha
        else:
            alpha_value = None  # fallback
    
        # Print alpha update only at logging steps
        if alpha_value is not None and self.state.global_step % self.args.logging_steps == 0:
            self.log({"alpha": new_alpha})
            print(f"🔁 Epoch {current_epoch}: Alpha set to {alpha_value:.4f}")
            
            # Also log to wandb
            wandb.log({"alpha": new_alpha})
    
        # Let Trainer handle rest (loss computation, backprop, etc.)
        return super().training_step(model, inputs, num_items)
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Compute AAMSoftmax loss for the speaker embeddings."""
        # Extract inputs
        input_values = inputs.get('input_values')
        labels = inputs.get('speaker_labels')
    
        device = next(model.parameters()).device
        input_values = input_values.to(device)
        labels = labels.to(device) if labels is not None else None
        
        # Handle evaluation inputs with pairs for EER computation
        is_eval_with_pairs = False
        if not model.training and inputs.get('input_values2') is not None:
            is_eval_with_pairs = True
            input_values2 = inputs.get('input_values2').to(device)
            pair_labels = inputs.get('pair_labels').to(device)
        
        # Forward pass to get speaker embeddings
        embeddings = model(input_values)
        
        # Handle evaluation with pairs for EER
        if is_eval_with_pairs:
            # Get embeddings for second utterance in pairs
            embeddings2 = model(input_values2)
            
            # Store pairs for EER calculation
            self.pairs_embeddings1.append(embeddings.detach().cpu())
            self.pairs_embeddings2.append(embeddings2.detach().cpu())
            self.pairs_labels.append(pair_labels.detach().cpu())
        
        # Use AAMSoftmax loss for training
        if labels is not None:
            loss, outputs = self.criterion(embeddings, labels)
            
            # Log loss to wandb during training
            if self.model.training and self.state.global_step % self.args.logging_steps == 0:
                wandb.log({"train/aam_loss": loss.item()})
        else:
            loss = None
            outputs = None
        torch.cuda.empty_cache()
        if return_outputs:
            return loss, {"loss": loss, "logits": outputs, "embeddings": embeddings}
        else:
            return loss
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # Call the parent class method to get regular outputs
        outputs = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
        
        # During evaluation, collect embeddings for pairs
        if not prediction_loss_only:
            with torch.no_grad():
                # Get embeddings from model (adjust based on your model's output structure)
                device = next(model.parameters()).device
                embeddings1 = model(inputs["input_values"].to(device))
                embeddings2 = model(inputs["input_values2"].to(device))
                
                # Store embeddings and labels
                self.pairs_embeddings1.append(embeddings1.detach().cpu())
                self.pairs_embeddings2.append(embeddings2.detach().cpu())
                self.pairs_labels.append(inputs["pair_labels"].detach().cpu())
        
        return outputs    
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Override evaluate to compute EER at the end of evaluation."""
        # Reset storage for pairs
        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []
        
        # Run standard evaluation
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        # Calculate EER if we have collected pairs
        if len(self.pairs_embeddings1) > 0:
            # Prepare data for compute metrics function
            embeddings1 = torch.cat(self.pairs_embeddings1, dim=0).numpy()
            embeddings2 = torch.cat(self.pairs_embeddings2, dim=0).numpy()
            pair_labels = torch.cat(self.pairs_labels, dim=0).numpy()
            
            # Create a container class to hold the embeddings
            class EmbeddingPairs:
                def __init__(self, embeddings1, embeddings2, labels):
                    self.embeddings1 = embeddings1
                    self.embeddings2 = embeddings2
                    self.labels = labels
            
            eval_pairs = EmbeddingPairs(embeddings1, embeddings2, pair_labels)
            
            # Compute EER metrics
            eer_metrics = compute_speaker_metrics(eval_pairs)
            
            # Add EER metrics to the overall metrics
            for key, value in eer_metrics.items():
                if key not in metrics:
                    metrics[f"{metric_key_prefix}_{key}"] = value
            
            # Log to wandb with correct prefix
            wandb_metrics = {
                f"{metric_key_prefix}/{key}": value 
                for key, value in metrics.items() 
                if key.startswith(metric_key_prefix)
            }
            wandb.log(wandb_metrics)
            
            print(f"\n{metric_key_prefix.capitalize()} EER: {metrics.get(f'{metric_key_prefix}_eer', 0):.4f}")
            
            # Log embedding visualizations to wandb (t-SNE of random subset)
            if len(embeddings1) > 100:
                try:
                    from sklearn.manifold import TSNE
                    # Sample a subset for visualization (for efficiency)
                    max_samples = min(500, len(embeddings1))
                    indices = np.random.choice(len(embeddings1), max_samples, replace=False)
                    
                    # Apply t-SNE
                    tsne = TSNE(n_components=2, random_state=42)
                    embeddings_combined = np.vstack([embeddings1[indices], embeddings2[indices]])
                    embeddings_2d = tsne.fit_transform(embeddings_combined)
                    
                    # Split back into two sets
                    n_samples = len(indices)
                    embeddings1_2d = embeddings_2d[:n_samples]
                    embeddings2_2d = embeddings_2d[n_samples:]
                    
                    # Create scatter plot data
                    data = []
                    for i in range(n_samples):
                        data.append([
                            embeddings1_2d[i, 0], embeddings1_2d[i, 1], 
                            "Embedding 1", int(pair_labels[indices[i]])
                        ])
                        data.append([
                            embeddings2_2d[i, 0], embeddings2_2d[i, 1], 
                            "Embedding 2", int(pair_labels[indices[i]])
                        ])
                    
                    # Log to wandb
                    wandb.log({
                        f"{metric_key_prefix}/embeddings_tsne": wandb.Table(
                            data=data,
                            columns=["x", "y", "embedding_type", "same_speaker"]
                        )
                    })
                except ImportError:
                    print("sklearn not available for t-SNE visualization")
                except Exception as e:
                    print(f"Error creating t-SNE visualization: {e}")
        
        return metrics

In [None]:
from transformers import TrainingArguments
# Define training arguments
training_args = TrainingArguments(
    output_dir="./hubert_ecapa_tdnn",
    per_device_train_batch_size=8,
    eval_strategy="steps",
    save_strategy="epoch",
    logging_steps=400,
    per_device_eval_batch_size=8,
    dataloader_num_workers= 4,
    learning_rate=3e-4,
    gradient_accumulation_steps=4,
    save_total_limit=2,
    num_train_epochs=5,
    report_to=["wandb"],  # Enable logging to wandb
    metric_for_best_model="eval_eer",
    greater_is_better=False,  # Lower EER is better
    fp16=True,
)

In [None]:
# Initialize trainer
trainer = SpeakerVerificationTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, 
    eval_dataset=val_dataset,    
    data_collator=combined_collator,  
    compute_metrics=compute_speaker_metrics,
    total_epochs=int(training_args.num_train_epochs),
    num_speakers=1000,  # Adjust based on your dataset
    margin=0.2,
    scale=30
)

# Start training
trainer.train()

# After training completes, save and log the best model to wandb
trainer.save_model(training_args.output_dir + "/best_model")
wandb.save(training_args.output_dir + "/best_model/*")

# Finish the wandb run
wandb.finish()