In [None]:
!apt-get install -y sox libsox-dev libsox-fmt-all

In [None]:
from torchvision import datasets
import torchvision
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
import os
import random
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy


In [None]:
import os
from pathlib import Path

def collect_audio_list(root_fpt, root_edge, root_cut):
    data = []

    for sub in ["false_no", "false_yes", "true"]:
        sub_path = Path(root_fpt) / sub
        if not sub_path.exists():
            continue

        label = 1 if sub == "true" else 0

        for speaker_dir in sub_path.iterdir():
            if not speaker_dir.is_dir():
                continue
            for wav in speaker_dir.glob("*.wav"):
                data.append((str(wav), label))

    edge_root = Path(root_edge)
    for speaker_dir in edge_root.iterdir():
        if not speaker_dir.is_dir():
            continue
        for sub in speaker_dir.iterdir():
            if not sub.is_dir():
                continue
            label = 1 if sub.name.lower() == "true" else 0
            for mp3 in sub.glob("*.mp3"):
                data.append((str(mp3), label))

    cut_root = Path(root_cut)
    for mp3 in cut_root.glob("*.mp3"):
        data.append((str(mp3), 0))

    print(f"Tổng cộng {len(data)} file — {sum(l for _, l in data)} label=1, {len(data)-sum(l for _, l in data)} label=0")
    return data

root_fpt = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/FPT.AI"
root_edge = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/edge_voices_16k"
root_cut = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/cut_sound"

data_train = collect_audio_list(root_fpt, root_edge, root_cut)

random.seed(35)
random.shuffle(data_train)

n_total = len(data_train)
n_valid = int(0.05 * n_total)

data_valid = data_train[:n_valid]
data_train = data_train[n_valid:]

In [None]:
def collect_test_audio_list(root_test):
    data = []
    root_test = Path(root_test)

    for speaker_dir in root_test.iterdir():
        if not speaker_dir.is_dir():
            continue
        for sub_dir in speaker_dir.iterdir():
            if not sub_dir.is_dir():
                continue

            label = 1 if sub_dir.name.lower() == "true" else 0

            for audio_file in sub_dir.glob("*.mp3"):
                data.append((str(audio_file), label))
            for audio_file in sub_dir.glob("*.wav"):
                data.append((str(audio_file), label))

    print(f"Tổng cộng {len(data)} file — {sum(l for _, l in data)} label=1, {len(data)-sum(l for _, l in data)} label=0")
    return data


root_test = "/kaggle/input/test-aip419/Datatest"
data_test = collect_test_audio_list(root_test)


In [None]:
def swish(x): return x * torch.sigmoid(x)

