In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
file_path = '/content/drive/MyDrive/UEF_SummerSchool_2024/SASV_BS2/'
embedding_dir = file_path + 'embeddings/'
spk_meta_dir = file_path + 'spk_meta/'
output_dir = file_path + 'exp_result/'

In [8]:
import random
import pickle as pk
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import roc_curve
from scipy.interpolate import interp1d
from scipy.optimize import brentq
import torch
import sys
import time
import os
import argparse
from torch import optim
import pickle as pk
from torch import nn

In [15]:

class SASV_Dataset(Dataset):
    def __init__(self, partition):
        self.part = partition
        self.embedding_dir = embedding_dir
        if self.part == "trn":
            self.spk_meta_dir = spk_meta_dir
            self.load_meta_information()
        else:
            sasv_trial = file_path + "protocols/ASVspoof2019.LA.asv." + self.part + ".gi.trl.txt"
            with open(sasv_trial, "r") as f:
                self.utt_list = f.readlines()
        self.load_embeddings()

    def load_meta_information(self):
        with open(self.spk_meta_dir + "spk_meta_trn.pk", "rb") as f:
            self.spk_meta = pk.load(f)

    def load_embeddings(self):
        # load saved countermeasures(CM) related preparations
        with open(self.embedding_dir + "cm_embd_" + self.part + ".pk", "rb") as f:
            self.cm_embd = pk.load(f)
        # load saved automatic speaker verification(ASV) related preparations
        with open(self.embedding_dir + "asv_embd_" + self.part + ".pk", "rb") as f:
            self.asv_embd = pk.load(f)
        if self.part in ["dev", "eval"]:
            # load speaker models for development and evaluation sets
            with open(self.embedding_dir + "spk_model_" + self.part + ".pk", "rb") as f:
                self.spk_model = pk.load(f)

    def __len__(self):
        if self.part == "trn":
            return len(self.cm_embd.keys())
        elif self.part in ["dev", "eval"]:
            return len(self.utt_list)

    def __getitem__(self, idx):
        return getattr(self, 'getitem_'+self.part)(idx)

    def getitem_trn(self, index):
        ans_type = random.randint(0, 1)
        if ans_type == 1:  # target
            spk = random.choice(list(self.spk_meta.keys()))
            enr, tst = random.sample(self.spk_meta[spk]["bonafide"], 2)
            nontarget_type = 0
            ans = 'target'
        elif ans_type == 0:  # nontarget
            nontarget_type = random.randint(1, 2)
            if nontarget_type == 1:  # zero-effort nontarget
                spk, ze_spk = random.sample(list(self.spk_meta.keys()), 2)
                enr = random.choice(self.spk_meta[spk]["bonafide"])
                tst = random.choice(self.spk_meta[ze_spk]["bonafide"])
                ans = 'nontarget'

            if nontarget_type == 2:  # spoof nontarget
                spk = random.choice(list(self.spk_meta.keys()))
                if len(self.spk_meta[spk]["spoof"]) == 0:
                    while True:
                        spk = random.choice(list(self.spk_meta.keys()))
                        if len(self.spk_meta[spk]["spoof"]) != 0:
                            break
                enr = random.choice(self.spk_meta[spk]["bonafide"])
                tst = random.choice(self.spk_meta[spk]["spoof"])
                ans = 'spoof'
        else:
            raise ValueError

        return self.asv_embd[enr], self.asv_embd[tst], \
               self.cm_embd[tst], ans_type, ans

    def getitem_dev(self, index):
        line = self.utt_list[index]
        spkmd, key, _, ans = line.strip().split(" ")
        ans_type = int(ans == "target")

        return self.spk_model[spkmd], self.asv_embd[key], \
               self.cm_embd[key], ans_type, ans

    def getitem_eval(self, index):
        line = self.utt_list[index]
        spkmd, key, _, ans = line.strip().split(" ")
        ans_type = int(ans == "target")

        return self.spk_model[spkmd], self.asv_embd[key], \
               self.cm_embd[key], ans_type, ans



In [6]:
def get_all_EERs(preds, keys):
    """
    Calculate all three EERs used in the SASV Challenge 2022.
    preds and keys should be pre-calculated using dev or eval protocol in
    either 'protocols/ASVspoof2019.LA.asv.dev.gi.trl.txt' or
    'protocols/ASVspoof2019.LA.asv.eval.gi.trl.txt'

    :param preds: list of scores in tensor
    :param keys: list of keys where each element should be one of
    ['target', 'nontarget', 'spoof']
    """
    sasv_labels, sv_labels, spf_labels = [], [], []
    sv_preds, spf_preds = [], []

    for pred, key in zip(preds, keys):
        if key == "target":
            sasv_labels.append(1)
            sv_labels.append(1)
            spf_labels.append(1)
            sv_preds.append(pred)
            spf_preds.append(pred)

        elif key == "nontarget":
            sasv_labels.append(0)
            sv_labels.append(0)
            sv_preds.append(pred)

        elif key == "spoof":
            sasv_labels.append(0)
            spf_labels.append(0)
            spf_preds.append(pred)
        else:
            raise ValueError(
                f"should be one of 'target', 'nontarget', 'spoof', got:{key}"
            )

    fpr, tpr, _ = roc_curve(sasv_labels, preds, pos_label=1)
    sasv_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)

    fpr, tpr, _ = roc_curve(sv_labels, sv_preds, pos_label=1)
    sv_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)

    fpr, tpr, _ = roc_curve(spf_labels, spf_preds, pos_label=1)
    spf_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)

    return sasv_eer, sv_eer, spf_eer

