<a href="https://colab.research.google.com/github/jjbmsda/EnsembleModel/blob/main/EnsembleModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# soundfile 설치
!pip -q install soundfile

In [3]:
# torchcodec 설치
!pip -q install torchcodec

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━[0m [32m1.4/2.1 MB[0m [31m40.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import torchaudio
from torchaudio.datasets import LIBRISPEECH

train_raw = LIBRISPEECH(
    root="/content/drive/MyDrive/datasets",
    url="dev-clean",
    download=False
)

test_raw = LIBRISPEECH(
    root="/content/drive/MyDrive/datasets",
    url="test-clean",
    download=False
)

In [None]:
import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import resnet18, densenet121

import torchaudio
from torchaudio.datasets import LIBRISPEECH
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import top_k_accuracy_score


# =========================
# Colab RAM-safe settings
# =========================
# 하이퍼파라미터 세팅
BATCH_SIZE = 8
EPOCHS = 1
MAX_SPEAKERS = 50
TOPK = (1, 3)


def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)


def pad_trim_2d(spec: torch.Tensor, target_frames: int, pad_value: float = 0.0) -> torch.Tensor:
    T = spec.size(-1) #spec 의 마지막 차원 길이
    if T > target_frames: #dataset 에서 targetframe 을 받아와서 비교
        return spec[..., :target_frames]
    elif T < target_frames:
        return F.pad(spec, (0, target_frames - T), value=pad_value)
    return spec


class LibriSpeechSpeakerDataset(Dataset):
    def __init__(self, dataset, spk2idx, sample_rate=16000, n_mels=64, target_frames=256):
        self.dataset = dataset
        self.spk2idx = spk2idx
        self.sample_rate = sample_rate
        self.target_frames = target_frames
        self.melspec = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        self.to_db = AmplitudeToDB(stype="power")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sr, transcript, speaker_id, chapter_id, utterance_id = self.dataset[idx]
        spk = int(speaker_id)

        # max_speakers 밖 화자는 스킵.
        if spk not in self.spk2idx:
            return None

        # mono
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # resample to 16k
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        # mel -> db, then pad/trim time axis
        spec = self.to_db(self.melspec(waveform))     # [1, n_mels, T]
        spec = pad_trim_2d(spec, self.target_frames)  # [1, n_mels, target_frames]

        y = self.spk2idx[spk]
        return spec, torch.tensor(y, dtype=torch.long)


def collate_skip_none(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    xs, ys = zip(*batch)
    return torch.stack(xs, dim=0), torch.stack(ys, dim=0)


class ResNetModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = resnet18(weights=None, num_classes=num_classes)
        self.net.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.net.maxpool = nn.Identity()

    def forward(self, x):
        return self.net(x)


class DenseNetModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = densenet121(weights=None)
        self.net.features.conv0 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.net.features.pool0 = nn.Identity()
        self.net.classifier = nn.Linear(self.net.classifier.in_features, num_classes)

    def forward(self, x):
        return self.net(x)


def train_one_epoch(model, loader, criterion, optimizer, device, use_amp, epoch, model_name="model"):
    model.train()
    total_loss, steps = 0.0, 0
    scaler = torch.amp.GradScaler("cuda") if use_amp else None

    for i, batch in enumerate(loader):
        if batch is None:
            continue
        x, y = batch
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with torch.amp.autocast("cuda"):
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

        total_loss += float(loss.item())
        steps += 1

        if (i + 1) % 50 == 0:
            print(
                f"[Epoch {epoch}][{model_name}] "
                f"step {i+1}/{len(loader)} "
                f"loss={total_loss/steps:.4f}"
            )

    return total_loss / max(1, steps)


@torch.no_grad()
def eval_model(model, loader, device, num_classes, topk=TOPK):
    model.eval()
    probs_all, targets_all = [], []

    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)

        logits = model(x)
        probs = logits.softmax(dim=1).cpu().numpy()

        probs_all.append(probs)
        targets_all.append(y.numpy())

    if len(targets_all) == 0:
        return None

    probs_all = np.concatenate(probs_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)
    preds = probs_all.argmax(axis=1)

    out = {
        "acc": accuracy_score(targets_all, preds),
        "macro_f1": f1_score(targets_all, preds, average="macro"),
        "top1": top_k_accuracy_score(targets_all, probs_all, k=1, labels=list(range(num_classes)))
    }
    for k in topk:
        if k <= num_classes:
            out[f"top{k}"] = top_k_accuracy_score(targets_all, probs_all, k=k, labels=list(range(num_classes)))
    return out


@torch.no_grad()
def eval_ensemble(rnet, dnet, loader, device, num_classes, topk=TOPK, alpha=0.8):
    rnet.eval()
    dnet.eval()
    probs_all, targets_all = [], []

    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)

        # (중요) logits_r/logits_d를 먼저 계산해야 함
        logits_r = rnet(x)
        logits_d = dnet(x)

        # 가중 앙상블
        logits = alpha * logits_r + (1 - alpha) * logits_d
        probs = logits.softmax(dim=1).cpu().numpy()

        probs_all.append(probs)
        targets_all.append(y.numpy())

    if len(targets_all) == 0:
        return None

    probs_all = np.concatenate(probs_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)
    preds = probs_all.argmax(axis=1)

    out = {
        "acc": accuracy_score(targets_all, preds),
        "macro_f1": f1_score(targets_all, preds, average="macro"),
        "top1": top_k_accuracy_score(targets_all, probs_all, k=1, labels=list(range(num_classes)))
    }
    for k in topk:
        if k <= num_classes:
            out[f"top{k}"] = top_k_accuracy_score(targets_all, probs_all, k=k, labels=list(range(num_classes)))
    return out


def split_indices(n, seed=42, train=0.8, val=0.1):
    idxs = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(idxs)
    n_train = int(n * train)
    n_val = int(n * val)
    train_idxs = idxs[:n_train]
    val_idxs = idxs[n_train:n_train + n_val]
    test_idxs = idxs[n_train + n_val:]
    return train_idxs, val_idxs, test_idxs


def main():
    os.makedirs("./data", exist_ok=True)

    # dev-clean만 사용
    raw = LIBRISPEECH("./data", url="dev-clean", download=True)

    # speaker mapping (최대 MAX_SPEAKERS명만 사용)
    all_speakers = sorted({int(spk) for _, _, _, spk, *_ in raw})
    speakers = all_speakers[:MAX_SPEAKERS]
    spk2idx = {spk: i for i, spk in enumerate(speakers)}
    num_classes = len(speakers)

    print(f"Using speakers: {num_classes}/{len(all_speakers)} (MAX_SPEAKERS={MAX_SPEAKERS})")

    full_ds = LibriSpeechSpeakerDataset(raw, spk2idx, n_mels=64, target_frames=256)

    # split (발화 단위)
    tr, va, te = split_indices(len(full_ds), seed=42, train=0.8, val=0.1)
    train_ds = Subset(full_ds, tr)
    val_ds   = Subset(full_ds, va)
    test_ds  = Subset(full_ds, te)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = (device.type == "cuda")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_skip_none)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_skip_none)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_skip_none)

    rnet = ResNetModel(num_classes).to(device)
    dnet = DenseNetModel(num_classes).to(device)

    crit = nn.CrossEntropyLoss()
    opt_r = optim.Adam(rnet.parameters(), lr=1e-3)
    opt_d = optim.Adam(dnet.parameters(), lr=1e-3)

    for ep in range(1, EPOCHS + 1):
        loss_r = train_one_epoch(
            rnet, train_loader, crit, opt_r, device, use_amp,
            epoch=ep, model_name="ResNet"
        )

        loss_d = train_one_epoch(
            dnet, train_loader, crit, opt_d, device, use_amp,
            epoch=ep, model_name="DenseNet"
        )

        val_r = eval_model(rnet, val_loader, device, num_classes, topk=TOPK)
        val_d = eval_model(dnet, val_loader, device, num_classes, topk=TOPK)
        val_e = eval_ensemble(rnet, dnet, val_loader, device, num_classes, topk=TOPK, alpha=0.8)

        print(f"\n[Epoch {ep}/{EPOCHS}]")
        print(f"  ResNet   loss={loss_r:.4f}  val={val_r}")
        print(f"  DenseNet loss={loss_d:.4f}  val={val_d}")
        print(f"  Ensemble           val={val_e}")

    test_r = eval_model(rnet, test_loader, device, num_classes, topk=TOPK)
    test_d = eval_model(dnet, test_loader, device, num_classes, topk=TOPK)
    test_e = eval_ensemble(rnet, dnet, test_loader, device, num_classes, topk=TOPK, alpha=0.8)

    if test_r is None or test_d is None or test_e is None:
        print("\nTEST set이 비어있거나 모두 None으로 필터링됐어. (split/데이터 로딩 확인 필요)")
        return

    print("\n=== TEST RESULTS (dev-clean split) ===")
    print(f"ResNet   {test_r}")
    print(f"DenseNet {test_d}")
    print(f"Ensemble {test_e}")


