Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trained model for RawNet2_modified and RawNet2 #21

Closed
hdubey opened this issue Oct 20, 2021 · 4 comments
Closed

trained model for RawNet2_modified and RawNet2 #21

hdubey opened this issue Oct 20, 2021 · 4 comments

Comments

@hdubey
Copy link

hdubey commented Oct 20, 2021

Hi,
Can you share trained models for RawNet2 and RawNet2_modified for quick testing. Do you have script for extracting speaker embeddings from a wav file?

@ac-alpha
Copy link

@hdubey you can find the weights for RawNet2 here.

You can use this script for quick testing and getting the embeddings. Make sure that you have this model definition in the directory you are running this script.

from tqdm import tqdm
from collections import OrderedDict

import os
import argparse
import json
import numpy as np
import glob
import pickle

import torch
import torch.nn as nn
from torch.utils import data

from dataloader import *
from model_RawNet2 import RawNet2
from parser import get_args
from trainer import *
from utils import *
from model_RawNet2_original_code import *
from pydub import AudioSegment

load_model_dir = "Pre-trained_model/rawnet2_best_weights.pt"
test_wav_path1 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10270/5r0dWxy17C8/00001.wav"
test_wav_path2 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10278/d6WJf6TOoIQ/00001.wav"

test_wav_path3 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/01dfn2spqyE/00001.m4a"
test_wav_path4 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/8_a6O3vdlU0/00021.m4a"

def cos_sim(a,b) :
    return np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b))

def read_wav_and_get_clip_tensor(test_wav_path, nb_samp, window_size, wav_file = True):
    
    if not wav_file:
        X = AudioSegment.from_file(test_wav_path)
        X = X.get_array_of_samples()
        X = np.array(X)
    else:
        X, _ = sf.read(test_wav_path)
    X = X.astype(np.float64)
    X = _normalize_scale(X).astype(np.float32)
    X = X.reshape(1,-1)
    
    nb_time = X.shape[1]
    list_X = []
    nb_time = X.shape[1]
    if nb_time < nb_samp:
        nb_dup = int(nb_samp / nb_time) + 1
        list_X.append(np.tile(X, (1, nb_dup))[:, :nb_samp][0])
    elif nb_time > nb_samp:
        step = nb_samp - window_size
        iteration = int( (nb_time - window_size) / step ) + 1
        for i in range(iteration):
            if i == 0:
                list_X.append(X[:, :nb_samp][0])
            elif i < iteration - 1:
                list_X.append(X[:, i*step : i*step + nb_samp][0])
            else:
                list_X.append(X[:, -nb_samp:][0])
    else :
        list_X.append(X[0])
    return torch.from_numpy(np.asarray(list_X))

def get_embedding_from_clip_tensor(clip_tensor, model, device):
    model.eval()
    
    with torch.set_grad_enabled(False):
        #1st, extract speaker embeddings.
        l_embeddings = []
        l_code = []
        mbatch = clip_tensor
        mbatch = mbatch.unsqueeze(1)
#         print("Batch size = {}".format(mbatch.size()))
        for batch in mbatch:
            batch = batch.to(device)
            code = model(x = batch, is_test=True)
#             print("Code size = {}".format(code.size()))
            l_code.extend(code.cpu().numpy())
        embedding = np.mean(l_code, axis=0)
#         print("Embedding shape = {}".format(embedding.shape))
        return embedding

def _normalize_scale(x):
    '''
    Normalize sample scale alike SincNet.
    '''
    return x/np.max(np.abs(x))

def main_test():
    #parse arguments
    args = get_args()
    
    wav_path = args.wav_path
    save_path = args.sav_path
    direc_level = args.direc_level
    wav_file = True if args.wav_file==1 else False
    
    ## Number of speakers in VoxCeleb2 dataset. 
    ## Not used in computing embeddings but should still be there. 
    ## Do not comment this.
    args.model['nb_classes'] = 6112 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #device setting
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    print('Device: {}'.format(device))
    
    model = RawNet(args.model, device).to(device)
    model.load_state_dict(torch.load(load_model_dir))
    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    nb_samp = args.model["nb_samp"]
    window_size = args.window_size
    print('nb_params: {}'.format(nb_params))
    
    X1 = read_wav_and_get_clip_tensor(test_wav_path3, nb_samp, window_size, wav_file)
    emb_X1 = get_embedding_from_clip_tensor(X1, model, device)
    
    X2 = read_wav_and_get_clip_tensor(test_wav_path4, nb_samp, window_size, wav_file)
    emb_X2 = get_embedding_from_clip_tensor(X2, model, device)
    
    sim_score = cos_sim(emb_X1, emb_X2)
    print("Similarity = {}".format(sim_score))

if __name__ == '__main__':
    main_test()

@Jungjee Jungjee pinned this issue Nov 7, 2021
@Jungjee
Copy link
Owner

Jungjee commented Nov 7, 2021

@ac-alpha thanks for the reply :)
I'll close this

@Jungjee Jungjee closed this as completed Nov 7, 2021
@hdubey
Copy link
Author

hdubey commented Jan 5, 2022

@ac-alpha thanks. using above script and provided model leads to following errors. Is it RawNet or RawNet2 or Rawnet2_modified? RuntimeError: Error(s) in loading state_dict for RawNet:
Unexpected key(s) in state_dict: "block2.0.conv_downsample.weight", "block2.0.conv_downsample.bias".
size mismatch for block2.0.bn1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.bn1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for block2.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3]).

@Jungjee
Copy link
Owner

Jungjee commented Jun 27, 2022

Closing this now as I have uploaded RawNet3 and a script to extract speaker embedding from any 16k 16bit mono utterance

@Jungjee Jungjee unpinned this issue Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants