# Data

0 -15 done

15-40 done

40-65 done

65-100 done

100-150 done

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:

from datasets import load_dataset, concatenate_datasets, Dataset
import os

SAVE_PATH = "vsasv_subset_0_349"
os.makedirs(SAVE_PATH, exist_ok=True)

# Collect datasets one shard at a time
shard_datasets = []

for i in range(400,403):
    shard = f"data/train-{i:05d}-of-00432.parquet"
    print(f"Downloading shard {shard} ...")

    # Load in streaming mode
    streaming_data = load_dataset(
        "hustep-lab/VSASV-Dataset",
        split="train",
        data_files=[shard],
        streaming=True
    )

    # Materialize this shard only
    ds = Dataset.from_generator(lambda: (ex for ex in streaming_data))
    shard_datasets.append(ds)

# Merge everything into one dataset
full_dataset = concatenate_datasets(shard_datasets)

# Save to disk
full_dataset.save_to_disk(SAVE_PATH)

print(f"‚úÖ Saved subset (0‚Äì349) to {SAVE_PATH}")


Downloading shard data/train-00400-of-00432.parquet ...


Generating train split: 0 examples [00:00, ? examples/s]

Downloading shard data/train-00401-of-00432.parquet ...


Generating train split: 0 examples [00:00, ? examples/s]

Downloading shard data/train-00402-of-00432.parquet ...


Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/1533 [00:00<?, ? examples/s]

‚úÖ Saved subset (0‚Äì349) to vsasv_subset_0_349


In [None]:
import os
import glob
import soundfile as sf
from datasets import Dataset
from tqdm import tqdm

# Path to Arrow dataset folder
DATA_PATH = "/content/vsasv_subset_0_349"
EXPORT_PATH = "/content/vsasv_export"
os.makedirs(EXPORT_PATH, exist_ok=True)

# Find all .arrow shard files
arrow_files = sorted(glob.glob(os.path.join(DATA_PATH, "data-*.arrow")))

for arrow_file in arrow_files:
    print(f"üìÇ Processing {arrow_file} ...")

    # Load only this shard
    ds = Dataset.from_file(arrow_file)

    # Extract WAVs from this shard
    for ex in tqdm(ds, desc=f"Extracting {os.path.basename(arrow_file)}"):
        att_type = ex["utt_type"]       # e.g. "bonafide"
        speaker_id = ex["label"]        # e.g. "id00037"
        file_name = os.path.basename(ex["file"])  # e.g. "00025.wav"

        out_dir = os.path.join(EXPORT_PATH, att_type, speaker_id)
        os.makedirs(out_dir, exist_ok=True)

        audio = ex["audio"]
        wav_path = os.path.join(out_dir, file_name)
        sf.write(wav_path, audio["array"], audio["sampling_rate"])

    # Delete shard after extraction
    os.remove(arrow_file)
    print(f"üóëÔ∏è Deleted {arrow_file}")

print("‚úÖ Extraction complete, all shards deleted.")


üìÇ Processing /content/vsasv_subset_0_349/data-00000-of-00003.arrow ...


Extracting data-00000-of-00003.arrow: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 511/511 [00:36<00:00, 14.14it/s]


üóëÔ∏è Deleted /content/vsasv_subset_0_349/data-00000-of-00003.arrow
üìÇ Processing /content/vsasv_subset_0_349/data-00001-of-00003.arrow ...


Extracting data-00001-of-00003.arrow: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 511/511 [00:17<00:00, 28.97it/s]


üóëÔ∏è Deleted /content/vsasv_subset_0_349/data-00001-of-00003.arrow
üìÇ Processing /content/vsasv_subset_0_349/data-00002-of-00003.arrow ...


Extracting data-00002-of-00003.arrow: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 511/511 [00:33<00:00, 15.05it/s]

üóëÔ∏è Deleted /content/vsasv_subset_0_349/data-00002-of-00003.arrow
‚úÖ Extraction complete, all shards deleted.





In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split

EXPORT_PATH = "/content/vsasv_export"
META_PATH = "/content/metadatas"
os.makedirs(META_PATH, exist_ok=True)

