In [None]:
import os
import sys
import glob
import time
from collections import OrderedDict, Counter

from IPython.display import Audio, display
import numpy as np
import pandas as pd
import tqdm
from ipywidgets import widgets
import soundfile as sf
from sklearn.metrics import roc_curve, accuracy_score, confusion_matrix, roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from SpeakerNet import ModelTrainer, SpeakerNet,WrappedModel
from DatasetLoader import train_dataset_loader,test_dataset_loader,train_dataset_sampler

In [None]:
import yaml

In [None]:
print(torch.cuda.is_available())
print(torch.__version__)
print(torch.backends.cudnn.enabled)

In [None]:
# 현재 Setup 되어있는 device 확인
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())
print(torch.cuda.get_device_name(device))

# GPU 할당 변경하기
GPU_NUM = 1 # 원하는 GPU 번호 입력
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device) # change allocation of current GPU
print ('Current cuda device ', torch.cuda.current_device()) # check

# Additional Infos
if device.type == 'cuda':
    print(torch.cuda.get_device_name(GPU_NUM))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(GPU_NUM)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(GPU_NUM)/1024**3,1), 'GB')


# 1.Training

In [None]:
n_gpus = torch.cuda.device_count()
print('n_gpus',n_gpus)
distributed = True

0,None, args

In [None]:
#Fast ResNet-34
gpu = 0
max_frames = 200
nPerSpeaker = 2 
batch_size = 400
trainfunc = 'angleproto'
model = 'ResNetSE34L'
encoder_type = 'SAP'

s = SpeakerNet(model=model, optimizer='adam', trainfunc=trainfunc, nPerSpeaker=1,
                nOut=512,nClasses=5994)
s = WrappedModel(s).cuda(gpu)

it = 1
eers = [100]

In [None]:
train_list = './pairlist/train_list.txt'
max_frames = 400
train_path = '/direct/scp/voxceleb2'
musan_path = '/shared/musan'
rir_path ='/shared/rir' 

In [None]:
train_dataset = train_dataset_loader(train_list = train_list, augment = False, musan_path = musan_path,
                                    rir_path = rir_path, max_frames = max_frames, train_path = train_path)

In [None]:
train_sampler = train_dataset_sampler(train_dataset,
                                     nPerSpeaker=1,max_seg_per_spk=10,
                                     batch_size=1, distributed=False, seed=100)

In [None]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=200,
        num_workers=5,
        sampler=train_sampler,
        pin_memory=False,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )

# Inference

In [None]:
def loadParameters(model,path):
    print('(loadParameters)')
    self_state = model.state_dict()
    loaded_state = torch.load(path)

    if len(loaded_state.keys()) == 1 and "model" in loaded_state:
        loaded_state = loaded_state["model"]
        newdict = {}
        delete_list = []
        for name, param in loaded_state.items():
            new_name = "__S__."+name
            newdict[new_name] = param
            delete_list.append(name)
        loaded_state.update(newdict)
        for name in delete_list:
            del loaded_state[name]
    for name, param in loaded_state.items():
        origname = name
        if name not in self_state:
            name = name.replace("module.", "")

            if name not in self_state:
                print("{} is not in the model.".format(origname))
                continue

        if self_state[name].size() != loaded_state[origname].size():
            print("Wrong parameter length: {}, model: {}, loaded: {}".format(origname, self_state[name].size(), loaded_state[origname].size()))
            continue

        self_state[name].copy_(param)
    return self_state

## ResNet

In [None]:
conf= {
    'model':'ResNetSE34V2',
    'optimizer' : 'adam',
    'trainfunc' : 'softmaxproto',
    'nPerSpeaker' : 1,
    'num_eval': 10,
    'eval_frames' : 400,
    'n_mels' : 64,
    'nOut' : 512,
    'nClasses' :5994,
    'encoder_type' : 'ASP'
}

In [None]:
spn = SpeakerNet(**conf)

params = loadParameters(spn,'./pretrained/baseline_v2_smproto.model')
spn.load_state_dict(params)
spn.cuda()

## RawNet3

In [None]:
param_path = './configs'
with open(os.path.join(param_path,'RawNet3_AAM.yaml')) as f:
    conf = yaml.load(f,Loader = yaml.FullLoader)

In [None]:
spn = SpeakerNet(**conf)

params = loadParameters(spn,'./models/weights/RawNet3/model.pt')
spn.load_state_dict(params)
spn.cuda()

In [None]:
test_list = './pairlist/veri_test.txt'
test_path = '/direct/scp/voxceleb1'

In [None]:
pytorch_total_params = sum(p.numel() for p in spn.parameters())
print('Total parameters: ',pytorch_total_params)
print('Test list',test_list)

In [None]:
import itertools

In [None]:
spn.eval()

lines       = []
files       = []
feats       = {}
tstart      = time.time()

## Read all lines
with open(test_list) as f:
    lines = f.readlines()
## Get a list of unique file names
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
setfiles = list(set(files))
setfiles.sort()

In [None]:
## Define test data loader
test_dataset = test_dataset_loader(setfiles, test_path, num_eval=10, eval_frames=400)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=5,
    drop_last=False,
    sampler = None
)

In [None]:
print_interval = 100
for idx, data in tqdm.notebook.tqdm(enumerate(test_loader), total=len(test_loader)):
    inp1                = data[0][0].cuda()
    with torch.no_grad():
        ref_feat            = spn(inp1).detach().cpu()
    feats[data[1][0]]   = ref_feat
    telapsed            = time.time() - tstart
    
    if idx % print_interval == 0:
        sys.stdout.write("\rReading {:d} of {:d}: {:.2f} Hz, embedding size {:d}".format(idx, test_loader.__len__(), idx / telapsed, ref_feat.size()[1]))