if __name__ == "__main__":
    main()

Using speakers: 40/40 (MAX_SPEAKERS=50)
[Epoch 1][ResNet] step 50/271 loss=3.3546
[Epoch 1][ResNet] step 100/271 loss=2.9880
[Epoch 1][ResNet] step 150/271 loss=2.7235
[Epoch 1][ResNet] step 200/271 loss=2.4949
[Epoch 1][ResNet] step 250/271 loss=2.2784
[Epoch 1][DenseNet] step 50/271 loss=3.6242
[Epoch 1][DenseNet] step 100/271 loss=3.3382
[Epoch 1][DenseNet] step 150/271 loss=3.0942
[Epoch 1][DenseNet] step 200/271 loss=2.9085
[Epoch 1][DenseNet] step 250/271 loss=2.7663

[Epoch 1/1]
  ResNet   loss=2.2080  val={'acc': 0.5185185185185185, 'macro_f1': 0.471012906867878, 'top1': np.float64(0.5185185185185185), 'top3': np.float64(0.8037037037037037)}
  DenseNet loss=2.7120  val={'acc': 0.43703703703703706, 'macro_f1': 0.35431933093603574, 'top1': np.float64(0.43703703703703706), 'top3': np.float64(0.7)}
  Ensemble           val={'acc': 0.5555555555555556, 'macro_f1': 0.4918604140691743, 'top1': np.float64(0.5555555555555556), 'top3': np.float64(0.8333333333333334)}