records = []

# Walk through export folder
for att_type in os.listdir(EXPORT_PATH):
    att_type_dir = os.path.join(EXPORT_PATH, att_type)
    if not os.path.isdir(att_type_dir):
        continue

    for speaker_id in os.listdir(att_type_dir):
        speaker_dir = os.path.join(att_type_dir, speaker_id)
        if not os.path.isdir(speaker_dir):
            continue

        for file_name in os.listdir(speaker_dir):
            if not file_name.endswith(".wav"):
                continue
            rel_path = os.path.join(att_type, speaker_id, file_name)
            records.append([rel_path, att_type, speaker_id])

# Build DataFrame
df = pd.DataFrame(records, columns=["path", "att_type", "speaker_id"])

# Encode att_type + speaker_id
df["att_type_id"] = df["att_type"].astype("category").cat.codes
df["speaker_id_num"] = df["speaker_id"].astype("category").cat.codes

# Save full metadata
train_csv = os.path.join(META_PATH, "metadata_train.csv")
df.to_csv(train_csv, index=False)
print(f"üìë Saved {train_csv}")

# Split into train/val (80/20)
train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df["speaker_id"], random_state=42
)

val_csv = os.path.join(META_PATH, "metadata_val.csv")
val_df.to_csv(val_csv, index=False)
print(f"üìë Saved {val_csv}")


üìë Saved /content/metadatas/metadata_train.csv
üìë Saved /content/metadatas/metadata_val.csv


# Data?

In [None]:
%cd  /content/

/content


In [None]:
!rm -rf /content/EcapaTdnn_revisited
!git clone https://github.com/daoanhkhoa123/EcapaTdnn_revisited.git

Cloning into 'EcapaTdnn_revisited'...
remote: Enumerating objects: 257, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 257 (delta 14), reused 29 (delta 9), pack-reused 222 (from 1)[K
Receiving objects: 100% (257/257), 38.67 KiB | 628.00 KiB/s, done.
Resolving deltas: 100% (165/165), done.


In [None]:
!gdown 1q_uq_iovbe4t9TX_NtScXxJGmmVgjOxs

Downloading...
From (original): https://drive.google.com/uc?id=1q_uq_iovbe4t9TX_NtScXxJGmmVgjOxs
From (redirected): https://drive.google.com/uc?id=1q_uq_iovbe4t9TX_NtScXxJGmmVgjOxs&confirm=t&uuid=e22411b0-dc58-4c94-a185-ee9ea7e2d6e6
To: /content/best_model_epoch7_20250930_123823.pt
100% 62.1M/62.1M [00:00<00:00, 127MB/s]


# Continue!

In [None]:
# import torch
# import gc

# # Clear cache
# torch.cuda.empty_cache()

# # Collect garbage
# gc.collect()

# # Optional: reset default device memory
# if torch.cuda.is_available():
#     torch.cuda.reset_peak_memory_stats()
#     print("GPU memory cleared.")

# import torch

# if torch.cuda.is_available():
#     total, free = torch.cuda.mem_get_info()  # returns bytes
#     print(f"Free GPU memory: {free / 1024**2:.2f} MB")
#     print(f"Total GPU memory: {total / 1024**2:.2f} MB")



In [None]:
%cd /content/EcapaTdnn_revisited/real_ecapatdnn

/content/EcapaTdnn_revisited/real_ecapatdnn


In [None]:
import torch
from model import ECAPA_TDNN
from aamloss import AAMsoftmax
from dataloader import VSAVSDataloader_config, VSAVSDataset_SpkerEmbed, collate_fn
from ultils import load_parameters, save_parameters, evaluate
from ultils import *

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16
NUM_WORKERS = 2
data_prefix_path = r"/content/vsasv_export"
# --- Train DataLoader ---
train_config = VSAVSDataloader_config(
    df_path="/content/metadatas/metadata_train.csv",
    data_prefix_path=data_prefix_path,
    bonafide_only=True,
)
train_dataset = VSAVSDataset_SpkerEmbed(train_config)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
)