class FeedForward(nn.Module):
    def __init__(self, dim, expansion=4, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * expansion)
        self.fc2 = nn.Linear(dim * expansion, dim)
        self.dropout = nn.Dropout(dropout)
        self.act = swish

    def forward(self, x):
        # x: [B, T, H]
        out = self.fc1(x)
        out = self.act(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.dropout(out)
        return out

class ConvModule(nn.Module):
    def __init__(self, dim, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.pointwise1 = nn.Linear(dim, dim*2)
        padding = (kernel_size - 1) // 2
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
        self.bn = nn.BatchNorm1d(dim)
        self.pointwise2 = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        self.act = swish

    def forward(self, x):
        # x: [B, T, H]
        out = self.layer_norm(x)
        out = self.pointwise1(out)
        out = F.glu(out, dim=-1)
        out = out.transpose(1, 2)
        out = self.depthwise(out)
        out = self.bn(out)
        out = out.transpose(1, 2)
        out = self.act(out)
        out = self.pointwise2(out)
        out = self.dropout(out)
        return out

class BinaryGate(nn.Module):
    def __init__(self, H):
        super().__init__()
        self.logits = nn.Linear(H, 2)  # [keep, skip]

    def forward(self, x_mean, tau=1.0, training=True, thresh=0.5):
        # x_mean: [B, H]
        logits = self.logits(x_mean)
        probs = F.softmax(logits, dim=-1)
        if training:
            g = F.gumbel_softmax(logits, tau=tau, hard=True, dim=-1)
            keep = g[:, 0].unsqueeze(-1)  # [B,1]
            return keep, probs
        else:
            keep = (probs[:, 0] > thresh).float().unsqueeze(-1)
            return keep, probs


class ConformerBlockWithGates(nn.Module):
    def __init__(self, dim, nhead=4, ff_expansion=4, conv_kernel=31, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.ff1 = FeedForward(dim, expansion=ff_expansion, dropout=dropout)
        self.attn_ln = nn.LayerNorm(dim)
        self.mhsa = nn.MultiheadAttention(embed_dim=dim, num_heads=nhead, batch_first=True, dropout=dropout)
        self.conv = ConvModule(dim, kernel_size=conv_kernel, dropout=dropout)
        self.ff2 = FeedForward(dim, expansion=ff_expansion, dropout=dropout)
        
        self.g_ff1 = BinaryGate(dim)
        self.g_attn = BinaryGate(dim)
        self.g_conv = BinaryGate(dim)
        self.g_ff2 = BinaryGate(dim)
        
        self.ff_scale = 0.5
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, x, tau=1.0, training=True):
        # x: [B, T, H]
        B, T, H = x.size()
        out = x

        x_mean = out.mean(dim=1)
        keep, p_ff1 = self.g_ff1(x_mean, tau=tau, training=training)
        module_out = self.ff1(out)
        keep = keep.unsqueeze(1)
        out = out + keep * (self.ff_scale * module_out)

        x_mean = out.mean(dim=1)
        keep, p_attn = self.g_attn(x_mean, tau=tau, training=training)
        q = self.attn_ln(out)
        attn_out, _ = self.mhsa(q, q, q, need_weights=False)
        keep = keep.unsqueeze(1)
        out = out + keep * self.dropout(attn_out)

        x_mean = out.mean(dim=1)
        keep, p_conv = self.g_conv(x_mean, tau=tau, training=training)
        conv_out = self.conv(out)
        keep = keep.unsqueeze(1)
        out = out + keep * conv_out

        x_mean = out.mean(dim=1)
        keep, p_ff2 = self.g_ff2(x_mean, tau=tau, training=training)
        module_out = self.ff2(out)
        keep = keep.unsqueeze(1)
        out = out + keep * (self.ff_scale * module_out)

        out = self.layer_norm(out)
        probs = {"ff1": p_ff1, "attn": p_attn, "conv": p_conv, "ff2": p_ff2}
        return out, probs


class SubsampleConv(nn.Module):
    def __init__(self, in_channels=1, out_dim=80, n_mels=40):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_dim, kernel_size=(3,3), stride=(1,4), padding=(1,1), bias=False)
        self.bn1 = nn.BatchNorm2d(out_dim)

        self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=(3,3), stride=(1,2), padding=(1,1), bias=False)
        self.bn2 = nn.BatchNorm2d(out_dim)

        self.conv_collapse = nn.Conv2d(out_dim, out_dim, kernel_size=(n_mels, 1), stride=(1,1), bias=False)
        self.bn_collapse = nn.BatchNorm2d(out_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.act(self.bn1(self.conv1(x)))
        out = self.act(self.bn2(self.conv2(out)))
        out = self.act(self.bn_collapse(self.conv_collapse(out)))
        out = out.squeeze(2)
        out = out.permute(0, 2, 1)
        return out


class ConformerEncoderWithGates(nn.Module):
    def __init__(self, input_n_mels=40, hidden_dim=80, num_blocks=8,
                 nhead=4, ff_expansion=4, conv_kernel=31, dropout=0.1,
                 agg=(2, 4, 6)):
        super().__init__()
        self.agg = set(agg)
        self.subsample = SubsampleConv(in_channels=1, out_dim=hidden_dim, n_mels=input_n_mels)
        self.blocks = nn.ModuleList([
            ConformerBlockWithGates(hidden_dim, nhead=nhead, ff_expansion=ff_expansion,
                                    conv_kernel=conv_kernel, dropout=dropout)
            for _ in range(num_blocks)
        ])

    def forward(self, x, tau=1.0, training=True):
        out = self.subsample(x)               # [B, T', H]
        gate_probs = []

        for i, blk in enumerate(self.blocks, 1):
            out, probs = blk(out, tau=tau, training=training)
            gate_probs.append(probs)

        return out, gate_probs


class ConformerWithGates(nn.Module):
    def __init__(self, num_classes=None, n_mels=40, hidden_dim=80, num_blocks=8,
                 nhead=4, ff_expansion=4, conv_kernel=31, dropout=0.1, emb_dim=192):
        super().__init__()
        self.encoder = ConformerEncoderWithGates(
            input_n_mels=n_mels,
            hidden_dim=hidden_dim,
            num_blocks=num_blocks,
            nhead=nhead,
            ff_expansion=ff_expansion,
            conv_kernel=conv_kernel,
            dropout=dropout
        )
        self.emb_dim = emb_dim
        self.post = nn.Linear(hidden_dim, emb_dim)
        self.bn = nn.BatchNorm1d(emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(emb_dim, num_classes)

    def forward(self, x, tau=1.0, training=True):
        feats, gate_probs = self.encoder(x, tau=tau, training=training)
        mean = feats.mean(1)

        emb = self.bn(self.post(mean))
        emb = F.normalize(torch.nan_to_num(emb), p=2, dim=1)
        emb = self.dropout(emb)

        logits = self.classifier(emb)
        return emb, logits, gate_probs

In [None]:
import torch as pt
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm import tqdm
import time

def train_one_epo(loader, model, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for xb, yb in tqdm(loader, desc="Train", leave=False):
        xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

        optimizer.zero_grad()
        _, logits, _ = model(xb)   # (emb, logits, gate_probs)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

        probs = pt.sigmoid(logits).detach().cpu().numpy()
        preds = (probs >= 0.5).astype(int)
        all_preds.extend(preds)
        all_labels.extend(yb.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    return total_loss / len(loader.dataset), acc


def valid_at_epo(loader, model, criterion, device, threshold=0.5):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    all_probs = []
    batch_times = []

    with pt.no_grad():
        for xb, yb in tqdm(loader, desc="Valid", leave=False):
            xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

            start = time.time()
            _, logits, _ = model(xb)
            end = time.time()
            batch_times.append(end - start)

            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)

            probs = pt.sigmoid(logits).cpu().numpy()
            preds = (probs >= threshold).astype(int)

            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(yb.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float('nan')

    avg_forward_time = sum(batch_times) / len(batch_times) if batch_times else 0.0
    return total_loss / len(loader.dataset), acc, auc, f1, avg_forward_time


def valid_at_epo_t(loader, model, criterion, device, step=0.01):
    model.eval()
    total_loss = 0
    all_labels, all_probs = [], []
    batch_times = []

    with pt.no_grad():
        for xb, yb in tqdm(loader, desc="Test", leave=False):
            xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

            start = time.time()
            _, logits, _ = model(xb)
            end = time.time()
            batch_times.append(end - start)

            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)

            probs = pt.sigmoid(logits).cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(yb.cpu().numpy())

    thresholds = [i * step for i in range(int(1 / step) + 1)]
    best_s, best_f1, best_thr, best_acc = 0, 0, 0.5, 0

    for thr in thresholds:
        preds = (pt.tensor(all_probs) >= thr).int().numpy()
        f1 = f1_score(all_labels, preds)
        acc = accuracy_score(all_labels, preds)
        if f1 * acc > best_s:
            best_s, best_f1, best_thr, best_acc = f1 * acc, f1, thr, acc

    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float('nan')

    avg_forward_time = sum(batch_times) / len(batch_times) if batch_times else 0.0
    return total_loss / len(loader.dataset), best_acc, auc, best_f1, best_thr, avg_forward_time


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import random
from torch.utils.data import Dataset


def make_log_mel_transform(sample_rate=16000,
                           n_fft=400,        # 25 ms 16k
                           hop_length=160,   # 10 ms 16k
                           n_mels=40):
    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=n_fft,
        window_fn=torch.hann_window,
        n_mels=n_mels,
        center=True,
        power=2.0,
    )
    def transform(waveform):
        
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        mel_spec = mel(waveform) + 1e-4            # [1, n_mels, T]
        log_mel = torchaudio.functional.amplitude_to_DB(
            mel_spec, multiplier=10, amin=1e-4,
            db_multiplier=torch.log10(torch.tensor(1e-4))
        )           # [1, n_mels, T]
        # normalize per-sample
        m = log_mel.mean()
        s = log_mel.std(unbiased=False)
        log_mel = (log_mel - m) / (s + 1e-9)
        return log_mel

    return transform



class SpeakerClassificationFeatureDataset(Dataset):
    def __init__(self, file_list, fixed_len=40000, one_hot=False,
                 augment=False, sample_rate=16000, n_mels=40):

        self.fixed_len = fixed_len
        self.one_hot = one_hot
        self.augment = augment
        self.sample_rate = sample_rate
        self.n_mels = n_mels

        self.files = [f for f, _ in file_list]
        self.labels = [lbl for _, lbl in file_list]
        self.num_speakers = len(set(self.labels))

        # transform log-mel
        self.logmel_fn = make_log_mel_transform(
            sample_rate=sample_rate,
            n_fft=400,
            hop_length=160,
            n_mels=n_mels
        )

    def __len__(self):
        return len(self.files)


    
    def augment_wave(self, wav):
        sr = self.sample_rate

        # noise
        noise_level = random.uniform(0.00, 0.015)
        wav = wav + torch.randn_like(wav) * noise_level

        # gain
        gain = random.uniform(0.5, 1.25)
        wav = wav * gain
        
        # time-stretch & pitch
        speed = random.uniform(0.85, 1.75)
        pitch = random.uniform(-3, 3)  # semitones
        reverb_wet = random.uniform(0.01, 0.6) * 100
        room_size = random.uniform(0, 100)

        effects = [
            ['tempo', str(speed)],
            ['pitch', str(pitch * 100)],  # Sox expects cents
            ['reverb', str(reverb_wet), str(reverb_wet), str(room_size)]
        ]
        wav, _ = torchaudio.sox_effects.apply_effects_tensor(wav.unsqueeze(0), sr, effects)
        wav = wav.mean(dim=0)

        return wav


    
    def augment_spec(self, spec):
        # spec: [1, n_mels, T]
        freq_mask_param = random.randint(2, 12)
        time_mask_param = random.randint(4, 28)
        freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask_param)
        time_mask = torchaudio.transforms.TimeMasking(time_mask_param=time_mask_param)
        spec = freq_mask(spec)
        spec = time_mask(spec)
        return spec


    
    def __getitem__(self, idx):
        path = self.files[idx]
        label = self.labels[idx]

        info = torchaudio.info(path)
        total_len = info.num_frames

        if total_len > self.fixed_len:
            start = random.randint(0, total_len - self.fixed_len)
            wav, sr = torchaudio.load(path, frame_offset=start, num_frames=self.fixed_len)
        else:
            wav, sr = torchaudio.load(path)
        
        wav = wav.mean(dim=0)
        if sr != self.sample_rate:
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav.unsqueeze(0)).squeeze(0)

        # aug
        if self.augment:
            wav = self.augment_wave(wav)

        # pad
        cur_len = wav.size(0)
        if cur_len < self.fixed_len:
            pad_total = self.fixed_len - cur_len
            pad_left = random.randint(0, pad_total)
            pad_right = pad_total - pad_left
            wav = F.pad(wav, (pad_left, pad_right))
        else:
            start = random.randint(0, cur_len - self.fixed_len)
            wav = wav[start:start+self.fixed_len]
        
        log_mel = self.logmel_fn(wav)  # [1, n_mels, T]
        # aug spec
        if self.augment:
            log_mel = self.augment_spec(log_mel)

        if self.one_hot:
            lbl = torch.zeros(self.num_speakers)
            lbl[label] = 1.0
        else:
            lbl = torch.tensor(label, dtype=torch.long)

        return log_mel, lbl



