Для начала зададим все необходимые константы (ключ для wandb, название проекта, путь к файлам, гиперпараметры и так далее)

In [None]:
KEY = ""
PROJECT_NAME = "antispoof"
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 3e-4
MAX_LR = 1e-3
SEED = 31
SAMPLE_RATE = 16000
N_FFT = 512
WIN_LENGTH = 400
HOP_LENGTH = 160
MAX_FRAMES = 200
NUM_WORKERS = 2
PIN_MEMORY = True

TRAIN_AUDIO_DIR = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_train/flac"
DEV_AUDIO_DIR = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_dev/flac"
EVAL_AUDIO_DIR = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_eval/flac"
TRAIN_PROTOCOL = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
DEV_PROTOCOL = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"
EVAL_PROTOCOL = "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt"
CACHE_DIR = "/tmp/spec_cache"

Импортируем необходимые библиотеки

In [None]:
import os
import random
import math
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio import transforms as T
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import pathlib
import hashlib

Подключимся к системе wandb и зададим исходные параметры. Установим сид, чтобы можно было тестировать различные гипотези без страха потерять текущие изменения

In [None]:
wandb.login(key=KEY)

wandb.init(project=PROJECT_NAME, config={
    "batch_size": BATCH_SIZE,
    "num_epochs": NUM_EPOCHS,
    "learning_rate": LEARNING_RATE,
    "max_lr": MAX_LR,
    "seed": SEED,
    "sample_rate": SAMPLE_RATE
})

os.makedirs(CACHE_DIR, exist_ok=True)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed(SEED)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mfamaxth[0m ([33mfamaxth-hse-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Напишем основную архитектуру сети (LightCNN + MFM)

In [None]:
class MFM(nn.Module):
    def forward(self, x):
        x1, x2 = torch.chunk(x, 2, dim=1)
        return torch.max(x1, x2)

class LightCNN(nn.Module):

    def __init__(self, num_classes=2, freq_bins=None, time_frames=MAX_FRAMES):
        super().__init__()

        if freq_bins is None: freq_bins = N_FFT // 2 + 1

        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(64), MFM(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 96, kernel_size=1), nn.BatchNorm2d(96), MFM(),
            nn.Conv2d(48, 96, kernel_size=3, padding=1), nn.BatchNorm2d(96), MFM(), nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 128, kernel_size=1), nn.BatchNorm2d(128), MFM(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), MFM(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 256, kernel_size=1), nn.BatchNorm2d(256), MFM(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), MFM(), nn.MaxPool2d(2, 2),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, 1, freq_bins, time_frames)
            out = self.net(dummy)
            flatten_size = out.shape[1] * out.shape[2] * out.shape[3]

        self.fc1 = nn.Linear(flatten_size, 256)
        self.mfm_fc = MFM()
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.mfm_fc(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

Это кусок кода, который отвечает за подготовку и аугментацию звуковых данных перед подачей в нейросеть


In [None]:
def cmvn(tensor, eps=1e-6):
    # Эта функция нужна для нормализации спектрограммы
    mean = tensor.mean(dim=1, keepdim=True)
    std = tensor.std(dim=1, keepdim=True) + eps
    return (tensor - mean) / std

def get_log_spec(tmp, sample_rate, n_fft=N_FFT, win_length=WIN_LENGTH, hop_length=HOP_LENGTH):
    # Логарифм спектрограммы
    if tmp.dim() == 1:
        tmp = tmp.unsqueeze(0)

    window = torch.hann_window(win_length).to(tmp.device)
    stft = torch.stft(tmp, n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, return_complex=True)

    spec = stft.abs().squeeze(0)
    spec = spec + 1e-6

    log_spec = torch.log(spec)
    log_spec = cmvn(log_spec)

    return log_spec

def get_cache_path(cache_dir, audio_path, n_fft, win_length, hop_length, max_frames):
    # Формирует уникальный путь для кэширования спектрограмм
    key = f"{audio_path}|{n_fft}|{win_length}|{hop_length}|{max_frames}"
    h = hashlib.md5(key.encode()).hexdigest()
    return os.path.join(cache_dir, f"{h}.pt")

class AudioAugment():
    # Аугментация

    def __init__(self, sample_rate=SAMPLE_RATE):

        self.sample_rate = sample_rate
        self.time_mask = T.TimeMasking(time_mask_param=35)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=15)

    def __call__(self, tmp):

        if random.random() < 0.5:
            gain_db = random.uniform(-3.0, 3.0)
            tmp = T.Vol(gain=gain_db, gain_type='db')(tmp)

        if random.random() < 0.5:
            noise = torch.randn_like(tmp) * 0.0035
            tmp = tmp + noise

        return tmp

Теперь напишем Dataset для преобразования данных в удобный для нас формат

In [None]:
class AudioDataSet(Dataset):

    def __init__(self, file_list, labels, root_dir, augment=False, cache_dir=CACHE_DIR, n_fft=N_FFT, win_length=WIN_LENGTH, hop_length=HOP_LENGTH, max_frames=MAX_FRAMES):
        self.file_list = file_list
        self.labels = labels
        self.root_dir = root_dir
        self.augment = augment
        self.augmenter = AudioAugment() if augment else None
        self.cache_dir = cache_dir
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.max_frames = max_frames
        self.time_mask = T.TimeMasking(time_mask_param=35)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=15)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        label = self.labels[idx]
        path = os.path.join(self.root_dir, filename)

        cache_path = get_cache_path(self.cache_dir, path, self.n_fft, self.win_length, self.hop_length, self.max_frames)

        if os.path.exists(cache_path) and not self.augment:
            spec = torch.load(cache_path)
        else:
            tmp, sr = torchaudio.load(path)
            if tmp.dim() > 1:
                tmp = tmp.mean(dim=0, keepdim=True)
            if self.augment and self.augmenter is not None:
                tmp = self.augmenter(tmp)
            spec = get_log_spec(tmp.squeeze(0), sr, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length)
            if self.augment:
                spec = self.freq_mask(spec)
                spec = self.time_mask(spec)
            if spec.size(1) < self.max_frames:
                spec = F.pad(spec, (0, self.max_frames - spec.size(1)))
            else:
                spec = spec[:, :self.max_frames]
            if not self.augment:
                try:
                    torch.save(spec, cache_path)
                except Exception:
                    pass

        spec = spec.unsqueeze(0)

        return spec.float(), label