# --- Validation DataLoader ---
val_config = VSAVSDataloader_config(
    df_path="/content/metadatas/metadata_val.csv",
    data_prefix_path=data_prefix_path,
    bonafide_only=True,
)
val_dataset = VSAVSDataset_SpkerEmbed(val_config)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # usually no shuffle for validation
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
)

# # --- Test DataLoader ---
# test_config = VSAVSDataloader_config(
#     df_path="/content/metadatas/metadata_test.csv",
#     data_prefix_path=data_prefix_path,
#     bonafide_only=True,
#     augument=False
# )
# test_dataset = VSAVSDataset_SpkerEmbed(test_config)
# test_loader = DataLoader(
#     test_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=False,  # no shuffle for test
#     num_workers=NUM_WORKERS,
#     collate_fn=collate_fn
# )


In [None]:
train_dataset[5]

(tensor([-0.0283, -0.0217, -0.0237,  ...,  0.0129,  0.0184,  0.0236]),
 np.int64(0))

In [None]:
train_dataset.n_speakers

127

In [None]:
device= "cuda" if torch.cuda.is_available() else "cpu"
model = ECAPA_TDNN(C=1024).to(device)
load_parameters(model, r"/content/best_model_epoch7_20250930_123823.pt")


In [None]:
%cd /content/

/content


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from datetime import datetime

device = "cuda" if torch.cuda.is_available() else "cpu"

loss_fn = AAMsoftmax(train_dataset.n_speakers, m=0.2, s=30).to(device)
optimizer = optim.Adam(
    list(model.parameters()) + list(loss_fn.parameters()),
    lr=1e-3, weight_decay=1e-5
)

EPOCHS = 16
PATIENCE = 3
best_val_eer = 100.0  # start with a high EER
epochs_no_improve = 0

for epoch in range(EPOCHS):
    # ---- Training ----
    model.train()
    total_loss = 0.0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for audios, labels in train_bar:
        audios = audios.to(device)   # (B, T)
        labels = labels.to(device)   # (B,)

        embeds = model(audios, aug=True)  # (B, D)
        loss = loss_fn(embeds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (len(train_bar) or 1)
        train_bar.set_postfix(loss=f"{avg_loss:.4f}")

    avg_train_loss = total_loss / len(train_loader)

    # ---- Validation / Evaluation ----
    model.eval()
    all_scores, all_labels = [], []

    with torch.no_grad():
        for audios, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            audios = audios.to(device)
            labels = labels.to(device)

            embeds = model(audios, aug=False)  # (B, D)

            # cosine similarity matrix
            sim_matrix = torch.matmul(embeds, embeds.T)
            sim_matrix = sim_matrix / torch.norm(embeds, dim=1, keepdim=True) / torch.norm(embeds, dim=1).unsqueeze(0)

            # collect pairwise scores
            for i in range(len(labels)):
                for j in range(i + 1, len(labels)):
                    score = sim_matrix[i, j].item()
                    all_scores.append(score)
                    all_labels.append(1 if labels[i] == labels[j] else 0)

    # ---- Compute metrics ----
    tunedThreshold, eer, fpr, fnr = tuneThresholdfromScore(all_scores, all_labels, [0.01, 0.05])
    fnrs, fprs, thresholds = ComputeErrorRates(all_scores, all_labels)
    minDCF, _ = ComputeMinDcf(fnrs, fprs, thresholds, p_target=0.01, c_miss=1, c_fa=1)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"EER: {eer:.2f}% | MinDCF: {minDCF:.4f}")

    # ---- Save best checkpoint or apply patience ----
    if eer < best_val_eer:
        best_val_eer = eer
        epochs_no_improve = 0  # reset counter
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        ckpt_path = f"/content/best_model_epoch{epoch+1}_{timestamp}.pt"
        save_parameters(model, ckpt_path)
        print(f"‚úÖ Saved best model: {ckpt_path} | EER {eer:.2f}%")
    else:
        epochs_no_improve += 1
        print(f"‚ö†Ô∏è No improvement in EER for {epochs_no_improve} epoch(s).")

    # ---- Early stopping ----
    if epochs_no_improve >= PATIENCE:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        last_ckpt_path = f"/content/last_model_epoch{epoch+1}_{timestamp}.pt"
        save_parameters(model, last_ckpt_path)
        print(f"‚èπÔ∏è Early stopping triggered (patience {PATIENCE}).")
        print(f"üíæ Saved last model: {last_ckpt_path}")
        print(f"Best EER achieved: {best_val_eer:.2f}%")
        break