In [9]:
class Baseline2(nn.Module):
    def __init__(self, num_nodes=[256, 128, 64]):
        super().__init__()
        self.name = "Baseline2"
        self.enh_DNN = self._make_layers(544, num_nodes)
        self.fc_out = torch.nn.Linear(num_nodes[-1], 2, bias = False)
        self.loss = torch.nn.CrossEntropyLoss(
            weight=torch.FloatTensor([0.1, 0.9])
        )

    def forward(self, embd_asv_enr, embd_asv_tst, embd_cm_tst):
        asv_enr = torch.squeeze(embd_asv_enr, 1) # shape: (bs, 192)
        asv_tst = torch.squeeze(embd_asv_tst, 1) # shape: (bs, 192)
        cm_tst = torch.squeeze(embd_cm_tst, 1) # shape: (bs, 160)

        x = self.enh_DNN(torch.cat([asv_enr, asv_tst, cm_tst], dim = 1)) # shape: (bs, 32)
        x = self.fc_out(x)  # (bs, 2)

        return x

    def _make_layers(self, in_dim, l_nodes):
        l_fc = []
        for idx in range(len(l_nodes)):
            if idx == 0:
                l_fc.append(torch.nn.Linear(in_features = in_dim,
                    out_features = l_nodes[idx]))
            else:
                l_fc.append(torch.nn.Linear(in_features = l_nodes[idx-1],
                    out_features = l_nodes[idx]))
            l_fc.append(torch.nn.LeakyReLU(negative_slope = 0.3))
        return torch.nn.Sequential(*l_fc)

    def calc_loss(self, preds, labels):
        return self.loss(preds, labels)

In [18]:

batch_size = 1024
lr = 0.001
epoch_size = 50
weight_decay = 1e-5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trn_set = SASV_Dataset("trn")
trn_loader = DataLoader(trn_set, batch_size=batch_size, shuffle=True,
                        drop_last=False, pin_memory=True)

model = Baseline2()
model.to(device)
params = list(model.parameters())
optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay)
model.train()
min_eer = 1e4
for epoch in range(epoch_size):
    preds, keys = [], []
    trn_loss = 0
    tot_batch_trn = len(trn_loader)
    for num, data_minibatch in enumerate(trn_loader, 0):
        asv1, asv2, cm1, ans, key = data_minibatch
        with torch.set_grad_enabled(True):
            if torch.cuda.is_available():
                asv1 = asv1.to(device)
                asv2 = asv2.to(device)
                cm1 = cm1.to(device)
                ans = ans.to(device)

            pred = model(asv1, asv2, cm1)
            nloss = model.calc_loss(pred, ans)
            trn_loss += nloss
            optimizer.zero_grad()
            nloss.backward()
            optimizer.step()
            pred = torch.softmax(pred, dim=-1)
            preds.append(pred)
            keys.extend(list(key))

    preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy()

    trn_loss = (trn_loss/tot_batch_trn).item()
    sasv_eer_trn, sv_eer_trn, spf_eer_trn = get_all_EERs(
        preds=preds, keys=keys)

    print("\nEpoch-%d Trn: Loss: %0.5f, sasv_eer_trn: %0.3f, sv_eer_trn: %0.3f, spf_eer_trn: %0.3f" % (
        epoch+1,  trn_loss, 100 * sasv_eer_trn, 100 * sv_eer_trn, 100 * spf_eer_trn))

    model.eval()
    dev_set = SASV_Dataset("dev")
    dev_loader = DataLoader(dev_set, batch_size=batch_size, shuffle=True,
                            drop_last=False, pin_memory=True)
    with torch.no_grad():
        preds, keys = [], []
        for num, data_minibatch in enumerate(dev_loader, 0):
            asv1, asv2, cm1, ans, key = data_minibatch
            if torch.cuda.is_available():
                asv1 = asv1.to(device)
                asv2 = asv2.to(device)
                cm1 = cm1.to(device)
                ans = ans.to(device)

            pred = model(asv1, asv2, cm1)
            nloss = model.calc_loss(pred, ans)
            pred = torch.softmax(pred, dim=-1)
            preds.append(pred)
            keys.extend(list(key))

        preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy()

        sasv_eer_dev, sv_eer_dev, spf_eer_dev = get_all_EERs(
            preds=preds, keys=keys)

        print("Epoch-%d Dev: sasv_eer_dev: %0.3f, sv_eer_dev: %0.3f, spf_eer_dev: %0.3f" % (
            epoch+1, 100 * sasv_eer_dev, 100 * sv_eer_dev, 100 * spf_eer_dev))

    if sasv_eer_dev < min_eer:
        torch.save(model.state_dict(), os.path.join(
            output_dir, "%s_best.pt" % (model.name)))
        min_eer = sasv_eer_dev
        best_epoch = epoch
        # print(f'Epoch-{epoch+1} Min sasv_eer: %{min_eer*100:.4f}')