=== TEST RESULTS (de

In [4]:

import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import resnet18, densenet121

import torchaudio
from torchaudio.datasets import LIBRISPEECH
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

from sklearn.metrics import accuracy_score, f1_score, roc_curve

# =========================
# ✅ Colab RAM-safe settings
# =========================
BATCH_SIZE = 8
EPOCHS = 1
MAX_SPEAKERS = 50

# Embedding config
EMB_DIM = 256
MARGIN = 0.2
SCALE = 30.0

# ✅ 앙상블 임베딩 비중 (ResNet 비중)
ALPHA = 0.8

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

def pad_trim_2d(spec: torch.Tensor, target_frames: int, pad_value: float = 0.0) -> torch.Tensor:
    T = spec.size(-1)
    if T > target_frames:
        return spec[..., :target_frames]
    elif T < target_frames:
        return F.pad(spec, (0, target_frames - T), value=pad_value)
    return spec

class LibriSpeechSpeakerDataset(Dataset):
    def __init__(self, dataset, spk2idx, sample_rate=16000, n_mels=64, target_frames=256):
        self.dataset = dataset
        self.spk2idx = spk2idx
        self.sample_rate = sample_rate
        self.target_frames = target_frames
        self.melspec = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        self.to_db = AmplitudeToDB(stype="power")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sr, transcript, speaker_id, chapter_id, utterance_id = self.dataset[idx]
        spk = int(speaker_id)

        # MAX_SPEAKERS 밖 화자 스킵(매핑 없는 화자)
        if spk not in self.spk2idx:
            return None

        # mono
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # resample
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        spec = self.to_db(self.melspec(waveform))     # [1, n_mels, T]
        spec = pad_trim_2d(spec, self.target_frames)  # [1, n_mels, target_frames]

        y = self.spk2idx[spk]
        return spec, torch.tensor(y, dtype=torch.long)

def collate_skip_none(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    xs, ys = zip(*batch)
    return torch.stack(xs, dim=0), torch.stack(ys, dim=0)

# -------------------------
# ✅ ResNet Embedder
# -------------------------
class ResNetEmbedder(nn.Module):
    def __init__(self, emb_dim=256):
        super().__init__()
        base = resnet18(weights=None)
        base.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.maxpool = nn.Identity()
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # [B, 512, 1, 1]
        self.fc = nn.Linear(512, emb_dim)

    def forward(self, x):
        h = self.backbone(x).flatten(1)  # [B, 512]
        e = self.fc(h)                   # [B, D]
        return F.normalize(e, dim=1)

# -------------------------
# ✅ DenseNet Embedder
# -------------------------
class DenseNetEmbedder(nn.Module):
    def __init__(self, emb_dim=256):
        super().__init__()
        base = densenet121(weights=None)
        base.features.conv0 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.features.pool0 = nn.Identity()

        self.features = base.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(base.classifier.in_features, emb_dim)

    def forward(self, x):
        f = self.features(x)
        f = F.relu(f, inplace=True)
        f = self.pool(f).flatten(1)      # [B, C]
        e = self.fc(f)                   # [B, D]
        return F.normalize(e, dim=1)

# -------------------------
# ✅ Weighted-sum Ensemble Embedder
# -------------------------
class WeightedEnsembleEmbedder(nn.Module):
    def __init__(self, r_embedder, d_embedder, alpha=0.8):
        super().__init__()
        self.r = r_embedder
        self.d = d_embedder
        self.alpha = alpha

    def forward(self, x):
        e_r = self.r(x)  # normalized
        e_d = self.d(x)  # normalized
        e = self.alpha * e_r + (1.0 - self.alpha) * e_d
        return F.normalize(e, dim=1)  # IMPORTANT

# -------------------------
# ✅ AM-Softmax Head
# -------------------------
class AMSoftmaxHead(nn.Module):
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.2):
        super().__init__()
        self.W = nn.Parameter(torch.randn(num_classes, emb_dim))
        nn.init.xavier_normal_(self.W)
        self.s = s
        self.m = m

    def forward(self, emb, y):
        W = F.normalize(self.W, dim=1)   # [C, D]
        cos = F.linear(emb, W)           # [B, C]

        y_onehot = F.one_hot(y, num_classes=cos.size(1)).float()
        cos_m = cos - self.m * y_onehot
        logits = self.s * cos_m
        return logits

def split_indices(n, seed=42, train=0.8, val=0.1):
    idxs = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(idxs)
    n_train = int(n * train)
    n_val = int(n * val)
    train_idxs = idxs[:n_train]
    val_idxs = idxs[n_train:n_train + n_val]
    test_idxs = idxs[n_train + n_val:]
    return train_idxs, val_idxs, test_idxs

def train_one_epoch(embedder, head, loader, optimizer, device, use_amp, epoch, model_name="Ensemble-Embed"):
    embedder.train()
    head.train()
    total_loss, steps = 0.0, 0
    scaler = torch.amp.GradScaler("cuda") if use_amp else None
    criterion = nn.CrossEntropyLoss()

    for i, batch in enumerate(loader):
        if batch is None:
            continue
        x, y = batch
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with torch.amp.autocast("cuda"):
                emb = embedder(x)
                logits = head(emb, y)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            emb = embedder(x)
            logits = head(emb, y)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

        total_loss += float(loss.item())
        steps += 1

        if (i + 1) % 50 == 0:
            print(f"[Epoch {epoch}][{model_name}] step {i+1}/{len(loader)} loss={total_loss/max(1,steps):.4f}")

    return total_loss / max(1, steps)

@torch.no_grad()
def extract_embeddings(embedder, loader, device):
    embedder.eval()
    embs, ys = [], []
    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)
        e = embedder(x).cpu().numpy()
        embs.append(e)
        ys.append(y.numpy())
    if len(ys) == 0:
        return None
    return np.concatenate(embs, axis=0), np.concatenate(ys, axis=0)

