#### Imports

In [1]:
import os
import errno
import random
import librosa
import scipy
import webrtcvad
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
%matplotlib inline

from collections import defaultdict
from tqdm import tqdm_notebook, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [2]:
vad = webrtcvad.Vad(3)

#### Methods

In [3]:
def path_hierarchy(path):
    hierarchy = {
        'type': 'folder',
        'name': os.path.basename(path),
        'path': path,
    }

    try:
        hierarchy['children'] = [
            path_hierarchy(os.path.join(path, contents))
            for contents in os.listdir(path)
        ]
    except OSError as e:
        if e.errno != errno.ENOTDIR:
            raise
        hierarchy['type'] = 'file'

    return hierarchy

In [4]:
def get_paths(level):
    root = "/".join(level['path'].split("/")[:-2]) + "/"
    filename = "/".join(level['path'].split("/")[-2:])
    return root, filename
def get_path_jsons(root):
    data_paths = defaultdict(list)
    for level_a in path_hierarchy(root)['children']:
        for level_b in level_a['children']:
            for level_c in level_b['children']:
                if level_c['type'] == 'folder':
                    for level_d in level_c['children']:
                        root, filename = get_paths(level_d)
                        data_paths[root].append(filename)
                else:
                    root, filename = get_paths(level_c)
                    data_paths[root].append(filename)
    return data_paths

In [5]:
def blend(signal, noise, target_snr=10):
    if len(signal) < len(noise):
        noise = noise[:len(signal)]
    elif len(noise) < len(signal):
        signal = signal[:len(noise)]

    if target_snr == np.inf:
        scaler, prescaler = 0, 1
    elif target_snr == -np.inf:
        scaler, prescaler = 1, 0
    else:
        signal_power = np.sum(signal**2)
        noise_power = np.sum(noise**2)
        scaler = np.sqrt( signal_power / (noise_power * 10.**(target_snr/10.)) )
        prescaler = 1

    # blend
    return prescaler * signal + scaler * noise

In [6]:
def sample_noise_path(noise_folders):
    sample_folder = random.choice(noise_folders)
    sample_files = os.listdir(sample_folder)
    sample_filepath = sample_folder + random.choice(sample_files)
    return sample_filepath

In [7]:
def get_speech(x, sample_rate, frame_duration, hop_duration):
    frame_length = int(frame_duration * sample_rate)
    hop_length = int(hop_duration * sample_rate)
    frames = scipy.array([x[i:i+frame_length] for i in range(0, len(x) - frame_length, hop_length)])
    is_speech = [vad.is_speech(frame.tobytes(), sample_rate) for frame in frames]
    speech_indices = np.nonzero(is_speech)
    speech_only = frames[speech_indices].flatten()
    return speech_only

def get_meldb(path, noise_folders, audio_params, tisv_frame):
    x, _ = librosa.core.load(path, sr=audio_params['sr'])
    if random.choice([True, True, False]):
        noise_path = sample_noise_path(noise_folders)
        noise, _ = librosa.core.load(noise_path, sr=audio_params['sr'])
        mean_snr = random.choice([7.5, 10 , 12.5, 15, 17.5, 20])
        std_snr = random.uniform(0.1, 1)
        snr = np.random.normal(mean_snr, std_snr)
        x = blend(x, noise, target_snr=snr)
    x = get_speech(x, _, 0.01, 0.01)
    window_length = int(audio_params['window'] * audio_params['sr'])
    hop_length = int(audio_params['hop'] * audio_params['sr'])
    spec = librosa.stft(x, n_fft=audio_params['nfft'],
                        hop_length=hop_length, 
                        win_length=window_length)
    mag_spec = np.abs(spec)
    mel_basis = librosa.filters.mel(audio_params['sr'], audio_params['nfft'],
                                    n_mels=audio_params['nmels'])
    mel_spec = np.dot(mel_basis, mag_spec)
    mel_db = librosa.amplitude_to_db(mel_spec).T
    last = len(mel_db) - tisv_frame - 1
    beg = random.randint(0, last)
    return mel_db[beg:beg+tisv_frame], x

#### Paths

In [8]:
train_root = "./dataset/dev/wav/"
test_root = "./dataset/test/wav"
noise_root = "./dataset/QUT-NOISE/split_noises/"

In [9]:
train_paths = get_path_jsons(train_root)
test_paths = get_path_jsons(test_root)
noise_folders = [noise_root + file + "/" for file in os.listdir(noise_root)]

In [10]:
print("Total speakers in train: ", len(train_paths))
print("Total speakers in test: ", len(test_paths))
print("Total files in train: ", sum([len(train_paths[root]) for root in train_paths]))
print("Total files in test: ", sum([len(test_paths[root]) for root in test_paths]))

Total speakers in train:  1211
Total speakers in test:  40
Total files in train:  148642
Total files in test:  4874


#### Dataset

In [23]:
class VoxCelebDatasest(Dataset):
    def __init__(self, train_paths, noise_folders, M, N, training=True):
        self.batch_size = M
        self.count = 0
        self.segment_size = None
        if training:
            self.paths = train_paths
            self.noise_folders = noise_folders
            self.utterance_number = N
        self.speakers = list(train_paths)
        random.shuffle(self.speakers)
        
    def __len__(self):
        return len(self.speakers)
    
    def __getitem__(self, idx):
        speaker = self.speakers[idx]
        wav_files = self.paths[speaker]
        random.shuffle(wav_files)
        wav_files = random.sample(wav_files, self.utterance_number)
        mel_dbs = []
        self.keep_segment_size()
        for f in wav_files:
            mel_db, _ = get_meldb(speaker + f, self.noise_folders, audio_params, self.segment_size)
            mel_dbs.append(mel_db)
        return torch.Tensor(mel_dbs)
    
    def keep_segment_size(self):
        self.count += 1
        if self.count % (self.batch_size + 1) == 0:
            self.count = 1
        if self.count == 1:
            self.segment_size = random.randint(140, 180)