Функции для подсчета EER

In [None]:
def compute_det_curve(target_scores, nontarget_scores):
    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate(
        (np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - \
        (np.arange(1, n_scores + 1) - tar_trial_sums)
    frr = np.concatenate(
        (np.atleast_1d(0), tar_trial_sums / target_scores.size))
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
                          nontarget_scores.size))
    thresholds = np.concatenate(
        (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
    return frr, far, thresholds

def compute_eer(bonafide_scores, other_scores):
    frr, far, thresholds = compute_det_curve(bonafide_scores, other_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

Теперь необходимо задать программу, которая будет подсчитывать итоговый EER для модели (на данных eval)

In [None]:
def load_protocol_file(path):
    files = []
    labels = []

    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            files.append(parts[1] + '.flac')
            if parts[-1] == 'bonafide':
                labels.append(1)
            else:
                labels.append(0)

    return files, labels

def evaluate():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    freq_bins = N_FFT // 2 + 1
    model = LightCNN(num_classes=2, freq_bins=freq_bins, time_frames=MAX_FRAMES).to(device)
    model.load_state_dict(torch.load("best_model.pth", map_location=device))
    model.eval()

    eval_files, eval_labels = load_protocol_file(EVAL_PROTOCOL)
    spoof_scores, bona_scores = [], []

    for filename, label in tqdm(zip(eval_files, eval_labels), desc="Evaluating (eval set)", total=len(eval_files)):
        tmp, sr = torchaudio.load(os.path.join(EVAL_AUDIO_DIR, filename))

        if tmp.dim() > 1:
            tmp = tmp.mean(dim=0, keepdim=True)

        spec = get_log_spec(tmp.squeeze(0), sr).unsqueeze(0).unsqueeze(0)

        if spec.size(3) < MAX_FRAMES:
            spec = F.pad(spec, (0, MAX_FRAMES - spec.size(3)))
        else:
            spec = spec[:, :, :, :MAX_FRAMES]

        spec = spec.to(device)

        with torch.no_grad():
            prob = torch.softmax(model(spec), dim=1)[0, 1].item()

        if label == 1:
            bona_scores.append(prob)
        else:
            spoof_scores.append(prob)

    eer, thr = compute_eer(np.array(bona_scores), np.array(spoof_scores))
    print(f"✅ FINAL EER on EVAL: {eer*100:.4f}% (thr={thr:.5f})")
    wandb.log({"final_eval_eer": eer})
    return eer, thr

Функция для обучения

In [None]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    train_files, train_labels = load_protocol_file(TRAIN_PROTOCOL)
    dev_files, dev_labels = load_protocol_file(DEV_PROTOCOL)

    train_set = AudioDataSet(train_files, train_labels, TRAIN_AUDIO_DIR, augment=True, cache_dir=CACHE_DIR)
    dev_set = AudioDataSet(dev_files, dev_labels, DEV_AUDIO_DIR, augment=False, cache_dir=CACHE_DIR)

    train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    dev_loader = DataLoader(dev_set, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    freq_bins = N_FFT // 2 + 1
    model = LightCNN(num_classes=2, freq_bins=freq_bins, time_frames=MAX_FRAMES).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

    steps_per_epoch = max(1, len(train_loader))
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR, steps_per_epoch=steps_per_epoch, epochs=NUM_EPOCHS)

    best_eer = float("inf")

    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
        for x, y in pbar:
            x = x.to(device, non_blocking=PIN_MEMORY)
            y = y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=total_loss/(pbar.n+1), lr=optimizer.param_groups[0]['lr'])
        avg_train_loss = total_loss / len(train_loader)
        wandb.log({"train_loss": avg_train_loss, "epoch": epoch, "lr": optimizer.param_groups[0]['lr']})
        model.eval()
        spoof_scores, bona_scores = [], []
        all_labels, all_preds = [], []
        with torch.no_grad():
            for x, y in tqdm(dev_loader, desc=f"Epoch {epoch} [Eval]"):
                x = x.to(device, non_blocking=PIN_MEMORY)
                outputs = model(x)
                probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                preds = outputs.argmax(dim=1).cpu().numpy()
                for score, label in zip(probs, y.numpy()):
                    (bona_scores if label == 1 else spoof_scores).append(score)
                all_labels.extend(y.numpy())
                all_preds.extend(preds)

        eer, thr = compute_eer(np.array(bona_scores), np.array(spoof_scores))
        print(f"Epoch {epoch}: DEV EER = {eer*100:.4f}%, thr={thr:.5f}")
        wandb.log({"dev_eer": eer, "epoch": epoch})

        cm = confusion_matrix(all_labels, all_preds, normalize='true')
        disp = ConfusionMatrixDisplay(cm, display_labels=["spoof", "bona"])
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f"Confusion Matrix (Epoch {epoch})")
        plt.savefig(f"conf_matrix_epoch_{epoch}.png")
        plt.close()

        if eer < best_eer:
            best_eer = eer
            torch.save(model.state_dict(), "best_model.pth")
            print(f"New best DEV EER: {best_eer*100:.4f}%. Model saved.")
            wandb.log({"best_dev_eer": best_eer})

    print("Training finished. Best DEV EER: %.4f%%" % (best_eer * 100))

Функция для создания csv с посчитанным предсказаниями относительно входных данных

In [None]:
def make_csv(file_out="arantitov.csv"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    freq_bins = N_FFT // 2 + 1

    model = LightCNN(num_classes=2, freq_bins=freq_bins, time_frames=MAX_FRAMES).to(device)
    model.load_state_dict(torch.load("best_model.pth", map_location=device))
    model.eval()

    eval_files = [f for f in os.listdir(EVAL_AUDIO_DIR) if f.endswith(".flac")]
    predictions = []
    for filename in tqdm(sorted(eval_files), desc="Predicting (eval folder)"):
        tmp, sr = torchaudio.load(os.path.join(EVAL_AUDIO_DIR, filename))
        if tmp.dim() > 1:
            tmp = tmp.mean(dim=0, keepdim=True)

        spec = get_log_spec(tmp.squeeze(0), sr).unsqueeze(0).unsqueeze(0)

        if spec.size(3) < MAX_FRAMES:
            spec = F.pad(spec, (0, MAX_FRAMES - spec.size(3)))
        else:
            spec = spec[:, :, :, :MAX_FRAMES]

        spec = spec.to(device)

        with torch.no_grad():
            prob = torch.softmax(model(spec), dim=1)[0, 1].item()

        predictions.append((filename.replace(".flac", ""), prob))

    df = pd.DataFrame(predictions, columns=["utt_id", "score"])
    df.to_csv(file_out, index=False)
    print(f"Predictions saved to {file_out}")

    return file_out

Запустим и обучим модель

In [None]:
train()
evaluate()
make_csv()
print("Done.")

Device: cuda


Epoch 0 [Train]: 100%|██████████| 794/794 [02:33<00:00,  5.19it/s, loss=0.536, lr=0.000152]
Epoch 0 [Eval]: 100%|██████████| 777/777 [02:37<00:00,  4.93it/s]


Epoch 0: DEV EER = 1.6450%, thr=0.01690
New best DEV EER: 1.6450%. Model saved.


Epoch 1 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.21it/s, loss=0.295, lr=0.000437]
Epoch 1 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.27it/s]