Epoch 1/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:18<00:00,  4.11it/s, loss=8.0643]
Epoch 1/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.86it/s]


Epoch 1/16 | Train Loss: 8.0643 | EER: 24.54% | MinDCF: 0.9878
‚úÖ Saved best model: /content/best_model_epoch1_20250930_141354.pt | EER 24.54%


Epoch 2/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:27<00:00,  4.04it/s, loss=6.5022]
Epoch 2/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.91it/s]


Epoch 2/16 | Train Loss: 6.5022 | EER: 25.85% | MinDCF: 0.9634
‚ö†Ô∏è No improvement in EER for 1 epoch(s).


Epoch 3/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:24<00:00,  4.06it/s, loss=5.8722]
Epoch 3/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.86it/s]


Epoch 3/16 | Train Loss: 5.8722 | EER: 22.00% | MinDCF: 0.9595
‚úÖ Saved best model: /content/best_model_epoch3_20250930_143155.pt | EER 22.00%


Epoch 4/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:26<00:00,  4.04it/s, loss=5.3956]
Epoch 4/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.86it/s]


Epoch 4/16 | Train Loss: 5.3956 | EER: 22.15% | MinDCF: 0.9515
‚ö†Ô∏è No improvement in EER for 1 epoch(s).


Epoch 5/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:22<00:00,  4.08it/s, loss=5.0867]
Epoch 5/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.88it/s]


Epoch 5/16 | Train Loss: 5.0867 | EER: 19.51% | MinDCF: 0.9717
‚úÖ Saved best model: /content/best_model_epoch5_20250930_144953.pt | EER 19.51%


Epoch 6/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:27<00:00,  4.03it/s, loss=4.7616]
Epoch 6/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.85it/s]


Epoch 6/16 | Train Loss: 4.7616 | EER: 16.69% | MinDCF: 0.9878
‚úÖ Saved best model: /content/best_model_epoch6_20250930_145856.pt | EER 16.69%


Epoch 7/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:26<00:00,  4.04it/s, loss=4.5016]
Epoch 7/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.92it/s]


Epoch 7/16 | Train Loss: 4.5016 | EER: 13.41% | MinDCF: 0.9595
‚úÖ Saved best model: /content/best_model_epoch7_20250930_150757.pt | EER 13.41%


Epoch 8/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:24<00:00,  4.06it/s, loss=4.2929]
Epoch 8/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.87it/s]


Epoch 8/16 | Train Loss: 4.2929 | EER: 11.16% | MinDCF: 0.8907
‚úÖ Saved best model: /content/best_model_epoch8_20250930_151656.pt | EER 11.16%


Epoch 9/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:28<00:00,  4.03it/s, loss=4.1133]
Epoch 9/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.86it/s]


Epoch 9/16 | Train Loss: 4.1133 | EER: 12.20% | MinDCF: 0.9512
‚ö†Ô∏è No improvement in EER for 1 epoch(s).


Epoch 10/16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [08:24<00:00,  4.06it/s, loss=3.9121]
Epoch 10/16 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 410/410 [00:34<00:00, 11.92it/s]

Epoch 10/16 | Train Loss: 3.9121 | EER: 13.41% | MinDCF: 1.0000
‚ö†Ô∏è No improvement in EER for 2 epoch(s).
‚èπÔ∏è Early stopping triggered (patience 2).
üíæ Saved last model: /content/last_model_epoch10_20250930_153458.pt
Best EER achieved: 11.16%





In [None]:
print(f"‚úÖ Saved best model: {ckpt_path}")

‚úÖ Saved best model: /content/best_model_epoch8_20250930_151656.pt


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# raise ValueError()
# Copy your checkpoint into your Drive
!cp {ckpt_path} /content/drive/MyDrive/


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