#### Model

In [151]:
class GE2E(nn.Module):
    def __init__(self, nlstm, dembed, dhid, dout, dropout=0.):
        super(GE2E, self).__init__()
        self.dropout = dropout
        self.lstm = nn.LSTM(dembed, dhid, num_layers=nlstm, batch_first=True)
        self.linear = nn.Linear(dhid, dout)
    def forward(self, inp):
        lstm_out, _ = self.lstm(inp)
        out = self.linear(lstm_out[:, -1, :])
        out_norm = out / torch.norm(out, p=2, dim=0)
        return out_norm
            
        
class GE2ELoss(nn.Module):
    def __init__(self):
        super(GE2ELoss, self).__init__()
        self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
        self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
        self.device = device
    def forward(self, k, M):
        groups = torch.split(k, M, dim=0)
        centroids = torch.stack([torch.mean(group, dim=0) for group in groups])
        sims = []
        for idx, item in enumerate(k.repeat(N, 1, 1).transpose(1, 0)):
            cur_centroid_idx = idx // M
            for_stability = (centroids[cur_centroid_idx] - item[0]) * M / (M - 1)
            updated_centroids = centroids.clone()
            updated_centroids[cur_centroid_idx] = for_stability
            sims.append(F.cosine_similarity(item, updated_centroids, dim=-1, eps=1e-8))
        cossim = torch.stack(sims)
        sim_matrix = self.w * cossim + self.b
        loss = self.compute_loss(sim_matrix)
        return loss
    def compute_loss(self, sim_matrix):
        loss = 0 
        for idx, group in enumerate(torch.split(sim_matrix, M, dim=0)):
            for utterance_sims in group:
                cur = utterance_sims[idx]
                rest = torch.cat([utterance_sims[:idx], utterance_sims[idx+1:]])
                loss += - cur + torch.log(torch.sum(torch.exp(rest)))
        return loss

In [37]:
device = 'cuda'

#### Data params

In [24]:
audio_params = {'mean_snr': 15,
               'sd_snr':10,
               'hop': 0.01,
               'window': 0.025,
               'sr': 16000,
               'tisv_frame': 180,
               'nfft': 512,
               'nmels': 40
               }
M = 4
N = 5

In [163]:
train_dataset = VoxCelebDatasest(train_paths, noise_folders, M, N)
train_loader = DataLoader(train_dataset, batch_size=M, shuffle=True, 
                          num_workers=1, drop_last=True) 

In [177]:
def train(train_loader, ge2e, ge2e_loss):
    ge2e.train()
    for batch_id, mel_db_batch in enumerate(train_loader):
        optimizer.zero_grad()
        mel_db_batch = torch.reshape(mel_db_batch, (5*batch_size, mel_db_batch.size(2), mel_db_batch.size(3)))
        embeddings = ge2e(mel_db_batch.cuda())
        loss = ge2e_loss(embeddings, M)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(ge2e.parameters(), 3.0)
        torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
        optimizer.step()
        print(batch_id, loss.item())
        if batch_id == 10:
            break

#### Define model and train

In [178]:
ge2e = GE2E(nlstm=3, dembed=40, dhid=768, dout=256).to(device)
ge2e_loss = GE2ELoss().to(device)
optimizer = torch.optim.SGD([
                {'params': ge2e.parameters()},
                {'params': ge2e_loss.parameters()}
            ], lr= 0.01)

In [179]:
train(train_loader, ge2e, ge2e_loss)

0 276.13323974609375
1 275.7585144042969
2 266.4223937988281
3 262.79620361328125
4 264.90338134765625
5 251.80618286132812
6 246.74217224121094
7 244.52845764160156
8 244.03240966796875
9 239.78164672851562
10 238.54244995117188


#### Testing and noise generation

In [180]:
#noise_paths = ["./dataset/QUT-NOISE/" + file for file in os.listdir("./dataset/QUT-NOISE/") if file.endswith(".wav")]
# for noise_path in tqdm(noise_paths):
#     cur_noise_dir = noise_path.replace(".wav", "").replace("-1", "").replace("-2", "")
#     cur_noise_dir_split = cur_noise_dir.split("/")
#     cur_noise_dir_split.insert(-1, "split_noises")
#     cur_noise_dir = "/".join(cur_noise_dir_split) + "/"
#     if not os.path.exists(cur_noise_dir):
#         os.makedirs(cur_noise_dir)
#     wpath = cur_noise_dir + noise_path.split("/")[-1].replace(".wav", "")
#     os.system("ffmpeg -i {} -f segment -segment_time 2 -c copy {}%05d.wav".format(noise_path, wpath))


In [181]:
# segment = random.randint(140, 180)
# sample_speaker = random.sample(list(train_paths), 1)[0]
# sample_file = random.sample(train_paths[sample_speaker], 1)[0]
# path = sample_speaker + sample_file
# meldb, mixed = get_meldb(path, noise_folders, audio_params, segment)
# ipd.Audio(mixed, rate=16000)