def nearest_centroid_identification(embs_train, y_train, embs_eval, y_eval):
    num_classes = int(y_train.max()) + 1
    centroids = np.zeros((num_classes, embs_train.shape[1]), dtype=np.float32)

    for c in range(num_classes):
        m = embs_train[y_train == c].mean(axis=0)
        m = m / (np.linalg.norm(m) + 1e-12)
        centroids[c] = m

    sims = embs_eval @ centroids.T  # cosine sim since normalized
    preds = sims.argmax(axis=1)
    return {
        "acc": accuracy_score(y_eval, preds),
        "macro_f1": f1_score(y_eval, preds, average="macro"),
        "top1": float((preds == y_eval).mean())
    }

def compute_verification_eer(embs, y, max_pairs=30000):
    # 샘플링 기반 EER 근사값
    n = len(y)
    rng = np.random.default_rng(42)

    scores = []
    labels = []
    for _ in range(max_pairs):
        i, j = rng.integers(0, n, size=2)
        if i == j:
            continue
        s = float(np.dot(embs[i], embs[j]))  # cosine (normalized)
        scores.append(s)
        labels.append(1 if y[i] == y[j] else 0)

    scores = np.array(scores)
    labels = np.array(labels)

    fpr, tpr, thr = roc_curve(labels, scores)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fpr - fnr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2.0
    return {"eer": float(eer), "thr": float(thr[eer_idx])}

def main():
    os.makedirs("./data", exist_ok=True)

    raw = LIBRISPEECH("./data", url="dev-clean", download=True)

    all_speakers = sorted({int(spk) for _, _, _, spk, *_ in raw})
    speakers = all_speakers[:MAX_SPEAKERS]
    spk2idx = {spk: i for i, spk in enumerate(speakers)}
    num_classes = len(speakers)
    print(f"Using speakers: {num_classes}/{len(all_speakers)} | ALPHA={ALPHA}")

    full_ds = LibriSpeechSpeakerDataset(raw, spk2idx, n_mels=64, target_frames=256)

    tr, va, te = split_indices(len(full_ds), seed=42, train=0.8, val=0.1)
    train_ds = Subset(full_ds, tr)
    val_ds   = Subset(full_ds, va)
    test_ds  = Subset(full_ds, te)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = (device.type == "cuda")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, collate_fn=collate_skip_none)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_skip_none)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_skip_none)

    r_embed = ResNetEmbedder(emb_dim=EMB_DIM).to(device)
    d_embed = DenseNetEmbedder(emb_dim=EMB_DIM).to(device)
    embedder = WeightedEnsembleEmbedder(r_embed, d_embed, alpha=ALPHA).to(device)

    head = AMSoftmaxHead(emb_dim=EMB_DIM, num_classes=num_classes, s=SCALE, m=MARGIN).to(device)

    # ✅ 둘 다 학습시키는 경우(기본)
    optimizer = optim.Adam(list(r_embed.parameters()) + list(d_embed.parameters()) + list(head.parameters()), lr=1e-3)

    for ep in range(1, EPOCHS + 1):
        loss = train_one_epoch(embedder, head, train_loader, optimizer, device, use_amp, epoch=ep, model_name="WeightedEnsembleEmbed")
        print(f"\n[Epoch {ep}/{EPOCHS}] loss={loss:.4f}")

        pack_tr = extract_embeddings(embedder, train_loader, device)
        pack_va = extract_embeddings(embedder, val_loader, device)
        if pack_tr is not None and pack_va is not None:
            embs_tr, y_tr = pack_tr
            embs_va, y_va = pack_va
            val_id = nearest_centroid_identification(embs_tr, y_tr, embs_va, y_va)
            val_ver = compute_verification_eer(embs_va, y_va)
            print(f"  Val-ID  {val_id}")
            print(f"  Val-VER {val_ver}")

    pack_tr = extract_embeddings(embedder, train_loader, device)
    pack_te = extract_embeddings(embedder, test_loader, device)
    if pack_tr is None or pack_te is None:
        print("No embeddings extracted. Check data filtering.")
        return

    embs_tr, y_tr = pack_tr
    embs_te, y_te = pack_te
    test_id = nearest_centroid_identification(embs_tr, y_tr, embs_te, y_te)
    test_ver = compute_verification_eer(embs_te, y_te)

    print("\n=== TEST RESULTS (Embedding / Weighted Ensemble) ===")
    print(f"Test-ID  {test_id}")
    print(f"Test-VER {test_ver}")

if __name__ == "__main__":
    main()

100%|██████████| 322M/322M [00:18<00:00, 18.5MB/s]


Using speakers: 40/40 | ALPHA=0.8
[Epoch 1][WeightedEnsembleEmbed] step 50/271 loss=10.5423
[Epoch 1][WeightedEnsembleEmbed] step 100/271 loss=9.5721
[Epoch 1][WeightedEnsembleEmbed] step 150/271 loss=9.0353
[Epoch 1][WeightedEnsembleEmbed] step 200/271 loss=8.7153
[Epoch 1][WeightedEnsembleEmbed] step 250/271 loss=8.4324