Epoch 1: DEV EER = 0.7075%, thr=0.43233
New best DEV EER: 0.7075%. Model saved.


Epoch 2 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.194, lr=0.00076] 
Epoch 2 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.20it/s]


Epoch 2: DEV EER = 0.3936%, thr=0.12950
New best DEV EER: 0.3936%. Model saved.


Epoch 3 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.186, lr=0.000971]
Epoch 3 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.17it/s]


Epoch 3: DEV EER = 0.7782%, thr=0.24767


Epoch 4 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.168, lr=0.000994]
Epoch 4 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.20it/s]


Epoch 4: DEV EER = 0.2809%, thr=0.14520
New best DEV EER: 0.2809%. Model saved.


Epoch 5 [Train]: 100%|██████████| 794/794 [02:31<00:00,  5.23it/s, loss=0.154, lr=0.00095] 
Epoch 5 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.26it/s]


Epoch 5: DEV EER = 0.3936%, thr=0.10192


Epoch 6 [Train]: 100%|██████████| 794/794 [02:31<00:00,  5.23it/s, loss=0.149, lr=0.000866]
Epoch 6 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.26it/s]


Epoch 6: DEV EER = 0.1172%, thr=0.07743
New best DEV EER: 0.1172%. Model saved.


Epoch 7 [Train]: 100%|██████████| 794/794 [02:31<00:00,  5.23it/s, loss=0.14, lr=0.00075]  
Epoch 7 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.26it/s]


