# Subnet Evaluation

In [1]:
import os, sys, random, warnings, time
warnings.filterwarnings("ignore")
import torch
import pandas

sys.path.append('/workspace/projects')

from torch.utils.data import DataLoader
from sugar.transforms import LogMelFbanks
from sugar.models.dynamictdnn import tdnn8m2g
from sugar.models import SpeakerModel, WrappedModel, veri_validate, batch_forward
from sugar.database import Utterance, AugmentedUtterance
from sugar.data.voxceleb1 import veriset
from sugar.data.voxceleb2 import veritrain
from sugar.data.augmentation import augset
from sugar.scores import score_cohorts, asnorm
from sugar.vectors import extract_vectors
from sugar.metrics import calculate_mindcf, calculate_eer
from sugar.utils.utility import bn_state_dict, load_bn_state_dict

def eval_veri(test_loader, network, p_target=0.01, device="cpu", vectors=None):
    eer, dcf, vec, scs = veri_validate(test_loader, network, p_target=p_target, device=device, ret_info=True, vectors=vectors)
    scs = pandas.DataFrame({'score': scs, 'enroll': test_loader.dataset.enrolls, 'test': test_loader.dataset.tests})
    labs = test_loader.dataset.labels
    eer = eer[0] * 100
    dcf = dcf[0]
    return eer, dcf, vec, scs

def eval_asnorm(labs, vec, scs, cohorts, p_target=0.01):
    cohorts_o = score_cohorts(cohorts, vec)
    asso = asnorm(scs, cohorts_o)
    eer_o_asnorm = calculate_eer(labs, asso)[0] * 100
    dcf_o_asnorm = calculate_mindcf(labs, asso, p_target=p_target)[0]
    return eer_o_asnorm, dcf_o_asnorm

device = 'cuda:1'

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



## Load Dataset

- Train set
- Test set

In [2]:
# vox1_root = "/path/to/voxceleb1/"
# vox2_root = "/path/to/voxceleb2/"

vox1_root = "/workspace/datasets/voxceleb/voxceleb1/"
vox2_root = "/workspace/datasets/voxceleb/voxceleb2/"

In [3]:
# vox2_train = '/path/to/train_list.txt'
vox2_train = '/workspace/datasets/voxceleb/Vox2/train_list.txt'
train, spks = veritrain(vox2_train, rootdir=vox2_root, num_samples=64000)

random.shuffle(train.datalst)
train.datalst = train.datalst[:6000]

aug_wav = augset(num_samples=64000)
trainset = AugmentedUtterance(train, spks, augment=aug_wav, mode='v2+')
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=5, drop_last=True)

The number of speakers is 5994


In [4]:
veritesto = "veri_test2.txt"
veri_testo, veri_teste, veri_testh, wav_files = veriset(
    test2=veritesto, all2=None, hard2=None, rootdir=vox1_root, num_samples=64000, num_eval=2)
testo_loader = DataLoader(veri_testo, batch_size=1, shuffle=False, num_workers=0)

## Evaluate different subnets

- $a_\text{max}$: (4, [512, 512, 512, 512, 512], [5, 5, 5, 5, 5], 1536)
- $a_\text{Kmin}$: (4, [512, 512, 512, 512, 512], [1, 1, 1, 1, 1], 1536)
- $a_\text{Dmin}$: (2, [512, 512, 512], [1, 1, 1], 1536)
- $a_\text{C1min}$: (2, [256, 256, 256], [1, 1, 1], 768)
- $a_\text{C2min}$: (2, [128, 128, 128], [1, 1, 1], 384)

In [7]:
transform = LogMelFbanks(80)
modelarch = tdnn8m2g(80, 192)
model = SpeakerModel(modelarch, transform=transform)
model = WrappedModel(model)

# supernet_path = '/path/to/supernet_checkpoint'
# supernet_path = '/workspace/projects/sugar/examples/nas/exps/exp3/supernet_kernel_width1_width2_depth/checkpoint000064.pth.tar'
supernet_path = '/workspace/projects/sugar/examples/nas/exps/exp3/supernet_depth_kernel_width1_width2/checkpoint000064.pth.tar'
state_dict = torch.load(supernet_path, map_location='cpu')
print(model.load_state_dict(state_dict['state_dict'], strict=False))