[Epoch 1/1] loss=8.3149
  Val-ID  {'acc': 0.7148148148148148, 'macro_f1': 0.6581482200663786, 'top1': 0.7148148148148148}
  Val-VER {'eer': 0.13101455650951943, 'thr': 0.8855322599411011}

=== TEST RESULTS (Embedding / Weighted Ensemble) ===
Test-ID  {'acc': 0.6826568265682657, 'macro_f1': 0.6710281961368076, 'top1': 0.6826568265682657}
Test-VER {'eer': 0.14810861077226223, 'thr': 0.8929426670074463}


In [None]:
import os
import random
import math
from itertools import combinations

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models import resnet18, densenet121

import torchaudio
from torchaudio.datasets import LIBRISPEECH
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

from sklearn.metrics import accuracy_score, f1_score, top_k_accuracy_score


# =========================
# 하이퍼파라미터 & 설정
# =========================
BATCH_SIZE    = 8
EPOCHS        = 1
MAX_SPEAKERS  = 40      # dev-clean에서 사용할 최대 화자 수
TOPK          = (1, 3)  # ID 평가에서 top-k
EMB_DIM       = 256     # 임베딩 차원
USE_TRIPLET   = False   # DenseNet에 TripletLoss까지 쓸지 여부 (옵션)
TRIPLET_LAMBDA = 0.1    # CE + Triplet 조합 시 Triplet 비중
SAMPLE_RATE   = 16000
N_MELS        = 64
TARGET_FRAMES = 256     # mel-spectrogram 시간축 길이 고정


# =========================
# 시드 고정
# =========================
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)


# =========================
# 스펙트로그램 pad/trim
# =========================
def pad_trim_2d(spec: torch.Tensor, target_frames: int, pad_value: float = 0.0) -> torch.Tensor:
    """
    spec: [C, n_mels, T]
    target_frames: 맞추고 싶은 시간축 길이
    """
    T = spec.size(-1)
    if T > target_frames:
        return spec[..., :target_frames]
    elif T < target_frames:
        return F.pad(spec, (0, target_frames - T), value=pad_value)
    return spec


# =========================
# LibriSpeech Speaker Dataset
# =========================
class LibriSpeechSpeakerDataset(Dataset):
    def __init__(self, dataset, spk2idx,
                 sample_rate=SAMPLE_RATE,
                 n_mels=N_MELS,
                 target_frames=TARGET_FRAMES):
        """
        dataset   : torchaudio.datasets.LIBRISPEECH
        spk2idx   : {speaker_id(int): class_index}
        """
        self.dataset = dataset
        self.spk2idx = spk2idx
        self.sample_rate = sample_rate
        self.target_frames = target_frames

        self.melspec = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        self.to_db = AmplitudeToDB(stype="power")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # dataset 항목: (waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)
        waveform, sr, _, speaker_id, _, _ = self.dataset[idx]
        spk = int(speaker_id)

        # 선택한 MAX_SPEAKERS 밖이면 스킵
        if spk not in self.spk2idx:
            return None

        # mono로 변환
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # sample rate 맞추기
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        # mel-spectrogram -> dB -> pad/trim
        spec = self.to_db(self.melspec(waveform))      # [1, n_mels, T]
        spec = pad_trim_2d(spec, self.target_frames)   # [1, n_mels, TARGET_FRAMES]

        y = self.spk2idx[spk]
        return spec, torch.tensor(y, dtype=torch.long)


def collate_skip_none(batch):
    # Dataset에서 None 리턴한 샘플 제거
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    xs, ys = zip(*batch)
    return torch.stack(xs, dim=0), torch.stack(ys, dim=0)


# =========================
# Backbone 정의 (ResNet / DenseNet)
# =========================
class ResNetBackbone(nn.Module):
    def __init__(self, emb_dim=EMB_DIM):
        super().__init__()
        base = resnet18(weights=None)
        # 1채널 입력용
        base.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.maxpool = nn.Identity()
        # 특징 추출 부분
        self.features = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.layer1,
            base.layer2,
            base.layer3,
            base.layer4,
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(base.fc.in_features, emb_dim)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x).flatten(1)
        e = self.fc(x)            # [B, emb_dim]
        e = F.normalize(e, dim=1) # 임베딩 정규화 (cosine 용)
        return e


class DenseNetBackbone(nn.Module):
    def __init__(self, emb_dim=EMB_DIM):
        super().__init__()
        base = densenet121(weights=None)
        base.features.conv0 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.features.pool0 = nn.Identity()
        self.features = base.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(base.classifier.in_features, emb_dim)

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = self.pool(x).flatten(1)
        e = self.fc(x)
        e = F.normalize(e, dim=1)
        return e