def collate_fn_classification(batch):
    feats = torch.stack([b[0] for b in batch], dim=0)
    labels = torch.tensor([b[1] for b in batch], dtype=torch.long)
    return feats, labels


In [None]:
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_s = 256

tra_ds = SpeakerClassificationFeatureDataset(data_train, fixed_len=40000, one_hot=False, augment=True)
tra_dl = DataLoader(tra_ds, batch_size=batch_s, shuffle=True, collate_fn=collate_fn_classification)
val_ds = SpeakerClassificationFeatureDataset(data_valid, fixed_len=40000, one_hot=False, augment=True)
val_dl = DataLoader(val_ds, batch_size=batch_s, shuffle=False, collate_fn=collate_fn_classification)
tes_ds = SpeakerClassificationFeatureDataset(data_test, fixed_len=40000, one_hot=False, augment=False)
tes_dl = DataLoader(tes_ds, batch_size=batch_s, shuffle=False, collate_fn=collate_fn_classification)

start_epoch = 0

In [None]:
model = ConformerWithGates(num_classes=1, emb_dim=200, num_blocks=3, ff_expansion=4, hidden_dim=40).to(device)
optimizer = pt.optim.Adam(model.parameters(), lr=4e-4)
criterion = pt.nn.BCEWithLogitsLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.4)

