In [1]:
# !pip install -r requiresment.txt

In [2]:
!pip install transformers[torch]

[0m

In [3]:
import math, torch, torchaudio
import torch.nn as nn
import torch.nn.functional as F

# ECAPA - TDNN

## SE Module

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SEModule(nn.Module):
    def __init__(self, channels, bottleneck=128):
        """
        Squeeze-Excitation Module for channel-wise attention
        
        Args:
            channels: Number of input channels
            bottleneck: Dimension of the bottleneck representation
        """
        super(SEModule, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        se = self.se(x)
        return x * se

## Res2Block

In [5]:
class Res2Block(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=1, scale=8):
        """
        Res2Net Block for multi-scale feature extraction
        
        Args:
            channels: Number of input/output channels
            kernel_size: Size of the convolutional kernel
            dilation: Dilation rate for the convolutions
            scale: Number of scales for feature extraction
        """
        super(Res2Block, self).__init__()
        self.scale = scale
        self.width = channels // scale
        self.nums = scale if channels % scale == 0 else scale - 1

        self.convs = nn.ModuleList()
        for i in range(self.nums):
            self.convs.append(
                nn.Conv1d(
                    self.width,
                    self.width,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=dilation * (kernel_size - 1) // 2,
                )
            )

        self.relu = nn.ReLU()

    def forward(self, x):
        out = []
        spx = torch.split(x, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.relu(self.convs[i](sp))
            out.append(sp)
        if self.scale - self.nums > 0:
            out.append(spx[self.nums])
        out = torch.cat(out, dim=1)
        return out

## SE-Res2Block

In [6]:
class SERes2Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, scale=8):
        """
        SE-Res2Block: Combines Res2Net with Squeeze-Excitation
        
        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            kernel_size: Size of the convolutional kernel
            dilation: Dilation rate for the convolutions
            scale: Number of scales for Res2Net
        """
        super(SERes2Block, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.res2block = Res2Block(out_channels, kernel_size, dilation, scale)
        self.se = SEModule(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.res2block(x)
        x = self.se(x)
        x = self.conv2(x)
        return x + residual

## Attentive Stats Pool

In [7]:
class AttentiveStatsPooling(nn.Module):
    def __init__(self, in_dim, bottleneck_dim=128):
        """
        Attentive Statistics Pooling
        
        Args:
            in_dim: Number of input channels
            bottleneck_dim: Dimension of the bottleneck in the attention mechanism
        """
        super(AttentiveStatsPooling, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1),
            nn.Softmax(dim=2),
        )

    def forward(self, x):
        # x is (batch, channels, time)
        attention_weights = self.attention(x)
        
        # Weighted mean
        mean = torch.sum(x * attention_weights, dim=2)
        
        # Weighted standard deviation
        var = torch.sum(x**2 * attention_weights, dim=2) - mean**2
        std = torch.sqrt(var.clamp(min=1e-5))
        
        # Concatenate mean and standard deviation
        pooled = torch.cat([mean, std], dim=1)
        return pooled

## ECAPA_TDNN

In [8]:
class ECAPA_TDNN(nn.Module):
    def __init__(
        self,
        input_dim=80,         # Input feature dimension (e.g., 80 filterbank features)
        channels=512,         # Number of channels in the convolutional blocks
        embedding_dim=192,    # Final embedding dimension
        res2net_scale=8,      # Scale parameter for Res2Net blocks
    ):
        """
        ECAPA-TDNN model for speaker verification
        
        Args:
            input_dim: Dimension of the input features
            channels: Number of channels in the convolutional blocks
            embedding_dim: Dimension of the final speaker embedding
            res2net_scale: Scale parameter for Res2Net blocks
        """
        super(ECAPA_TDNN, self).__init__()
        
        # Initial Conv1D + ReLU + BN
        self.conv1 = nn.Conv1d(input_dim, channels, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(channels)
        
        # Three SE-Res2Blocks with different dilations (2, 3, 4)
        self.se_res2block1 = SERes2Block(
            channels, channels, kernel_size=3, dilation=2, scale=res2net_scale
        )
        self.se_res2block2 = SERes2Block(
            channels, channels, kernel_size=3, dilation=3, scale=res2net_scale
        )
        self.se_res2block3 = SERes2Block(
            channels, channels, kernel_size=3, dilation=4, scale=res2net_scale
        )
        
        # Multi-Layer Feature Aggregation + Conv1D
        self.conv_agg = nn.Conv1d(3 * channels, 1536, kernel_size=1)
        self.bn_agg = nn.BatchNorm1d(1536)
        
        # Attentive Statistics Pooling + BN
        self.asp = AttentiveStatsPooling(1536)
        self.bn_asp = nn.BatchNorm1d(3072)  # 1536*2 because we concatenate mean and std
        
        # Final FC layer + BN for embedding
        self.fc = nn.Linear(3072, embedding_dim)
        self.bn_emb = nn.BatchNorm1d(embedding_dim)

    def forward(self, input_values=None, labels=None, attention_mask=None, **kwargs):
        """
        Forward pass of the ECAPA-TDNN model
        
        Args:
            x: Input tensor of shape (batch_size, time_steps, input_dim)
               or (batch_size, input_dim, time_steps)
        
        Returns:
            Speaker embedding of shape (batch_size, embedding_dim)
        """
        # Check input shape and transpose if needed
        x = input_values
        if x.size(1) == self.conv1.in_channels:
            # Input is already in the format (batch, channels, time)
            pass
        else:
            # Input is in the format (batch, time, channels)
            x = x.transpose(1, 2)
        
        # Initial Conv1D + ReLU + BN
        x = self.conv1(x)
        x = self.relu(x)
        x = self.bn1(x)
        
        # Three SE-Res2Blocks
        x1 = self.se_res2block1(x)
        x2 = self.se_res2block2(x1)
        x3 = self.se_res2block3(x2)
        
        # Multi-Layer Feature Aggregation + Conv1D + ReLU + BN
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.conv_agg(x)
        x = self.relu(x)
        x = self.bn_agg(x)
        
        # Attentive Statistics Pooling + BN
        x = self.asp(x)
        x = self.bn_asp(x)
        
        # Final FC layer + BN for embedding
        x = self.fc(x)
        x = self.bn_emb(x)
        
        return x

### AAMSoftmaxLoss

In [9]:
class AAMSoftmaxLoss(nn.Module):
    def __init__(self, embedding_dim=192, num_speakers=815, margin=0.2, scale=30):
        """
        Additive Angular Margin Softmax Loss
        
        Args:
            embedding_dim: Dimension of the embeddings
            num_speakers: Number of speakers in the training set
            margin: Angular margin penalty
            scale: Scale factor for the cosine values
        """
        super(AAMSoftmaxLoss, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_speakers = num_speakers
        self.margin = margin
        self.scale = scale
        
        # Weight for the speaker classification
        self.weight = nn.Parameter(torch.FloatTensor(num_speakers, embedding_dim))
        nn.init.xavier_normal_(self.weight, gain=1)
        
        # Pre-compute constants for efficiency
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin
        
    def forward(self, embeddings, labels):
        # Input validation
        # Ensure labels are proper long integers and within range
        labels = labels.long()
        if torch.any(labels < 0) or torch.any(labels >= self.num_speakers):
            raise ValueError(f"Labels must be in range [0, {self.num_speakers-1}], got: min={labels.min().item()}, max={labels.max().item()}")
        
        # Normalize embeddings and weights
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        
        # Compute cosine similarity
        cosine = F.linear(embeddings_norm, weight_norm)
        
        # Add angular margin penalty
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        
        # Apply one-hot encoding for the target labels
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1)
        
        # Apply margin to the target classes only
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        
        # Scale the output
        output = output * self.scale
        
        # Cross entropy loss - ensure labels are in the expected format
        loss = F.cross_entropy(output, labels)
        
        return loss, output

## Datasets & Data Collator

In [10]:
!pip install --no-cache-dir webrtcvad
!pip show webrtcvad


[0mName: webrtcvad
Version: 2.0.10
Summary: Python interface to the Google WebRTC Voice Activity Detector (VAD)
Home-page: https://github.com/wiseman/py-webrtcvad
Author: John Wiseman
Author-email: jjwiseman@gmail.com
License: MIT
Location: /usr/local/lib/python3.10/dist-packages
Requires: 
Required-by: 


In [11]:
import webrtcvad
def remove_silence(waveform, sample_rate=16000, frame_duration_ms=30):
    vad = webrtcvad.Vad(2)  # Moderate aggressiveness (0-3)
    waveform_np = waveform.squeeze().numpy()
    waveform_int16 = (waveform_np * 32767).astype(np.int16)
    frame_length = int(sample_rate * frame_duration_ms / 1000)
    frames = [waveform_int16[i:i+frame_length] for i in range(0, len(waveform_int16), frame_length)]
    
    voiced_frames = []
    for frame in frames:
        if len(frame) == frame_length and vad.is_speech(frame.tobytes(), sample_rate):
            voiced_frames.append(frame)
    
    if voiced_frames:
        voiced_waveform = np.concatenate(voiced_frames).astype(np.float32) / 32767
        return torch.tensor(voiced_waveform, dtype=torch.float32).unsqueeze(0)
    return waveform

In [12]:
import os
import numpy as np
from torch.utils.data import Dataset
import librosa
import torchaudio
import torchaudio.transforms as T


class VSASV_Train(Dataset):
    def __init__(self, train_path, root_dir, sr=16000, duration=3):
        self.root_dir = root_dir
        self.filepaths = []
        self.labels = []
        self.meta = []
        self.sr = sr
        self.duration = duration
        self.max_length = sr * duration
        self.mfbe_transform = T.MelSpectrogram(
            sample_rate=16000,
            n_fft=400,
            win_length=400,
            hop_length=160,
            n_mels=80,
            power=1.0  # dùng năng lượng
        )
        self.target_frames = int(self.duration * self.sr / 160)  # hop_length = 160

        # Read file metadata
        with open(train_path, 'r') as f:
            for line in f:
                speaker_id, path, _ = line.strip().split()
                if "spoof" in path:
                    continue
                audio_path = os.path.join(self.root_dir, path)
                if os.path.exists(audio_path):
                    self.filepaths.append(audio_path)
                    self.labels.append(speaker_id)
                else:
                    print(audio_path)

        self.label_map = {label: idx for idx, label in enumerate(sorted(set(self.labels)))}
        self.labels = [self.label_map[label] for label in self.labels]


    def __len__(self):
        return len(self.filepaths)
        
    def __getitem__(self, idx):
        # Load the waveform
        wav_path = self.filepaths[idx]
        waveform, sample_rate = torchaudio.load(wav_path)
        waveform = remove_silence(waveform)
        
        if sample_rate != self.sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)
            waveform = resampler(waveform)
        #waveform = waveform[0]  # Convert [1, N] to [N]

        mfbe = self.mfbe_transform(waveform).squeeze(0)
        if mfbe.shape[-1] < self.target_frames:
            mfbe = torch.nn.functional.pad(mfbe, (0, self.target_frames - mfbe.shape[-1]))
        else:
            mfbe = mfbe[:, :self.target_frames]        

        # Get label 
        label = self.labels[idx]

        # Return as dictionary to match the format expected by the trainer
        return {
            'input_values': mfbe,
            'speaker_labels': torch.tensor(label, dtype=torch.long)
        }

In [13]:
from torch.nn.utils.rnn import pad_sequence

class DataCollatorVietnamCeleb:
    def __call__(self, batch):
        input_values = [item['input_values'] for item in batch]  # shape: [80, T]
        labels = [item['speaker_labels'] for item in batch]  # adjust to your label format

        # Pad MFCCs to the same time length (dim=2)
        # Transpose to [T, 80] to pad along time dimension
        input_values = [x.T for x in input_values]  # [T, 80]
        input_padded = pad_sequence(input_values, batch_first=True)  # [B, T_max, 80]
        input_padded = input_padded.transpose(1, 2)  # Back to [B, 80, T_max]

        labels = torch.tensor(labels)  # or handle text/tokenized labels accordingly

        return {
            'input_values': input_padded,
            'speaker_labels': labels
        }

In [14]:
from torch.utils.data import  DataLoader
import torch
train_path = 'train.txt'  # Folder lead to the Train Path
root_dir = 'data'  # The folder to contain the audio file
train_dataset = VSASV_Train(train_path, root_dir)
# Data Collator
train_collator = DataCollatorVietnamCeleb()

In [15]:
len(train_dataset)

292005

In [21]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset
import torchaudio.transforms as T


class VSASV_Validation(Dataset):
    def __init__(self, val_path, root_dir, sr=16000, duration=10):
        """
        Dataset for speaker verification validation using MFBE (Mel-Frequency Band Energies)
        """
        self.root_dir = root_dir
        self.sr = sr
        self.duration = duration
        self.max_length = sr * duration

        # MelSpectrogram to compute MFBE
        self.mfbe_transform = T.MelSpectrogram(
            sample_rate=self.sr,
            n_fft=512,
            win_length=400,
            hop_length=160,
            n_mels=80,
            power=1.0  # Use energy instead of power^2 (which is default)
        )

        self.pairs = []
        self.labels = []

        # Read val pairs
        with open(val_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 3:
                    utt_path1, utt_path2, label = parts
                    path1 = os.path.join(self.root_dir, utt_path1)
                    path2 = os.path.join(self.root_dir, utt_path2)
                    if os.path.exists(path1) and os.path.exists(path2):
                        self.pairs.append((path1, path2))
                        self.labels.append(int(label))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path1, path2 = self.pairs[idx]
        label = self.labels[idx]

        mfbe1 = self._load_mfbe(path1)
        mfbe2 = self._load_mfbe(path2)

        return {
            "input_values": mfbe1,
            "input_values2": mfbe2,
            "pair_labels": torch.tensor(label, dtype=torch.long)
        }

    def _load_mfbe(self, path):
        waveform, sr = torchaudio.load(path)
        if sr != self.sr:
            waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=self.sr)
        waveform = waveform[0]  # [1, T] → [T]
        waveform = remove_silence(waveform)
        mfbe = self.mfbe_transform(waveform).squeeze(0)  # [1, 80, T] → [80, T]
        return mfbe



In [22]:
# Data Collator for Validation
class ValidationDataCollator:
    def __call__(self, batch):
        input_values = [item['input_values'].squeeze(-1).T for item in batch]   # [T, 80]
        input_values2 = [item['input_values2'].squeeze(-1).T for item in batch] # [T, 80]
        pair_labels = torch.stack([item['pair_labels'] for item in batch])  # assume fixed-size

        # Pad both input sets
        input_values_padded = pad_sequence(input_values, batch_first=True)   # [B, T_max, 80]
        input_values2_padded = pad_sequence(input_values2, batch_first=True) # [B, T_max, 80]

        # Transpose back to [B, 80, T_max]
        input_values_padded = input_values_padded.transpose(1, 2)
        input_values2_padded = input_values2_padded.transpose(1, 2)

        return {
            'input_values': input_values_padded,
            'input_values2': input_values2_padded,
            'pair_labels': pair_labels
        }

In [23]:
import random
from torch.utils.data import Subset
random.seed(42)
# Validation data
val_path = "val_pairs.txt"
val_dataset = VSASV_Validation(val_path, root_dir)
val_collator = ValidationDataCollator()

In [24]:
len(val_dataset)

50000

In [25]:
from torch.utils.data import DataLoader

val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_collator)

batch = next(iter(val_loader))
for k, v in batch.items():
    print(f"{k}: shape = {v.shape}")


input_values: shape = torch.Size([8, 80, 400])
input_values2: shape = torch.Size([8, 80, 400])
pair_labels: shape = torch.Size([8])


In [26]:
# Combined Data Collator to put both train and val Collator in one class
class CombinedDataCollator:
    def __init__(self, train_collator, val_collator):
        self.train_collator = train_collator
        self.val_collator = val_collator
        
    def __call__(self, batch):
        # Check if this is validation data by looking for input_values2
        if isinstance(batch[0], dict) and 'input_values2' in batch[0]:
            return self.val_collator(batch)
        else:
            return self.train_collator(batch)
combined_collator = CombinedDataCollator(train_collator, val_collator)

### WanDB setup

In [27]:
!pip install safetensors


[0m

In [28]:
from safetensors.torch import load_file
# Log model architecture to wandb
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ECAPA_TDNN().to(device)

In [29]:
import wandb
import os
import torch
from safetensors.torch import load_file

os.environ["WANDB_KEY"] = "4b8af864ea6d5ec9af172b42a4c40e4444e20cf7"


# Initialize wandb
wandb.login(key=os.getenv("WANDB_KEY"))
wandb.init(
    project="ecapa-tdnn1",
    name="ecapa-tdnn",
    config={
        "learning_rate": 3e-5,
        "architecture": "ecapa-tdnn",
        "dataset": "vsasv",
        "epochs": 50,
    }
)

wandb.watch(model, log="all")

# Define batch size and other hyperparameters
 


# Update wandb config with additional hyperparameters
wandb.config.update({
    "batch_size": 256,
    "total_epochs": 50,
    "aam_margin": 0.2,
    "aam_scale": 30,
    "num_speakers": 815
})

# Load the model and move to device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = ECAPA_HuBERT().to(device)
# state_dict = load_file("/kaggle/input/hubert_ecapa_v1/transformers/default/2/model_v2.safetensors")
# model.load_state_dict(state_dict)

# # Log model architecture to wandb
# wandb.watch(model, log="all")

# # Define batch size and other hyperparameters
# batch_size = 8  
# total_epochs = 5

# # Update wandb config with additional hyperparameters
# wandb.config.update({
#     "batch_size": batch_size,
#     "total_epochs": total_epochs,
#     "aam_margin": 0.2,
#     "aam_scale": 30,
#     "num_speakers": 1000
# })

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhddat2k4[0m ([33mhddat2k4-uit[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [30]:
def compute_eer(fnr, fpr, scores=None):
    """
    Compute Equal Error Rate (EER) from false negative and false positive rates.
    Returns: (eer, threshold)
    """
    # Make sure fnr and fpr are numpy arrays
    fnr = np.array(fnr)
    fpr = np.array(fpr)
    
    # In case arrays are empty
    if len(fnr) == 0 or len(fpr) == 0:
        print("WARNING: Empty FNR or FPR arrays")
        return 0.5, 0.0  # Return default values
    
    # Calculate difference between FNR and FPR
    diff = fnr - fpr
    
    # Find where the difference changes sign
    # If diff changes sign, find the crossing point
    if np.any(diff >= 0) and np.any(diff <= 0):
        # Find indices where diff changes sign
        positive_indices = np.flatnonzero(diff >= 0)
        negative_indices = np.flatnonzero(diff <= 0)
        
        if len(positive_indices) > 0 and len(negative_indices) > 0:
            # Get the boundary indices
            idx1 = positive_indices[0]
            idx2 = negative_indices[-1]
            
            # Check if indices are out of bounds
            if idx1 >= len(fnr) or idx2 >= len(fnr):
                print("WARNING: Index out of bounds")
                # Find closest points
                abs_diff = np.abs(fnr - fpr)
                min_idx = np.argmin(abs_diff)
                eer = (fnr[min_idx] + fpr[min_idx]) / 2
                threshold = scores[min_idx] if scores is not None else min_idx
                return eer, threshold
            
            # Linear interpolation to find the EER
            if fnr[idx1] == fpr[idx1]:
                # Exactly equal at this point
                eer = fnr[idx1]
                threshold = scores[idx1] if scores is not None else idx1
            else:
                # Interpolate between idx1 and idx2
                x = [fpr[idx2], fpr[idx1]]
                y = [fnr[idx2], fnr[idx1]]
                eer = np.mean(y)  # Approximate EER
                
                # If scores are provided, interpolate threshold
                if scores is not None and idx1 < len(scores) and idx2 < len(scores):
                    threshold = (scores[idx1] + scores[idx2]) / 2
                else:
                    threshold = (idx1 + idx2) / 2
        else:
            # Fallback if indices are not found
            abs_diff = np.abs(fnr - fpr)
            min_idx = np.argmin(abs_diff)
            eer = (fnr[min_idx] + fpr[min_idx]) / 2
            threshold = scores[min_idx] if scores is not None else min_idx
    else:
        # Fallback if no sign change - find the closest points
        abs_diff = np.abs(fnr - fpr)
        min_idx = np.argmin(abs_diff)
        eer = (fnr[min_idx] + fpr[min_idx]) / 2
        threshold = scores[min_idx] if scores is not None else min_idx
    
    return eer, threshold

In [31]:
def compute_speaker_metrics(eval_pred):
    """Compute EER metrics for speaker verification."""
    # Extract embeddings and labels
    embeddings1 = eval_pred.embeddings1
    embeddings2 = eval_pred.embeddings2
    pair_labels = eval_pred.labels
    
    # Compute similarity scores
    similarity_scores = np.array([
        np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + 1e-10)
        for e1, e2 in zip(embeddings1, embeddings2)
    ])
    
    # Compute FPR and FNR
    thresholds = np.sort(similarity_scores)
    fpr = np.zeros(len(thresholds))
    fnr = np.zeros(len(thresholds))
    
    for i, threshold in enumerate(thresholds):
        # Predictions based on threshold
        pred = (similarity_scores >= threshold).astype(int)
        
        # True positives, false positives, true negatives, false negatives
        tp = np.sum((pred == 1) & (pair_labels == 1))
        fp = np.sum((pred == 1) & (pair_labels == 0))
        tn = np.sum((pred == 0) & (pair_labels == 0))
        fn = np.sum((pred == 0) & (pair_labels == 1))
        
        # FPR and FNR
        fpr[i] = fp / (fp + tn) if (fp + tn) > 0 else 0
        fnr[i] = fn / (fn + tp) if (fn + tp) > 0 else 0
    
    # Calculate EER
    eer, threshold = compute_eer(fnr, fpr, similarity_scores)
    result = {
        "eer": eer,
        "eer_threshold": threshold
    }
    # Log EER to wandb directly
    wandb.log({"eer": eer})
    
    # Create and log DET curve to wandb
    if len(fpr) > 10:  # Only log if we have enough points
        
        # Log histogram of similarity scores
        try:
            wandb.log({
                "similarity_scores": wandb.Histogram(similarity_scores),
                "same_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 1]),
                "diff_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 0])
            })
        except Exception as e:
            print(f"Error logging histograms: {e}")
    
    return result

In [32]:
from transformers import Trainer
import torch
import numpy as np

# Custom Trainer implementation for speaker verification
class SpeakerVerificationTrainer(Trainer):
    def __init__(self, *args, total_epochs=10, margin=0.2, scale=30, num_speakers=815, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.total_epochs = total_epochs
        self.margin = margin
        self.scale = scale
        self.num_speakers = num_speakers
        embedding_dim = 192  # This gets the embedding dimension
        print(embedding_dim)
        
        # Initialize AAMSoftmax criterion
        self.criterion = AAMSoftmaxLoss(
            embedding_dim=embedding_dim,  # Get embedding dim from model
            num_speakers=num_speakers,
            margin=margin,
            scale=scale
        ).to(self.args.device)
        
        # For storing embeddings during evaluation
        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []
        
        # Log criterion parameters to wandb
        wandb.config.update({
            "embedding_dim": embedding_dim,
            "aam_margin": margin,
            "aam_scale": scale,
            "num_speakers": num_speakers
        })

    def get_train_dataloader(self):
        """Create a working dataloader for training"""
        # Create a simple dataloader that we know works
        return DataLoader(
            self.train_dataset, 
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            num_workers=4,  # Critical: use single process
            pin_memory=True,
            drop_last=True
        )
    
    def get_eval_dataloader(self, eval_dataset=None):
        """Create a working dataloader for evaluation"""
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        return DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator,
            num_workers=4,  # Critical: use single process
            pin_memory=True,
            drop_last=True
        )
        
    def scheduling(self, total_training_epoch, current_epoch, threshold=0.3):
        """Calculate the alpha value for the current epoch."""
        if total_training_epoch <= 1:
            return threshold
        alpha = (current_epoch - 1) / (total_training_epoch - 1)
        return min(max(alpha, threshold), 1 - threshold)

    def training_step(self, model, inputs, num_items=None):
        """Override training step to update alpha parameter."""
        # Get current epoch as integer (HF stores fractional)
        current_epoch = int(self.state.epoch) + 1
        new_alpha = self.scheduling(self.total_epochs, current_epoch)
     
        # Safely update alpha depending on model wrapping
        if hasattr(model, 'module') and hasattr(model.module, 'alpha'):
            model.module.alpha = new_alpha
            alpha_value = model.module.alpha
        elif hasattr(model, 'alpha'):
            model.alpha = new_alpha
            alpha_value = model.alpha
        else:
            alpha_value = None  # fallback
    
        # Print alpha update only at logging steps
        if alpha_value is not None and self.state.global_step % self.args.logging_steps == 0:
            self.log({"alpha": new_alpha})
            print(f"🔁 Epoch {current_epoch}: Alpha set to {alpha_value:.4f}")
            
            # Also log to wandb
            wandb.log({"alpha": new_alpha})
    
        # Let Trainer handle rest (loss computation, backprop, etc.)
        return super().training_step(model, inputs, num_items)
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Compute AAMSoftmax loss for the speaker embeddings."""
        # Extract inputs
        input_values = inputs.get('input_values')
        labels = inputs.get('speaker_labels')
    
        device = next(model.parameters()).device
        input_values = input_values.to(device)
        labels = labels.to(device) if labels is not None else None
        
        # Handle evaluation inputs with pairs for EER computation
        is_eval_with_pairs = False
        if not model.training and inputs.get('input_values2') is not None:
            is_eval_with_pairs = True
            input_values2 = inputs.get('input_values2').to(device)
            pair_labels = inputs.get('pair_labels').to(device)
        
        # Forward pass to get speaker embeddings
        embeddings = model(input_values)
        
        # Handle evaluation with pairs for EER
        if is_eval_with_pairs:
            # Get embeddings for second utterance in pairs
            embeddings2 = model(input_values2)
            
            # Store pairs for EER calculation
            self.pairs_embeddings1.append(embeddings.detach().cpu())
            self.pairs_embeddings2.append(embeddings2.detach().cpu())
            self.pairs_labels.append(pair_labels.detach().cpu())
        
        # Use AAMSoftmax loss for training
        if labels is not None:
            loss, outputs = self.criterion(embeddings, labels)
            
            # Log loss to wandb during training
            if self.model.training and self.state.global_step % self.args.logging_steps == 0:
                wandb.log({"train/aam_loss": loss.item()})
        else:
            loss = None
            outputs = None
        torch.cuda.empty_cache()
        if return_outputs:
            return loss, {"loss": loss, "logits": outputs, "embeddings": embeddings}
        else:
            return loss
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # Call the parent class method to get regular outputs
        outputs = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
        
        # During evaluation, collect embeddings for pairs
        if not prediction_loss_only:
            with torch.no_grad():
                # Get embeddings from model (adjust based on your model's output structure)
                device = next(model.parameters()).device
                embeddings1 = model(inputs["input_values"].to(device))
                embeddings2 = model(inputs["input_values2"].to(device))
                
                # Store embeddings and labels
                self.pairs_embeddings1.append(embeddings1.detach().cpu())
                self.pairs_embeddings2.append(embeddings2.detach().cpu())
                self.pairs_labels.append(inputs["pair_labels"].detach().cpu())
        
        return outputs    
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Override evaluate to compute EER at the end of evaluation."""
        # Reset storage for pairs
        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []
        
        # Run standard evaluation
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        # Calculate EER if we have collected pairs
        if len(self.pairs_embeddings1) > 0:
            # Prepare data for compute metrics function
            embeddings1 = torch.cat(self.pairs_embeddings1, dim=0).numpy()
            embeddings2 = torch.cat(self.pairs_embeddings2, dim=0).numpy()
            pair_labels = torch.cat(self.pairs_labels, dim=0).numpy()
            
            # Create a container class to hold the embeddings
            class EmbeddingPairs:
                def __init__(self, embeddings1, embeddings2, labels):
                    self.embeddings1 = embeddings1
                    self.embeddings2 = embeddings2
                    self.labels = labels
            
            eval_pairs = EmbeddingPairs(embeddings1, embeddings2, pair_labels)
            
            # Compute EER metrics
            eer_metrics = compute_speaker_metrics(eval_pairs)
            
            # Add EER metrics to the overall metrics
            for key, value in eer_metrics.items():
                if key not in metrics:
                    metrics[f"{metric_key_prefix}_{key}"] = value
            
            # Log to wandb with correct prefix
            wandb_metrics = {
                f"{metric_key_prefix}/{key}": value 
                for key, value in metrics.items() 
                if key.startswith(metric_key_prefix)
            }
            wandb.log(wandb_metrics)
            
            print(f"\n{metric_key_prefix.capitalize()} EER: {metrics.get(f'{metric_key_prefix}_eer', 0):.4f}")
            
            # Log embedding visualizations to wandb (t-SNE of random subset)
            if len(embeddings1) > 100:
                try:
                    from sklearn.manifold import TSNE
                    # Sample a subset for visualization (for efficiency)
                    max_samples = min(500, len(embeddings1))
                    indices = np.random.choice(len(embeddings1), max_samples, replace=False)
                    
                    # Apply t-SNE
                    tsne = TSNE(n_components=2, random_state=42)
                    embeddings_combined = np.vstack([embeddings1[indices], embeddings2[indices]])
                    embeddings_2d = tsne.fit_transform(embeddings_combined)
                    
                    # Split back into two sets
                    n_samples = len(indices)
                    embeddings1_2d = embeddings_2d[:n_samples]
                    embeddings2_2d = embeddings_2d[n_samples:]
                    
                    # Create scatter plot data
                    data = []
                    for i in range(n_samples):
                        data.append([
                            embeddings1_2d[i, 0], embeddings1_2d[i, 1], 
                            "Embedding 1", int(pair_labels[indices[i]])
                        ])
                        data.append([
                            embeddings2_2d[i, 0], embeddings2_2d[i, 1], 
                            "Embedding 2", int(pair_labels[indices[i]])
                        ])
                    
                    # Log to wandb
                    wandb.log({
                        f"{metric_key_prefix}/embeddings_tsne": wandb.Table(
                            data=data,
                            columns=["x", "y", "embedding_type", "same_speaker"]
                        )
                    })
                except ImportError:
                    print("sklearn not available for t-SNE visualization")
                except Exception as e:
                    print(f"Error creating t-SNE visualization: {e}")
        
        return metrics
    

In [33]:
from transformers import TrainingArguments
# Define training arguments
training_args = TrainingArguments(
    output_dir="./checkpoint1",
    per_device_train_batch_size=256,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=500,
    per_device_eval_batch_size=256,
    learning_rate=3e-5,
    gradient_accumulation_steps=1,
    save_total_limit=51,
    num_train_epochs=50,
    dataloader_num_workers=4,
    report_to=["wandb"],  # Enable logging to wandb
    metric_for_best_model="eval_eer",
    greater_is_better=False,  # Lower EER is better
    fp16=True,
)

In [None]:
# Initialize trainer
trainer = SpeakerVerificationTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, 
    eval_dataset=val_dataset,    
    data_collator=combined_collator,  
    compute_metrics=compute_speaker_metrics,
    total_epochs=int(training_args.num_train_epochs),
    num_speakers=815,  # Adjust based on your dataset
    margin=0.2,
    scale=30
)

# Start training
trainer.train()

# After training completes, save and log the best model to wandb
trainer.save_model(training_args.output_dir + "/best_model")
wandb.save(training_args.output_dir + "/best_model/*")
    

# Finish the wandb run
wandb.finish()

192


  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]


Epoch,Training Loss,Validation Loss
1,7.9974,No log
2,6.2862,No log
3,5.2815,No log
4,4.3536,No log



Eval EER: 0.2149


  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]



Eval EER: 0.2280


  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]



Eval EER: 0.2190


  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]



Eval EER: 0.2123


  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
  input_values = [torch.tensor(item['input_values']) for item in batch]  # shape: [80, T]