# =========================
# ArcFace 스타일 Margin Product
# =========================
class ArcMarginProduct(nn.Module):
    """
    ArcFace / AM-Softmax 스타일의 margin-based classifier.
    - 입력: 정규화된 임베딩 [B, D]
    - 출력: logits [B, C]
    - 학습 시: 정답 클래스에 margin 부여
    - 평가 시(labels=None): margin 없이 cosine 기반 logits
    """
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.2):
        super().__init__()
        self.num_classes = num_classes
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, emb_dim))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, embeddings, labels=None):
        # weight / embedding 둘 다 L2 normalize
        W = F.normalize(self.weight, dim=1)       # [C, D]
        x = F.normalize(embeddings, dim=1)        # [B, D]

        cos = torch.matmul(x, W.t())              # [B, C]

        # 평가 모드: margin 없이 사용
        if labels is None:
            return self.s * cos

        # 학습 모드: 정답 클래스에 margin 부여
        one_hot = torch.zeros_like(cos)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)

        # AM-softmax 스타일: cos - m(정답 클래스에만)
        cos_m = cos - self.m * one_hot

        logits = self.s * cos_m
        return logits


# =========================
# ArcFace 모델 래퍼 (ResNet / DenseNet)
# =========================
class ResNetArcFaceModel(nn.Module):
    def __init__(self, num_classes, emb_dim=EMB_DIM, s=30.0, m=0.2):
        super().__init__()
        self.backbone = ResNetBackbone(emb_dim=emb_dim)
        self.head = ArcMarginProduct(emb_dim, num_classes, s=s, m=m)

    def forward(self, x, labels=None):
        e = self.backbone(x)
        logits = self.head(e, labels)
        return logits, e


class DenseNetArcFaceModel(nn.Module):
    def __init__(self, num_classes, emb_dim=EMB_DIM, s=30.0, m=0.2):
        super().__init__()
        self.backbone = DenseNetBackbone(emb_dim=emb_dim)
        self.head = ArcMarginProduct(emb_dim, num_classes, s=s, m=m)

    def forward(self, x, labels=None):
        e = self.backbone(x)
        logits = self.head(e, labels)
        return logits, e


# =========================
# (옵션) Triplet Loss 지원
# =========================
from torch.nn import TripletMarginLoss
triplet_loss_fn = TripletMarginLoss(margin=0.3, p=2)


def sample_triplets_in_batch(embeddings, labels):
    """
    매우 단순한 in-batch triplet 샘플러 (연구용 예시).
    실제로는 더 정교한 mining 전략을 쓰는 게 좋다.
    """
    device = embeddings.device
    labels = labels.view(-1)
    anchors, positives, negatives = [], [], []

    for i in range(len(labels)):
        anchor_label = labels[i].item()
        pos_idx = (labels == anchor_label).nonzero(as_tuple=True)[0]
        neg_idx = (labels != anchor_label).nonzero(as_tuple=True)[0]

        if len(pos_idx) < 2 or len(neg_idx) == 0:
            continue

        pos_idx = pos_idx[pos_idx != i]
        if len(pos_idx) == 0:
            continue

        p = pos_idx[0]
        n = neg_idx[0]

        anchors.append(embeddings[i])
        positives.append(embeddings[p])
        negatives.append(embeddings[n])

    if len(anchors) == 0:
        return None, None, None

    return (
        torch.stack(anchors).to(device),
        torch.stack(positives).to(device),
        torch.stack(negatives).to(device),
    )


# =========================
# 학습 루프
# =========================
def train_one_epoch_arc(model, loader, optimizer, device, use_amp, epoch, model_name="ResNetArc", use_triplet=False):
    model.train()
    ce_loss_fn = nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler("cuda") if use_amp else None
    total_loss, steps = 0.0, 0

    for i, batch in enumerate(loader):
        if batch is None:
            continue
        x, y = batch
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with torch.amp.autocast("cuda"):
                logits, emb = model(x, y)
                ce_loss = ce_loss_fn(logits, y)
                if use_triplet:
                    a, p, n = sample_triplets_in_batch(emb, y)
                    if a is not None:
                        tri_loss = triplet_loss_fn(a, p, n)
                        loss = ce_loss + TRIPLET_LAMBDA * tri_loss
                    else:
                        loss = ce_loss
                else:
                    loss = ce_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits, emb = model(x, y)
            ce_loss = ce_loss_fn(logits, y)
            if use_triplet:
                a, p, n = sample_triplets_in_batch(emb, y)
                if a is not None:
                    tri_loss = triplet_loss_fn(a, p, n)
                    loss = ce_loss + TRIPLET_LAMBDA * tri_loss
                else:
                    loss = ce_loss
            else:
                loss = ce_loss
            loss.backward()
            optimizer.step()

        total_loss += float(loss.item())
        steps += 1

        if (i + 1) % 50 == 0:
            print(
                f"[Epoch {epoch}][{model_name}] "
                f"step {i+1}/{len(loader)} "
                f"loss={total_loss/steps:.4f}"
            )

    return total_loss / max(1, steps)


# =========================
# ID 평가 (단일 모델)
# =========================
@torch.no_grad()
def eval_id(model, loader, device, num_classes, topk=TOPK):
    model.eval()
    probs_all, targets_all = [], []

    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)

        logits, _ = model(x, labels=None)  # labels=None -> margin 없이 logits
        probs = logits.softmax(dim=1).cpu().numpy()

        probs_all.append(probs)
        targets_all.append(y.numpy())

    if len(targets_all) == 0:
        return None

    probs_all = np.concatenate(probs_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)
    preds = probs_all.argmax(axis=1)

    out = {
        "acc": accuracy_score(targets_all, preds),
        "macro_f1": f1_score(targets_all, preds, average="macro"),
    }
    for k in topk:
        if k <= num_classes:
            out[f"top{k}"] = top_k_accuracy_score(targets_all, probs_all, k=k, labels=list(range(num_classes)))
    return out