model = model.to(device)
model.eval()

import copy
model_bak = copy.deepcopy(model)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['module.__L__.W'])


In [10]:
configs = [
    (4, [512, 512, 512, 512, 512], [5, 5, 5, 5, 5], 1536),
    (4, [512, 512, 512, 512, 512], [1, 1, 1, 1, 1], 1536),
    (2, [512, 512, 512], [1, 1, 1], 1536),
    (2, [256, 256, 256], [1, 1, 1], 768),
    (2, [128, 128, 128], [1, 1, 1], 384),
]

for config in configs[1:2]:
    model.module.__S__ = model_bak.module.__S__.clone(config)
    bn_path = os.path.join(os.path.dirname(supernet_path), f"{config}.bn.pth")
    if os.path.exists(bn_path):
        load_bn_state_dict(model.module.__S__, torch.load(bn_path, map_location="cpu"))
        print(f"loaded state dict from saved batch norm {bn_path}")
        time.sleep(1)
    else:
        batch_forward(train_loader, model, device=device)
        torch.save(bn_state_dict(model.module.__S__), bn_path)
        print(f"saved batch norm state dict {bn_path}")
        time.sleep(1)
    eero, dcfo, veco, scso = eval_veri(testo_loader, model, device=device)
    print(f'subnet: {config}\nEvaluate on Vox1-O: * EER / DCF {eero:.2f}% / {dcfo:.3f}') 

Forward Model: 100%|██████████| 187/187 [00:08<00:00, 21.22it/s]


saved batch norm state dict /workspace/projects/sugar/examples/nas/exps/exp3/supernet_depth_kernel_width1_width2/(4, [512, 512, 512, 512, 512], [1, 1, 1, 1, 1], 1536).bn.pth


Extract Vectors: 100%|██████████| 4708/4708 [01:20<00:00, 58.41it/s]
Compute Scores: 100%|██████████| 37611/37611 [00:06<00:00, 6011.82it/s]


subnet: (4, [512, 512, 512, 512, 512], [1, 1, 1, 1, 1], 1536)
Evaluate on Vox1-O: * EER / DCF 3.37% / 0.326


## Results among different progressive orders

### Subnets

- $a_\text{max}$: (4, [512, 512, 512, 512, 512], [5, 5, 5, 5, 5], 1536)
- $a_\text{Kmin}$: (4, [512, 512, 512, 512, 512], [1, 1, 1, 1, 1], 1536)
- $a_\text{Dmin}$: (2, [512, 512, 512], [1, 1, 1], 1536)
- $a_\text{C1min}$: (2, [256, 256, 256], [1, 1, 1], 768)
- $a_\text{C2min}$: (2, [128, 128, 128], [1, 1, 1], 384)

### Supernets

- kernel->depth->width: /workspace/projects/sugar/examples/nas/exps/exp3/width2/phase2/width2.torchparams
- kernel->width->depth: /workspace/projects/sugar/examples/nas/exps/exp3/supernet_kernel_width1_width2_depth/checkpoint000064.pth.tar
- depth->kernel->width: /workspace/projects/sugar/examples/nas/exps/exp3/supernet_depth_kernel_width1_width2/checkpoint000064.pth.tar

| Progressive Order | $a_\text{max}$ | $a_\text{Kmin}$ | $a_\text{Dmin}$ | $a_\text{C1min}$ | $a_\text{C2min}$ |
|:---|:---:|:---:|:---:|:---:|:---:|
| Table V kernel->depth->width | 1.44 / 0.163 | 3.54 / 0.344 | 3.58 / 0.334 | 3.98 / 0.360 | 5.29 / 0.478 |
| kernel->width->depth | 1.49 / 0.153 | 3.52 / 0.330 | 3.82 / 0.369 | 3.99 / 0.373 | 5.32 / 0.463 |
| depth->kernel->width | 1.48 / 0.144 | 3.43 / 0.325 | 3.56 / 0.328 | 3.99 / 0.389 | 5.39 / 0.474 |