In [None]:
ref_feat.shape

In [None]:
all_scores = [];
all_labels = [];
all_trials = [];
tstart = time.time()

## Read files and compute all scores
for idx, line in tqdm.notebook.tqdm(enumerate(lines),total = len(lines)):

    data = line.split();

    ## Append random label if missing
    if len(data) == 2: 
        data = [random.randint(0,1)] + data

    ref_feat = feats[data[1]].cuda()
    com_feat = feats[data[2]].cuda()

    if spn.__L__.test_normalize:
        ref_feat = F.normalize(ref_feat, p=2, dim=1)
        com_feat = F.normalize(com_feat, p=2, dim=1)
    
    dist1 = torch.cdist(ref_feat.reshape(num_eval, -1), com_feat.reshape(num_eval, -1)).detach().cpu().numpy()
    dist2 = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();
    score = -1 * np.mean(dist1);

    all_scores.append(score);  
    all_labels.append(int(data[0]));
    all_trials.append(data[1]+" "+data[2])

    if idx % print_interval == 0:
        telapsed = time.time() - tstart
        sys.stdout.write("\rComputing {:d} of {:d}: {:.2f} Hz".format(idx,len(lines),idx/telapsed));
        sys.stdout.flush();

In [None]:
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(all_labels, all_scores, pos_label=1)
fnr = 1 - tpr

eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
thresh = interp1d(fpr, thresholds)(eer)
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(eer*100))

In [None]:
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve

In [None]:
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
fig = plt.figure(figsize=(5,2))
plt.hist(all_scores,bins=50, color='royalblue')
plt.xlabel('score')
plt.ylabel('freq')
plt.vlines(thresh,0,3000,color='black')


In [None]:
all_preds = np.where(all_scores<thresh, 0, 1)
all_labels = np.array(all_labels)

In [None]:
confusion_matrix(y_true=all_labels, y_pred=all_preds)

# analysis


In [None]:
indices = np.where((all_preds==0)&(all_labels==1))[0]
FN = sorted([all_trials[idx] for idx in indices])
len(FN)

In [None]:
tn_lines = [lines[i] for i in np.where((all_labels !=all_preds) & (all_labels==0) )[0]]
fn_lines = [lines[i] for i in np.where((all_labels !=all_preds) & (all_labels==1) )[0]]

In [None]:
print_interval = 100
tn_rawdist = [];

tstart = time.time()
## Read files and compute all scores
for idx, line in enumerate(tn_lines):#enumerate(lines):
    data = line.split();
    try:
        ref_feat = feats[data[1]].cuda()
        com_feat = feats[data[2]].cuda()
    except:
        continue

    if spn.__L__.test_normalize:
        ref_feat = F.normalize(ref_feat, p=2, dim=1)
        com_feat = F.normalize(com_feat, p=2, dim=1)
        
    rawdist = np.asarray(np.abs((ref_feat - com_feat).cpu()))#F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();
    tn_rawdist.append(rawdist); 

    if idx % print_interval == 0:
        telapsed = time.time() - tstart
        sys.stdout.write("\rComputing {:d} of {:d}: {:.2f} Hz".format(idx,len(tn_lines),idx/telapsed));
        sys.stdout.flush();
        
    #gc.collect()
    #torch.cuda.empty_cache()
tn_rawdist = np.asarray(tn_rawdist)
print(tn_rawdist.shape)

In [None]:
ax1 = plt.subplot(3,2,1)
ax1.plot(tn_rawdist[0].mean(axis=0))
ax1.set_ylim(0,0.3)

ax1 = plt.subplot(3,2,2)
ax1.plot(tn_rawdist[1].mean(axis=0))
ax1.set_ylim(0,0.3)

In [None]:
fig = plt.figure(figsize=(15,3))
plt.errorbar(range(0,512,1),tn_rawdist.mean(axis=1).mean(axis=0), tn_rawdist.std(axis=1).std(axis=0),fmt='-o')
#plt.xlim(0,100)

In [None]:
import shap

In [None]:
## Get a list of unique file names
tnfiles = sum([x.strip().split()[-2:] for x in tn_lines],[])
tnfiles = list(set(tnfiles))
tnfiles.sort()

In [None]:
import gc

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
## Define test data loader
test_dataset = test_dataset_loader(tnfiles, test_path, num_eval=1, eval_frames=400)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=50,
    shuffle=False,
    num_workers=1,
    drop_last=False,
)

In [None]:
batch = next(iter(test_loader))
tnaudios, _ = batch
background = tnaudios.to(device)
e =shap.DeepExplainer(spn, background)

In [None]:
tnaudios.shape

In [None]:
shap.DeepExplainer?

In [None]:
tnaudios.reshape(50,64240)

In [None]:
s = e.shap_values(background)

In [None]:
display(audio1)

In [None]:
select = widgets.Select(options=FN, layout=widgets.Layout(width='500px', height='500px'))

audio1 = widgets.Audio(autoplay=False, loop=False)
audio2 = widgets.Audio(autoplay=False, loop=False)

def on_change(change):
    a1, a2 = change['new'].split()
    audio1.value = widgets.Audio.from_file(os.path.join(test_path, a1)).value
    audio2.value = widgets.Audio.from_file(os.path.join(test_path, a2)).value

select.observe(on_change, 'value')
select.value=FN[0]