In [1]:
from models.resnet.resnet_se_34v2 import ResNetSE34V2
from torchaudio.transforms import MelSpectrogram
from torchaudio.functional import amplitude_to_DB
import torch
import torchaudio
import yaml
import librosa
import numpy as np
import os 
import pandas as pd 
from tqdm import tqdm

In [2]:
#Get all path from the audio directory and save it to a list
def get_all_audio_path(audio_dir):
    audio_path = []
    for root, dirs, files in os.walk(audio_dir):
        for file in files:
            if file.endswith(".wav"):
                audio_path.append(os.path.join(root, file))
    return audio_path

In [3]:
audio_paths = get_all_audio_path("audio/speaker_segments/")

In [4]:
def load_audio(file):
    EPS = 1e-8
    s, _ = librosa.load(file, sr=16000)
    amax = np.max(np.abs(s))
    factor = 1.0 / (amax + EPS)
    s = s * factor
    return s

In [5]:
with open('./models/resnet/config.yaml') as f:
    config = yaml.safe_load(f)

sd = torch.load('./models/weigths/resnetse34_epoch92_eer0.00931.pth')
model = ResNetSE34V2(nOut=256, n_mels= 80)
model.load_state_dict(sd)
model.eval()
torch.set_grad_enabled(False)

transform = MelSpectrogram(
    sample_rate= 16000,
    n_fft= 512,
    win_length= 400,
    hop_length= 160,
    window_fn=torch.hamming_window,
    n_mels= 80,
    f_min=20,
    f_max=7600,
    norm='slaney')


def embed_inference(audio_path, transform= transform, model = model):
    s = load_audio(audio_path)
    x = torch.tensor(s[None, :])
    x = transform(x)
    x = amplitude_to_DB(
        x, multiplier=10, amin= 1e-5, db_multiplier=0, top_db=75)

    feature = model(x[:, None, :, :])
    feature = torch.nn.functional.normalize(feature)
    return(feature)

Embedding size is 256, encoder SAP.


In [6]:
loss = torch.nn.CosineSimilarity(dim=1, eps=1e-8)

In [9]:

audio_paths_compar = audio_paths[1:]
ref_embed = embed_inference(audio_paths[0])

compar_tab = pd.DataFrame(columns = ['path_1', 'path_2', 'label', 'loss_man', 'loss_val'])
with tqdm(total=len(audio_paths_compar)) as pbar:
    for path_1 in audio_paths_compar :
        spk_1 = path_1.split('/')[-1][15]
        ref_embed = embed_inference(path_1)
        for path_2 in audio_paths_compar :
            spk_2 = path_2.split('/')[-1][15]
            if spk_1 == spk_2 :
                label = 1        
            else :
                label = 0

            
            embed = embed_inference(path_2)
            loss_val = loss(ref_embed, embed)
            #print(embed)
            #print(ref_embed)
            val = torch.dot(ref_embed.squeeze(0), embed.squeeze(0))
            compar_tab.loc[len(compar_tab)] = [path_1, path_2, label, val, loss_val]
            #print(label, val)
        pbar.update(1)
       

           
compar_tab.to_csv('compar_tab.csv', index=False)

100%|██████████| 29/29 [00:54<00:00,  1.87s/it]


In [79]:

audio_paths_compar = audio_paths[1:]
ref_embed = embed_inference(audio_paths[0])

compar_tab = pd.DataFrame(columns = ['path_1', 'path_2', 'score'])
for path_1 in audio_paths_compar :
    ref_embed = embed_inference(path_1)
    for path_2 in audio_paths_compar :
        embed = embed_inference(path_2)
        #print(embed)
        #print(ref_embed)
        loss_value = loss(ref_embed, embed)
       

        compar_tab.loc[len(compar_tab)] = [path_1, path_2, loss_value.item()]
        compar_tab.to_csv('compar_tab.csv', index=False)


In [69]:
test = torch.randn(1, 1)


In [70]:
print(test.item())

0.467571496963501
