In [1]:
!pip install -r requirements.txt

[0m

In [2]:
!pip install transformers[torch]

[0m

## ResNet 48 + MLCC + AAM 

### BasicBlock

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


### StatsPooling

In [4]:
class StatsPooling(nn.Module):
    def forward(self, x):
        # x: [B, C, F, T]  (sau stage4: [B, 256, 10, T/8])
        B, C, F, T = x.size()
        x = x.view(B, C * F, T)         # [B, 2560, T/8]
        mean = x.mean(dim=2)            # [B, 2560]
        std = x.std(dim=2)              # [B, 2560]
        return torch.cat([mean, std], dim=1)  # [B, 5120]

### AAM Softmax

In [5]:
class AAMSoftmax(nn.Module):
    def __init__(self, embedding_dim=256, num_classes=1000, s=30.0, m=0.2):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)
        self.s = s
        self.m = m

    def forward(self, x, labels):
        x = F.normalize(x, dim=1)
        W = F.normalize(self.weight, dim=1)
        cosine = F.linear(x, W)  # [B, num_classes]
        theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=cosine.size(1)).float().to(x.device)
        output = self.s * (one_hot * target_logits + (1.0 - one_hot) * cosine)
        loss = F.cross_entropy(output, labels)
        return loss, output

### ResNet-48

In [6]:
class ResNet48_ASV(nn.Module):
    def __init__(self, embedding_dim=256, num_speakers=None):
        super().__init__()
        # Conv0
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(96)
        self.relu = nn.ReLU(inplace=True)

        # Residual stages
        self.layer1 = self._make_layer(96, 96, 6, stride=1)    # ResBlock-1
        self.layer2 = self._make_layer(96, 128, 8, stride=2)   # ResBlock-2
        self.layer3 = self._make_layer(128, 160, 6, stride=2)  # ResBlock-3
        self.layer4 = self._make_layer(160, 256, 3, stride=2)  # ResBlock-4

        # Pooling + FC
        self.pooling = StatsPooling()
        self.fc = nn.Linear(5120, embedding_dim)

        # Classifier (optional)
        if num_speakers:
            self.classifier = nn.Linear(embedding_dim, num_speakers)
        else:
            self.classifier = None

    def _make_layer(self, in_c, out_c, num_blocks, stride):
        layers = [BasicBlock(in_c, out_c, stride)]
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_c, out_c))
        return nn.Sequential(*layers)

    def forward(self, input_values, labels=None):
        # Backbone
        x = self.relu(self.bn1(self.conv1(input_values)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # StatsPooling
        x = self.pooling(x)

        # Dense → embedding
        embeddings = F.normalize(self.fc(x), dim=1)  # [B, 256]
        return embeddings

## Train & Validation Data Loader

In [7]:
!pip install webrtcvad

[0m

  Downloading webrtcvad-2.0.10.tar.gz (66 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/66.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.2/66.2 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: webrtcvad


  Building wheel for webrtcvad (setup.py) ... [?25l[?25hdone
  Created wheel for webrtcvad: filename=webrtcvad-2.0.10-cp311-cp311-linux_x86_64.whl size=73495 sha256=d53b693bef75f5a8e3dd786ed97de9d63f0eafb4eb29f26d2d5e22f5fce54738
  Stored in directory: /root/.cache/pip/wheels/94/65/3f/292d0b656be33d1c801831201c74b5f68f41a2ae465ff2ee2f
Successfully built webrtcvad


Installing collected packages: webrtcvad


Successfully installed webrtcvad-2.0.10


In [8]:
import webrtcvad
def remove_silence(waveform, sample_rate=16000, frame_duration_ms=30):
    vad = webrtcvad.Vad(2)  # Moderate aggressiveness (0-3)
    waveform_np = waveform.squeeze().numpy()
    waveform_int16 = (waveform_np * 32767).astype(np.int16)
    frame_length = int(sample_rate * frame_duration_ms / 1000)
    frames = [waveform_int16[i:i+frame_length] for i in range(0, len(waveform_int16), frame_length)]
    
    voiced_frames = []
    for frame in frames:
        if len(frame) == frame_length and vad.is_speech(frame.tobytes(), sample_rate):
            voiced_frames.append(frame)
    
    if voiced_frames:
        voiced_waveform = np.concatenate(voiced_frames).astype(np.float32) / 32767
        return torch.tensor(voiced_waveform, dtype=torch.float32).unsqueeze(0)
    return waveform

### Train Data Loader


In [9]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset


class VSASV_Train(Dataset):
    def __init__(self, list_file, root_dir="/kaggle/input/vsasv-train/vlsp_train/home4/vuhl/VSASV-Dataset",
                 sample_rate=16000, max_duration=2, n_mels=80, n_fft=400, hop_length=160, f_min=20, f_max=None):
        self.sample_rate = sample_rate
        self.max_samples = int(sample_rate * max_duration)
        self.root_dir = root_dir

        # MFBE transform (torchaudio MelSpectrogram)
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            f_min=f_min,
            f_max=f_max,
            power=2.0,  # power=2.0 để ra năng lượng
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB(stype="power")

        # metadata
        self.samples = []
        self.speaker_map = {}
        self.speaker_to_indices = {}
        label_counter = 0

        with open(list_file, "r") as f:
            for line in f:
                spk, rel_path, label = line.strip().split()
                if label != "bonafide":   # chỉ train trên bonafide
                    continue

                full_path = os.path.join(self.root_dir, rel_path)

                if spk not in self.speaker_map:
                    self.speaker_map[spk] = label_counter
                    label_counter += 1

                spk_id = self.speaker_map[spk]
                self.samples.append((full_path, spk_id))

                if spk_id not in self.speaker_to_indices:
                    self.speaker_to_indices[spk_id] = []
                self.speaker_to_indices[spk_id].append(len(self.samples) - 1)

        self.speakers = list(self.speaker_to_indices.keys())

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        waveform, sr = torchaudio.load(path)   # [1, T]
    
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        waveform = waveform.squeeze(0)  # [T]
    
        # Cắt ngẫu nhiên 2s
        if waveform.shape[0] > self.max_samples:
            max_start = waveform.shape[0] - self.max_samples
            start = torch.randint(0, max_start + 1, (1,)).item()
            waveform = waveform[start:start + self.max_samples]
        else:
            pad = self.max_samples - waveform.shape[0]
            waveform = torch.nn.functional.pad(waveform, (0, pad))
    
        # MFBE transform
        mel_spec = self.mel_transform(waveform)   # [n_mels, time]
        mel_db = self.amplitude_to_db(mel_spec)   # log-scale

        return {
            "input_values": mel_db,  # [n_mels, time]
            "speaker_labels": torch.tensor(label, dtype=torch.long)
        }


### Train Data Collator

In [10]:
from torch.nn.utils.rnn import pad_sequence

class TrainDataCollator:
    def __call__(self, batch):
        # Giả sử input là [80, T] hoặc [80, T, 1], ta cần đưa về [T, 80]
        input_values = [
            item['input_values'].squeeze()  # loại bỏ chiều dư nếu có
            .transpose(0, 1)                # [80, T] → [T, 80]
            for item in batch
        ]

        # Pad theo T → [B, T_max, 80]
        input_padded = pad_sequence(input_values, batch_first=True)

        # Transpose lại về [B, 80, T_max]
        input_padded = input_padded.transpose(1, 2)

        # Thêm channel: [B, 1, 80, T_max]
        input_padded = input_padded.unsqueeze(1)

        labels = torch.tensor([item['speaker_labels'] for item in batch], dtype=torch.long)

        return {
            'input_values': input_padded,      # ✅ [B, 1, 80, T]
            'speaker_labels': labels
        }


In [11]:
train_path = "train_asv.txt"
root_dir = "data"

In [12]:
from torch.utils.data import DataLoader
train_dataset = VSASV_Train(
    list_file=train_path,
    root_dir=root_dir,
)
train_collator = TrainDataCollator()

### Validation Data Loader

In [13]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset
import torch.nn.functional as F

class VSASV_Validation(Dataset):
    def __init__(self, val_path, root_dir, sr=16000, duration=5.0):
        """
        Dataset for speaker verification validation using pre-defined utterance pairs

        Args:
            val_path: Path to validation file with pre-defined pairs
            root_dir: Root directory containing the audio files
            sr: Sample rate
            duration: Max duration in seconds
        """
        self.root_dir = root_dir
        self.sr = sr
        self.duration = duration
        self.max_samples = int(sr * duration)

        # Feature extractor
        self.feature_extractor = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sr,
            n_fft=400,
            win_length=400,
            hop_length=160,
            window_fn=torch.hamming_window,
            n_mels=80,
            f_min=20.0,
            f_max=7600.0,
            power=2.0
        )
        self.to_db = torchaudio.transforms.AmplitudeToDB(stype="power")

        self.pairs = []
        self.labels = []

        with open(val_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 3:
                    utt1, utt2, label = parts
                    path1 = os.path.join(self.root_dir, utt1)
                    path2 = os.path.join(self.root_dir, utt2)
                    if os.path.exists(path1) and os.path.exists(path2):
                        self.pairs.append((path1, path2))
                        self.labels.append(int(label))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path1, path2 = self.pairs[idx]
        label = self.labels[idx]

        feat1 = self._load_and_process(path1)  # [1, 80, T]
        feat2 = self._load_and_process(path2)  # [1, 80, T]

        return {
            'input_values': feat1,       # [1, 80, T]
            'input_values2': feat2,      # [1, 80, T]
            'pair_labels': torch.tensor(label, dtype=torch.long)
        }

    def _load_and_process(self, path):
        waveform, sr = torchaudio.load(path)              # [1, T]
        waveform = remove_silence(waveform, sample_rate=sr)  # [1, T]

        # Pad or truncate
        if waveform.shape[1] > self.max_samples:
            waveform = waveform[:, :self.max_samples]
        else:
            pad = self.max_samples - waveform.shape[1]
            waveform = F.pad(waveform, (0, pad))

        mel = self.feature_extractor(waveform)
        mel_db = self.to_db(mel)
        mel_db = mel_db - mel_db.mean(dim=-1, keepdim=True)  # normalize
        return mel_db


### Validation Collator

In [14]:
class ValidationDataCollator:
    def __call__(self, batch):
        input_values = [
            item['input_values'].squeeze()  # loại bỏ chiều dư nếu có
            .transpose(0, 1)                # [80, T] → [T, 80]
            for item in batch
        ]
        input_values2 = [
            item['input_values2'].squeeze()  # loại bỏ chiều dư nếu có
            .transpose(0, 1)                # [80, T] → [T, 80]
            for item in batch
        ]
        pair_labels = torch.stack([item['pair_labels'] for item in batch])  # assume fixed-size

        # Pad both input sets
        input_values_padded = pad_sequence(input_values, batch_first=True)   # [B, T_max, 80]
        input_values2_padded = pad_sequence(input_values2, batch_first=True) # [B, T_max, 80]

        # Transpose back to [B, 1, 80, T_max]
        input_values_padded = input_values_padded.transpose(1, 2).unsqueeze(1)
        input_values2_padded = input_values2_padded.transpose(1, 2).unsqueeze(1)

        return {
            'input_values': input_values_padded,
            'input_values2': input_values2_padded,
            'pair_labels': pair_labels
        }

In [15]:
import random 
from torch.utils.data import Subset


val_path = "val_pairs.txt"
root_dir = "data"
val_dataset = VSASV_Validation(val_path, root_dir)
val_collator = ValidationDataCollator()


In [16]:
class CombinedDataCollator:
    def __init__(self, train_collator, val_collator):
        self.train_collator = train_collator
        self.val_collator = val_collator
        
    def __call__(self, batch):
        # Check if this is validation data by looking for input_values2
        if isinstance(batch[0], dict) and 'input_values2' in batch[0]:
            return self.val_collator(batch)
        else:
            return self.train_collator(batch)
combined_collator = CombinedDataCollator(train_collator, val_collator)

## Set Up Trainer

### Compute EER

In [17]:
def compute_eer(fnr, fpr, scores=None):
    """
    Compute Equal Error Rate (EER) from false negative and false positive rates.
    Returns: (eer, threshold)
    """
    # Make sure fnr and fpr are numpy arrays
    fnr = np.array(fnr)
    fpr = np.array(fpr)
    
    # In case arrays are empty
    if len(fnr) == 0 or len(fpr) == 0:
        print("WARNING: Empty FNR or FPR arrays")
        return 0.5, 0.0  # Return default values
    
    # Calculate difference between FNR and FPR
    diff = fnr - fpr
    
    # Find where the difference changes sign
    # If diff changes sign, find the crossing point
    if np.any(diff >= 0) and np.any(diff <= 0):
        # Find indices where diff changes sign
        positive_indices = np.flatnonzero(diff >= 0)
        negative_indices = np.flatnonzero(diff <= 0)
        
        if len(positive_indices) > 0 and len(negative_indices) > 0:
            # Get the boundary indices
            idx1 = positive_indices[0]
            idx2 = negative_indices[-1]
            
            # Check if indices are out of bounds
            if idx1 >= len(fnr) or idx2 >= len(fnr):
                print("WARNING: Index out of bounds")
                # Find closest points
                abs_diff = np.abs(fnr - fpr)
                min_idx = np.argmin(abs_diff)
                eer = (fnr[min_idx] + fpr[min_idx]) / 2
                threshold = scores[min_idx] if scores is not None else min_idx
                return eer, threshold
            
            # Linear interpolation to find the EER
            if fnr[idx1] == fpr[idx1]:
                # Exactly equal at this point
                eer = fnr[idx1]
                threshold = scores[idx1] if scores is not None else idx1
            else:
                # Interpolate between idx1 and idx2
                x = [fpr[idx2], fpr[idx1]]
                y = [fnr[idx2], fnr[idx1]]
                eer = np.mean(y)  # Approximate EER
                
                # If scores are provided, interpolate threshold
                if scores is not None and idx1 < len(scores) and idx2 < len(scores):
                    threshold = (scores[idx1] + scores[idx2]) / 2
                else:
                    threshold = (idx1 + idx2) / 2
        else:
            # Fallback if indices are not found
            abs_diff = np.abs(fnr - fpr)
            min_idx = np.argmin(abs_diff)
            eer = (fnr[min_idx] + fpr[min_idx]) / 2
            threshold = scores[min_idx] if scores is not None else min_idx
    else:
        # Fallback if no sign change - find the closest points
        abs_diff = np.abs(fnr - fpr)
        min_idx = np.argmin(abs_diff)
        eer = (fnr[min_idx] + fpr[min_idx]) / 2
        threshold = scores[min_idx] if scores is not None else min_idx
    
    return eer, threshold

In [18]:
def compute_speaker_metrics(eval_pred):
    """Compute EER metrics for speaker verification."""
    # Extract embeddings and labels
    embeddings1 = eval_pred.embeddings1
    embeddings2 = eval_pred.embeddings2
    pair_labels = eval_pred.labels
    
    # Compute similarity scores
    similarity_scores = np.array([
        np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + 1e-10)
        for e1, e2 in zip(embeddings1, embeddings2)
    ])
    
    # Compute FPR and FNR
    thresholds = np.sort(similarity_scores)
    fpr = np.zeros(len(thresholds))
    fnr = np.zeros(len(thresholds))
    
    for i, threshold in enumerate(thresholds):
        # Predictions based on threshold
        pred = (similarity_scores >= threshold).astype(int)
        
        # True positives, false positives, true negatives, false negatives
        tp = np.sum((pred == 1) & (pair_labels == 1))
        fp = np.sum((pred == 1) & (pair_labels == 0))
        tn = np.sum((pred == 0) & (pair_labels == 0))
        fn = np.sum((pred == 0) & (pair_labels == 1))
        
        # FPR and FNR
        fpr[i] = fp / (fp + tn) if (fp + tn) > 0 else 0
        fnr[i] = fn / (fn + tp) if (fn + tp) > 0 else 0
    
    # Calculate EER
    eer, threshold = compute_eer(fnr, fpr, similarity_scores)
    result = {
        "eer": eer,
        "eer_threshold": threshold
    }
    # Log EER to wandb directly
    wandb.log({"eer": eer})
    
    # Create and log DET curve to wandb
    if len(fpr) > 10:  # Only log if we have enough points
        
        # Log histogram of similarity scores
        try:
            wandb.log({
                "similarity_scores": wandb.Histogram(similarity_scores),
                "same_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 1]),
                "diff_speaker_scores": wandb.Histogram(similarity_scores[pair_labels == 0])
            })
        except Exception as e:
            print(f"Error logging histograms: {e}")
    
    return result

In [19]:
import random
from torch.utils.data import Sampler

class RandomSpeakerBatchSampler(Sampler):
    def __init__(self, dataset, speakers_per_batch=128, seed=42, infinite=True):
        """
        dataset: Dataset phải có dataset.speakers (list speaker_id)
                                 dataset.speaker_to_indices (dict: spk_id -> list index)
        speakers_per_batch: số speaker mỗi batch (mỗi speaker 1 file)
        seed: random seed
        infinite: nếu True thì sampler chạy vô hạn (thường dùng với max_steps)
        """
        self.dataset = dataset
        self.speakers = dataset.speakers
        self.speaker_to_indices = dataset.speaker_to_indices
        self.speakers_per_batch = speakers_per_batch
        self.infinite = infinite
        self.rng = random.Random(seed)

        assert len(self.speakers) >= speakers_per_batch, \
            f"Số speaker ({len(self.speakers)}) nhỏ hơn {speakers_per_batch}"

    def __iter__(self):
        while True:
            # chọn ngẫu nhiên 256 speaker
            batch_speakers = self.rng.sample(self.speakers, self.speakers_per_batch)
            batch_indices = []
            for spk in batch_speakers:
                # chọn ngẫu nhiên 1 utterance cho mỗi speaker
                idx = self.rng.choice(self.speaker_to_indices[spk])
                batch_indices.append(idx)
            yield batch_indices

            if not self.infinite:
                break

    def __len__(self):
        # nếu muốn tính epoch dựa trên số speaker duyệt hết một vòng
        return len(self.speakers) // self.speakers_per_batch


## WANDB Setup

In [20]:
!pip install wandb safetensors

[0m





In [21]:
from safetensors.torch import load_file
# Log model architecture to wandb
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ResNet48_ASV().to(device)

In [22]:
import wandb
import os
import torch
from safetensors.torch import load_file

In [23]:
config = {
    "model": "ResNet48_ASV_WithAAM",
    "embedding_dim": 256,
    "num_classes": 652,
    "features": "MFCC",
    "loss": "AAMSoftmax",
    
    # Optimizer & Learning
    "optimizer": "AdamW",
    "learning_rate": 3e-5,
    "gradient_accumulation_steps": 1,
    
    # Batch & Epochs
    "batch_size_per_device": 128,
    "global_batch_size": 128,
    "num_train_epochs": 50,
    
    # Evaluation & Saving
    "eval_strategy": "steps",
    "save_strategy": "steps",
    "save_total_limit": 45,
    "metric_for_best_model": "eval_eer",
    "greater_is_better": False,  # Lower EER is better
    
    # Logging
    "logging_steps": 500,
    "report_to": "wandb",
    
    # Hardware
    "fp16": True,
    "dataloader_num_workers": 4,
    
    # Misc
    "load_best_model_at_end": True
}


In [24]:
import os

os.environ["WANDB_KEY"] = "4b8af864ea6d5ec9af172b42a4c40e4444e20cf7"
wandb.login(key=os.getenv("WANDB_KEY"))


wandb.init(
    project="ResNet48-ASV-final",
    name="ResNet48-AAM-MFBE-60ksteps",
    config=config
)



[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhddat2k4[0m ([33mhddat2k4-uit[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Trainer


In [25]:
from transformers import Trainer
import torch
import numpy as np
from torch.optim.lr_scheduler import OneCycleLR

class VSASV_Trainer(Trainer):
    def __init__(self, *args, total_steps=50000, margin=0.2, scale=30, num_speakers=815, embedding_dim=256, **kwargs):
        super().__init__(*args, **kwargs)

        self.total_steps = total_steps
        self.margin = margin
        self.scale = scale
        self.num_speakers = num_speakers
        self.use_amp = self.args.fp16
        ##Initialize AAMSoftmax Loss
        self.criterion = AAMSoftmax(
            embedding_dim=embedding_dim,
            num_classes=num_speakers,
            m=margin,
            s=scale
        ).to(self.args.device)

        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []

        wandb.config.update({
            "embedding_dim": embedding_dim,
            "aam_margin": margin,
            "aam_scale": scale,
            "num_speakers": num_speakers
        })
    def create_optimizer(self):
        if self.optimizer is None:
            decay_params = []
            head_params = []

            for name, param in self.model.named_parameters():
                if not param.requires_grad:
                    continue
                if "classifier" in name or "am_softmax" in name:
                    head_params.append(param)   # AM-Softmax head
                else:
                    decay_params.append(param)  # Backbone

            optimizer_grouped_parameters = [
                {"params": decay_params, "weight_decay": 1e-5},
                {"params": head_params, "weight_decay": 1e-4},
            ]

            self.optimizer = torch.optim.AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(0.9, 0.98),
                eps=1e-8
            )
        return self.optimizer

    def create_scheduler(self, num_training_steps, optimizer=None):
        if self.lr_scheduler is None:
            def lr_lambda(step):
                if step < 10000:  # giữ nguyên trong 10k step đầu
                    return 1.0
                return 0.5 ** ((step - 10000) // 4000)
            
            self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=lr_lambda
            )
        return self.lr_scheduler


    def get_margin(self, epoch):
        """Schedule margin theo mô tả paper"""
        if epoch <= 10:
            return 0.0
        elif 4 <= epoch <= 13:
            # tăng tuyến tính từ 0 → 0.3 trong 10 epoch
            progress = (epoch - 3) / 10.0
            return 0.3 * progress
        else:
            return 0.3
    def get_train_dataloader(self):
        from torch.utils.data import DataLoader
    
        # tạo batch_sampler
        batch_sampler = RandomSpeakerBatchSampler(
            self.train_dataset, 
            speakers_per_batch=self.args.train_batch_size,  # dùng train_batch_size làm số speaker/batch
            infinite=True
        )
    
        return DataLoader(
            self.train_dataset,
            batch_sampler=batch_sampler,   # thay batch_size + shuffle
            collate_fn=self.data_collator,
            num_workers=4,
            pin_memory=True
        )
    def get_eval_dataloader(self, eval_dataset=None):
        """Create a working dataloader for evaluation"""
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        return DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator,
            num_workers=4,  # Critical: use single process
            pin_memory=True,
            drop_last=True
        )
        

    def training_step(self, model, inputs, num_items=None):
        # Lấy epoch hiện tại
        current_epoch = int(self.state.global_step // 1000) + 1  # vì 1 epoch = 5000 steps
        new_margin = self.get_margin(current_epoch)

        # update margin trong loss
        if hasattr(self.criterion, "margin"):
            self.criterion.margin = new_margin

        # logging
        if self.state.global_step % self.args.logging_steps == 0:
            self.log({"margin": new_margin})
            wandb.log({"margin": new_margin})
            print(f"Epoch {current_epoch}: margin = {new_margin:.4f}")

        return super().training_step(model, inputs, num_items)
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Compute AAMSoftmax loss for the speaker embeddings."""
        # Extract inputs
        input_values = inputs.get('input_values')
        labels = inputs.get('speaker_labels')
    
        device = next(model.parameters()).device
        input_values = input_values.to(device)
        labels = labels.to(device) if labels is not None else None
        
        # Handle evaluation inputs with pairs for EER computation
        is_eval_with_pairs = False
        if not model.training and inputs.get('input_values2') is not None:
            is_eval_with_pairs = True
            input_values2 = inputs.get('input_values2').to(device)
            pair_labels = inputs.get('pair_labels').to(device)
        
        # Forward pass to get speaker embeddings
        embeddings = model(input_values)
        
        # Handle evaluation with pairs for EER
        if is_eval_with_pairs:
            # Get embeddings for second utterance in pairs
            embeddings2 = model(input_values2)
            
            # Store pairs for EER calculation
            self.pairs_embeddings1.append(embeddings.detach().cpu())
            self.pairs_embeddings2.append(embeddings2.detach().cpu())
            self.pairs_labels.append(pair_labels.detach().cpu())
        
        # Use AAMSoftmax loss for training
        if labels is not None:
            loss, outputs = self.criterion(embeddings, labels)
            
            # Log loss to wandb during training
            if self.model.training and self.state.global_step % self.args.logging_steps == 0:
                wandb.log({"train/aam_loss": loss.item()})
        else:
            loss = None
            outputs = None
        torch.cuda.empty_cache()
        if return_outputs:
            return loss, {"loss": loss, "logits": outputs, "embeddings": embeddings}
        else:
            return loss
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        with torch.no_grad():
            device = next(model.parameters()).device
    
            # Forward input1
            embeddings1 = model(inputs["input_values"].to(device))
            
            # Forward input2 (nếu có)
            if "input_values2" in inputs and "pair_labels" in inputs:
                embeddings2 = model(inputs["input_values2"].to(device))
    
                # Lưu embedding & nhãn để tính EER
                self.pairs_embeddings1.append(embeddings1.detach().cpu())
                self.pairs_embeddings2.append(embeddings2.detach().cpu())
                self.pairs_labels.append(inputs["pair_labels"].detach().cpu())
    
            # Nếu có label chính (ví dụ speaker_labels) thì tính loss cho HuggingFace Trainer
            if "speaker_labels" in inputs:
                outputs = model(inputs["input_values"].to(device))
                loss = self.criterion(outputs, inputs["speaker_labels"].to(device))
            else:
                loss = None
    
        return (loss, None, None)

      
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Override evaluate to compute EER at the end of evaluation."""
        # Reset storage for pairs
        self.pairs_embeddings1 = []
        self.pairs_embeddings2 = []
        self.pairs_labels = []
        
        # Run standard evaluation
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
        
        # Calculate EER if we have collected pairs
        if len(self.pairs_embeddings1) > 0:
            # Prepare data for compute metrics function
            embeddings1 = torch.cat(self.pairs_embeddings1, dim=0).numpy()
            embeddings2 = torch.cat(self.pairs_embeddings2, dim=0).numpy()
            pair_labels = torch.cat(self.pairs_labels, dim=0).numpy()
            
            # Create a container class to hold the embeddings
            class EmbeddingPairs:
                def __init__(self, embeddings1, embeddings2, labels):
                    self.embeddings1 = embeddings1
                    self.embeddings2 = embeddings2
                    self.labels = labels
            
            eval_pairs = EmbeddingPairs(embeddings1, embeddings2, pair_labels)
            
            # Compute EER metrics
            eer_metrics = compute_speaker_metrics(eval_pairs)
            
            # Add EER metrics to the overall metrics
            for key, value in eer_metrics.items():
                if key not in metrics:
                    metrics[f"{metric_key_prefix}_{key}"] = value
            
            # Log to wandb with correct prefix
            wandb_metrics = {
                f"{metric_key_prefix}/{key}": value 
                for key, value in metrics.items() 
                if key.startswith(metric_key_prefix)
            }
            wandb.log(wandb_metrics)
            
            print(f"\n{metric_key_prefix.capitalize()} EER: {metrics.get(f'{metric_key_prefix}_eer', 0):.4f}")
            
            # Log embedding visualizations to wandb (t-SNE of random subset)
            if len(embeddings1) > 100:
                try:
                    from sklearn.manifold import TSNE
                    # Sample a subset for visualization (for efficiency)
                    max_samples = min(500, len(embeddings1))
                    indices = np.random.choice(len(embeddings1), max_samples, replace=False)
                    
                    # Apply t-SNE
                    tsne = TSNE(n_components=2, random_state=42)
                    embeddings_combined = np.vstack([embeddings1[indices], embeddings2[indices]])
                    embeddings_2d = tsne.fit_transform(embeddings_combined)
                    
                    # Split back into two sets
                    n_samples = len(indices)
                    embeddings1_2d = embeddings_2d[:n_samples]
                    embeddings2_2d = embeddings_2d[n_samples:]
                    
                    # Create scatter plot data
                    data = []
                    for i in range(n_samples):
                        data.append([
                            embeddings1_2d[i, 0], embeddings1_2d[i, 1], 
                            "Embedding 1", int(pair_labels[indices[i]])
                        ])
                        data.append([
                            embeddings2_2d[i, 0], embeddings2_2d[i, 1], 
                            "Embedding 2", int(pair_labels[indices[i]])
                        ])
                    
                    # Log to wandb
                    wandb.log({
                        f"{metric_key_prefix}/embeddings_tsne": wandb.Table(
                            data=data,
                            columns=["x", "y", "embedding_type", "same_speaker"]
                        )
                    })
                except ImportError:
                    print("sklearn not available for t-SNE visualization")
                except Exception as e:
                    print(f"Error creating t-SNE visualization: {e}")
        
        return metrics

## Training

In [26]:
from transformers import TrainingArguments
# Define training arguments
training_args = TrainingArguments(
    output_dir="./checkpoint_ResNet481",
    per_device_train_batch_size=128,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    eval_steps=1000,
    logging_steps=500,
    per_device_eval_batch_size=128,
    learning_rate=1e-4,
    gradient_accumulation_steps=1,
    save_total_limit=101,
    max_steps=50000,
    #warmup_steps=4000,
    dataloader_num_workers=4,
    report_to=["wandb"],  # Enable logging to wandb
    metric_for_best_model="eval_eer",
    greater_is_better=False,  # Lower EER is better
    fp16=True,
    load_best_model_at_end=True,
)

In [27]:
len(train_dataset)

292005

In [None]:
trainer = VSASV_Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, 
    eval_dataset=val_dataset,    
    data_collator=combined_collator,  
    compute_metrics=compute_speaker_metrics,
    total_steps=int(training_args.max_steps),
    num_speakers=652,  # Adjust based on your dataset
    margin=0,
    scale=30,
)

# Train model
trainer.train()
# Tiếp tục từ checkpoint step 15000
#trainer.train(resume_from_checkpoint="./checkpoint_ResNet48/checkpoint-15000")

# Save best model (automatically saved if using metric_for_best_model + load_best_model_at_end=True)
trainer.save_model(os.path.join(training_args.output_dir, "best_model"))

# Save training state (optimizer, scheduler, etc.)
trainer.save_state()

# Log to wandb
wandb.save(os.path.join(training_args.output_dir, "best_model", "*"))
wandb.save(os.path.join(training_args.output_dir, "trainer_state.json"))
wandb.save(os.path.join(training_args.output_dir, "optimizer.pt"))
wandb.save(os.path.join(training_args.output_dir, "scheduler.pt"))

# Finish wandb run
wandb.finish()




Epoch 1: margin = 0.0000


Step,Training Loss,Validation Loss
1000,2.6123,No log
2000,1.2919,No log
3000,0.824,No log
4000,0.5934,No log
5000,0.4611,No log
6000,0.3756,No log
7000,0.303,No log


Epoch 1: margin = 0.0000

Eval EER: 0.2732
Epoch 2: margin = 0.0000
Epoch 2: margin = 0.0000

Eval EER: 0.2878
Epoch 3: margin = 0.0000
Epoch 3: margin = 0.0000

Eval EER: 0.2777
Epoch 4: margin = 0.0000
Epoch 4: margin = 0.0000

Eval EER: 0.2686
Epoch 5: margin = 0.0000
Epoch 5: margin = 0.0000

Eval EER: 0.2608
Epoch 6: margin = 0.0000
Epoch 6: margin = 0.0000

Eval EER: 0.2614
Epoch 7: margin = 0.0000
Epoch 7: margin = 0.0000

Eval EER: 0.2581
Epoch 8: margin = 0.0000
Epoch 8: margin = 0.0000