# =========================
# ID 평가 (logit 앙상블)
# =========================
@torch.no_grad()
def eval_id_ensemble_logits(rnet, dnet, loader, device, num_classes, topk=TOPK, alpha=0.8):
    rnet.eval()
    dnet.eval()
    probs_all, targets_all = [], []

    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)

        logits_r, _ = rnet(x, labels=None)
        logits_d, _ = dnet(x, labels=None)

        logits = alpha * logits_r + (1.0 - alpha) * logits_d
        probs = logits.softmax(dim=1).cpu().numpy()

        probs_all.append(probs)
        targets_all.append(y.numpy())

    if len(targets_all) == 0:
        return None

    probs_all = np.concatenate(probs_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)
    preds = probs_all.argmax(axis=1)

    out = {
        "acc": accuracy_score(targets_all, preds),
        "macro_f1": f1_score(targets_all, preds, average="macro"),
    }
    for k in topk:
        if k <= num_classes:
            out[f"top{k}"] = top_k_accuracy_score(targets_all, probs_all, k=k, labels=list(range(num_classes)))
    return out


# =========================
# 임베딩 추출
# =========================
@torch.no_grad()
def extract_embeddings(model, loader, device):
    model.eval()
    all_embs, all_labels = [], []

    for batch in loader:
        if batch is None:
            continue
        x, y = batch
        x = x.to(device)

        # backbone만 써서 임베딩 추출
        emb = model.backbone(x)  # [B, EMB_DIM]
        all_embs.append(emb.cpu())
        all_labels.append(y)

    all_embs = torch.cat(all_embs, dim=0)   # [N, D]
    all_labels = torch.cat(all_labels, dim=0)  # [N]
    return all_embs, all_labels


# =========================
# Embedding-level Fusion
# =========================
def fuse_embeddings(e_r, e_d, alpha=0.8):
    """
    e_r, e_d: [N, D] (이미 normalize 되었다고 가정, 그래도 한 번 더 normalize 해도 OK)
    """
    e = alpha * e_r + (1.0 - alpha) * e_d
    e = F.normalize(e, dim=1)
    return e


# =========================
# Verification: EER 계산
# =========================
def compute_eer(scores, labels):
    """
    scores: 1D numpy array, 높을수록 "same speaker"
    labels: 1D numpy array, 1 = same, 0 = different
    """
    scores = np.asarray(scores)
    labels = np.asarray(labels)

    # score 기준 정렬
    idx = np.argsort(scores)
    scores = scores[idx]
    labels = labels[idx]

    P = labels.sum()
    N = len(labels) - P

    # genuine = 1, impostor = 0 기준
    # 뒤에서부터 threshold 내렸다고 생각하며 FNR, FPR 계산
    # labels[::-1] 순회
    t_labels = labels[::-1]
    t_scores = scores[::-1]

    fnr = np.cumsum(t_labels) / (P + 1e-12)
    fpr = np.cumsum(1 - t_labels) / (N + 1e-12)

    diff = np.abs(fnr - fpr)
    min_idx = diff.argmin()
    eer = (fnr[min_idx] + fpr[min_idx]) / 2.0
    thr = t_scores[min_idx]

    return float(eer), float(thr)


@torch.no_grad()
def build_verification_pairs(embs, labels, max_pairs_per_spk=50, max_impostor_pairs=10000):
    """
    embs   : [N, D] torch.Tensor
    labels : [N] torch.Tensor (class index)
    """
    embs = embs.cpu()
    labels = labels.cpu().numpy()
    N = embs.size(0)

    # speaker별 index 모으기
    spk_to_indices = {}
    for i, y in enumerate(labels):
        spk_to_indices.setdefault(y, []).append(i)

    # genuine pairs
    genuine_scores = []
    genuine_labels = []

    for spk, idxs in spk_to_indices.items():
        if len(idxs) < 2:
            continue
        # 조합 너무 많을 수 있으니 일부만 사용
        combs = list(combinations(idxs, 2))
        random.shuffle(combs)
        combs = combs[:max_pairs_per_spk]
        for i, j in combs:
            v1 = embs[i]
            v2 = embs[j]
            score = F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)).item()
            genuine_scores.append(score)
            genuine_labels.append(1)

    # impostor pairs (랜덤 샘플링)
    all_indices = list(range(N))
    impostor_scores = []
    impostor_labels = []

    attempts = 0
    while len(impostor_scores) < max_impostor_pairs and attempts < max_impostor_pairs * 10:
        i, j = random.sample(all_indices, 2)
        if labels[i] == labels[j]:
            attempts += 1
            continue
        v1 = embs[i]
        v2 = embs[j]
        score = F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)).item()
        impostor_scores.append(score)
        impostor_labels.append(0)
        attempts += 1

    scores = np.array(genuine_scores + impostor_scores, dtype=np.float32)
    pair_labels = np.array(genuine_labels + impostor_labels, dtype=np.int64)

    return scores, pair_labels


@torch.no_grad()
def eval_verification_from_embeddings(embs, labels):
    scores, pair_labels = build_verification_pairs(embs, labels)
    eer, thr = compute_eer(scores, pair_labels)
    return {"eer": eer, "thr": thr}