Epoch 7: DEV EER = 0.0774%, thr=0.17835
New best DEV EER: 0.0774%. Model saved.


Epoch 8 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.137, lr=0.000611]
Epoch 8 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.15it/s]


Epoch 8: DEV EER = 0.0796%, thr=0.17596


Epoch 9 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.133, lr=0.000462]
Epoch 9 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.24it/s]


Epoch 9: DEV EER = 0.0465%, thr=0.15808
New best DEV EER: 0.0465%. Model saved.


Epoch 10 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.21it/s, loss=0.129, lr=0.000317]
Epoch 10 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.19it/s]


Epoch 10: DEV EER = 0.0045%, thr=0.21449
New best DEV EER: 0.0045%. Model saved.


Epoch 11 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.126, lr=0.000188]
Epoch 11 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.25it/s]


Epoch 11: DEV EER = 0.0398%, thr=0.10571


Epoch 12 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.124, lr=8.68e-5] 
Epoch 12 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.21it/s]


Epoch 12: DEV EER = 0.0398%, thr=0.13282


Epoch 13 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.22it/s, loss=0.123, lr=2.22e-5]
Epoch 13 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.20it/s]


Epoch 13: DEV EER = 0.0398%, thr=0.09471


Epoch 14 [Train]: 100%|██████████| 794/794 [02:32<00:00,  5.21it/s, loss=0.123, lr=4.04e-9]
Epoch 14 [Eval]: 100%|██████████| 777/777 [00:36<00:00, 21.19it/s]


Epoch 14: DEV EER = 0.0376%, thr=0.08144
Training finished. Best DEV EER: 0.0045%


Evaluating (eval set): 100%|██████████| 71237/71237 [14:51<00:00, 79.95it/s]


✅ FINAL EER on EVAL: 5.6582% (thr=0.89429)


Predicting (eval folder): 100%|██████████| 71933/71933 [08:01<00:00, 149.43it/s]


Predictions saved to arantitov.csv
Done.