print(
    f'\nMin sasv_eer_dev: %{min_eer*100:.4f} obtained in epoch {best_epoch+1}')


model.load_state_dict(torch.load(os.path.join(
    output_dir, "%s_best.pt" % (model.name))))
model.eval()
dev_set = SASV_Dataset("dev")
dev_loader = DataLoader(dev_set, batch_size=len(dev_set), shuffle=False,
                        drop_last=False, pin_memory=True)

with torch.no_grad():
    preds, keys = [], []
    for num, data_minibatch in enumerate(dev_loader, 0):
        asv1, asv2, cm1, ans, key = data_minibatch
        if torch.cuda.is_available():
            asv1 = asv1.to(device)
            asv2 = asv2.to(device)
            cm1 = cm1.to(device)
            ans = ans.to(device)

        pred = model(asv1, asv2, cm1)
        nloss = model.calc_loss(pred, ans)
        pred = torch.softmax(pred, dim=-1)
        preds.append(pred)
        keys.extend(list(key))

    preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy()

    sasv_eer_dev, sv_eer_dev, spf_eer_dev = get_all_EERs(
        preds=preds, keys=keys)
    print("\nEpoch-%d Dev: sasv_eer_dev: %0.3f, sv_eer_dev: %0.3f, spf_eer_dev: %0.3f" % (
        epoch+1, 100 * sasv_eer_dev, 100 * sv_eer_dev, 100 * spf_eer_dev))


eval_set = SASV_Dataset("eval")
eval_loader = DataLoader(eval_set, batch_size=len(eval_set), shuffle=False,
                          drop_last=False, pin_memory=True)

with torch.no_grad():
    preds, keys = [], []
    for num, data_minibatch in enumerate(eval_loader, 0):
        asv1, asv2, cm1, ans, key = data_minibatch
        if torch.cuda.is_available():
            asv1 = asv1.to(device)
            asv2 = asv2.to(device)
            cm1 = cm1.to(device)
            ans = ans.to(device)

        pred = model(asv1, asv2, cm1)
        nloss = model.calc_loss(pred, ans)
        pred = torch.softmax(pred, dim=-1)
        preds.append(pred)
        keys.extend(list(key))

    preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy()

    sasv_eer_eval, sv_eer_eval, spf_eer_eval = get_all_EERs(
        preds=preds, keys=keys)
    print("\nEpoch-%d Eval: sasv_eer_eval: %0.3f, sv_eer_eval: %0.3f, spf_eer_eval: %0.3f" % (
        epoch+1, 100 * sasv_eer_eval, 100 * sv_eer_eval, 100 * spf_eer_eval))



Epoch-1 Trn: Loss: 0.26636, sasv_eer_trn: 33.441, sv_eer_trn: 49.906, spf_eer_trn: 8.697
Epoch-1 Dev: sasv_eer_dev: 15.431, sv_eer_dev: 44.175, spf_eer_dev: 0.067

Epoch-2 Trn: Loss: 0.19859, sasv_eer_trn: 30.950, sv_eer_trn: 47.405, spf_eer_trn: 0.032
Epoch-2 Dev: sasv_eer_dev: 15.201, sv_eer_dev: 45.215, spf_eer_dev: 0.072

Epoch-3 Trn: Loss: 0.18470, sasv_eer_trn: 24.617, sv_eer_trn: 37.632, spf_eer_trn: 0.032
Epoch-3 Dev: sasv_eer_dev: 12.988, sv_eer_dev: 39.511, spf_eer_dev: 0.135

Epoch-4 Trn: Loss: 0.10950, sasv_eer_trn: 10.964, sv_eer_trn: 17.596, spf_eer_trn: 0.047
Epoch-4 Dev: sasv_eer_dev: 12.534, sv_eer_dev: 36.685, spf_eer_dev: 0.152

Epoch-5 Trn: Loss: 0.06055, sasv_eer_trn: 6.023, sv_eer_trn: 9.728, spf_eer_trn: 0.143
Epoch-5 Dev: sasv_eer_dev: 12.500, sv_eer_dev: 33.599, spf_eer_dev: 0.135

Epoch-6 Trn: Loss: 0.02594, sasv_eer_trn: 2.100, sv_eer_trn: 3.385, spf_eer_trn: 0.080
Epoch-6 Dev: sasv_eer_dev: 11.342, sv_eer_dev: 31.293, spf_eer_dev: 0.139

Epoch-7 Trn: Loss: 