# =========================
# Train/Val/Test split
# =========================
def split_indices(n, seed=42, train=0.8, val=0.1):
    idxs = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(idxs)
    n_train = int(n * train)
    n_val = int(n * val)
    train_idxs = idxs[:n_train]
    val_idxs = idxs[n_train:n_train + n_val]
    test_idxs = idxs[n_train + n_val:]
    return train_idxs, val_idxs, test_idxs


# =========================
# 메인 실행
# =========================
def main():
    os.makedirs("./data", exist_ok=True)

    # dev-clean만 사용
    raw = LIBRISPEECH("./data", url="dev-clean", download=True)

    # 전체 화자 ID 수집 후, MAX_SPEAKERS만 사용
    all_speakers = sorted({int(spk) for _, _, _, spk, *_ in raw})
    speakers = all_speakers[:MAX_SPEAKERS]
    spk2idx = {spk: i for i, spk in enumerate(speakers)}
    num_classes = len(speakers)

    print(f"Using speakers: {num_classes}/{len(all_speakers)} | MAX_SPEAKERS={MAX_SPEAKERS}")

    full_ds = LibriSpeechSpeakerDataset(
        raw, spk2idx,
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        target_frames=TARGET_FRAMES
    )

    # 발화 단위 Train/Val/Test split
    tr_idx, va_idx, te_idx = split_indices(len(full_ds), seed=42, train=0.8, val=0.1)
    train_ds = Subset(full_ds, tr_idx)
    val_ds   = Subset(full_ds, va_idx)
    test_ds  = Subset(full_ds, te_idx)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = (device.type == "cuda")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=0, collate_fn=collate_skip_none)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=0, collate_fn=collate_skip_none)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=0, collate_fn=collate_skip_none)

    # 모델 준비
    rnet = ResNetArcFaceModel(num_classes=num_classes, emb_dim=EMB_DIM, s=30.0, m=0.2).to(device)
    dnet = DenseNetArcFaceModel(num_classes=num_classes, emb_dim=EMB_DIM, s=30.0, m=0.2).to(device)

    opt_r = optim.Adam(rnet.parameters(), lr=1e-3)
    opt_d = optim.Adam(dnet.parameters(), lr=1e-3)

    # =========================
    # 학습
    # =========================
    for ep in range(1, EPOCHS + 1):
        loss_r = train_one_epoch_arc(
            rnet, train_loader, opt_r, device, use_amp,
            epoch=ep, model_name="ResNetArc", use_triplet=False
        )

        loss_d = train_one_epoch_arc(
            dnet, train_loader, opt_d, device, use_amp,
            epoch=ep, model_name="DenseNetArc", use_triplet=USE_TRIPLET
        )

        val_r = eval_id(rnet, val_loader, device, num_classes, topk=TOPK)
        val_d = eval_id(dnet, val_loader, device, num_classes, topk=TOPK)
        val_e = eval_id_ensemble_logits(rnet, dnet, val_loader, device, num_classes, topk=TOPK, alpha=0.8)

        print(f"\n[Epoch {ep}/{EPOCHS}]")
        print(f"  ResNet   loss={loss_r:.4f}  val={val_r}")
        print(f"  DenseNet loss={loss_d:.4f}  val={val_d}")
        print(f"  Ensemble(logits) val={val_e}")

    # =========================
    # 최종 Test 평가
    # =========================
    test_r = eval_id(rnet, test_loader, device, num_classes, topk=TOPK)
    test_d = eval_id(dnet, test_loader, device, num_classes, topk=TOPK)
    test_e = eval_id_ensemble_logits(rnet, dnet, test_loader, device, num_classes, topk=TOPK, alpha=0.8)

    print("\n=== TEST RESULTS (Identification, dev-clean split) ===")
    print(f"ResNet   {test_r}")
    print(f"DenseNet {test_d}")
    print(f"Ensemble(logits) {test_e}")

    # =========================
    # Verification (EER) 평가
    # =========================
    emb_r, lab_r = extract_embeddings(rnet, test_loader, device)
    emb_d, lab_d = extract_embeddings(dnet, test_loader, device)

    # sanity check: label 동일해야 함
    assert torch.all(lab_r.eq(lab_d)), "ResNet/DenseNet test labels mismatch!"
    labels = lab_r

    ver_r = eval_verification_from_embeddings(emb_r, labels)
    ver_d = eval_verification_from_embeddings(emb_d, labels)
    emb_f = fuse_embeddings(emb_r, emb_d, alpha=0.8)
    ver_e = eval_verification_from_embeddings(emb_f, labels)

    print("\n=== TEST RESULTS (Verification, cosine EER) ===")
    print(f"ResNet   VER {ver_r}")
    print(f"DenseNet VER {ver_d}")
    print(f"Ensemble(emb) VER {ver_e}")


if __name__ == "__main__":
    main()

Using speakers: 40/40 | MAX_SPEAKERS=40
[Epoch 1][ResNetArc] step 50/271 loss=10.1360
[Epoch 1][ResNetArc] step 100/271 loss=9.3853
[Epoch 1][ResNetArc] step 150/271 loss=8.8575
[Epoch 1][ResNetArc] step 200/271 loss=8.5821
[Epoch 1][ResNetArc] step 250/271 loss=8.3183