In [None]:
for epoch in range(start_epoch, 25):

    print(epoch)

    train_loss, train_acc = train_one_epo(tra_dl, model, optimizer, criterion, device)
    valid_loss, valid_acc, valid_auc, valid_f1, valid_time = valid_at_epo(val_dl, model, criterion, device, 0.5)

    print(f"  Train -> loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Valid -> loss: {valid_loss:.4f}, acc: {valid_acc:.4f}, auc: {valid_auc:.4f}, f1: {valid_f1:.4f}, time: {valid_time:.4f}s")

    optimizer_info = {
        'param_groups': [
            {k: v for k, v in group.items() if k in ['lr', 'betas', 'weight_decay']}
            for group in optimizer.param_groups
        ]
    }
    ckpt_path = f"igc_xs_{epoch}.pt"
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "criterion_state": criterion.state_dict(),
        "optimizer_info": optimizer_info
    }, ckpt_path)

In [None]:
_, test_acc, test_auc, test_f1, _ = valid_at_epo(tes_dl, model, criterion, device, 0.5)
_, test_b_acc, test_b_auc, test_b_f1, test_b_thr, _ = valid_at_epo_t(tes_dl, model, criterion, device)
print(f"  Test -> acc: {test_acc:.4f}, auc: {test_auc:.4f}, f1: {test_f1:.4f}")
print(f"  Test best  -> acc: {test_b_acc:.4f}, auc: {test_b_auc:.4f}, f1: {test_b_f1:.4f}, th: {test_b_thr:.2f}")