#### Imports

In [2]:
import os
import errno
import random
import shutil
import librosa
import scipy
import webrtcvad
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
%matplotlib inline

from collections import defaultdict
from tqdm import tqdm_notebook, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

In [3]:
vad = webrtcvad.Vad(3)

#### Methods

In [4]:
def path_hierarchy(path: str) -> dict:
    """Computes all paths in a recursive manner to a dictionary"""
    hierarchy = {
        'type': 'folder',
        'name': os.path.basename(path),
        'path': path,
    }

    try:
        hierarchy['children'] = [
            path_hierarchy(os.path.join(path, contents))
            for contents in os.listdir(path)
        ]
    except OSError as e:
        if e.errno != errno.ENOTDIR:
            raise
        hierarchy['type'] = 'file'

    return hierarchy


def split_speaker_wav(level):
    """Splits path into speaker path and corresponding wav file"""
    root = "/".join(level['path'].split("/")[:-2]) + "/"
    filename = "/".join(level['path'].split("/")[-2:])
    return root, filename

def get_speaker_wav_dict(root):
    """Returns a dictionary with speakers paths as keys and corresponding wavs as values"""
    data_paths = defaultdict(list)
    for level_a in path_hierarchy(root)['children']:
        for level_b in level_a['children']:
            for level_c in level_b['children']:
                if level_c['type'] == 'folder':
                    for level_d in level_c['children']:
                        root, filename = split_speaker_wav(level_d)
                        data_paths[root].append(filename)
                else:
                    root, filename = split_speaker_wav(level_c)
                    data_paths[root].append(filename)
    return data_paths

In [5]:
def sample_noise_path(noise_folders):
    """Randomly samples a noise wav file"""
    sample_folder = random.choice(noise_folders)
    sample_files = os.listdir(sample_folder)
    sample_filepath = sample_folder + random.choice(sample_files)
    return sample_filepath

In [6]:
def get_failed_paths(train_paths):
    """Return files that cannot be used for training"""
    failed = []
    for train_path in tqdm(train_paths):
        db_dir = train_path.replace("wav", "meldb")
        for wav_path in train_paths[train_path]:
            npy_path = wav_path.replace('.wav', '.npy')
            db_path = db_dir + npy_path
            audio_path = train_path + wav_path
            try:
                mel_db = np.load(db_path)
                if mel_db.shape[0] < 180:
                    failed.append(wav_path)
            except Exception as e:
                failed.append(wav_path)
    return failed

#### Speech processing

In [7]:
def blend(signal, noise, target_snr=10):
    """Blends audio signal with noise with target SNR"""
    if len(signal) < len(noise):
        noise = noise[:len(signal)]
    elif len(noise) < len(signal):
        signal = signal[:len(noise)]

    if target_snr == np.inf:
        scaler, prescaler = 0, 1
    elif target_snr == -np.inf:
        scaler, prescaler = 1, 0
    else:
        signal_power = np.sum(signal**2)
        noise_power = np.sum(noise**2)
        scaler = np.sqrt( signal_power / (noise_power * 10.**(target_snr/10.)) )
        prescaler = 1

    return prescaler * signal + scaler * noise

In [8]:
def get_speech(x, sample_rate, frame_duration, hop_duration):
    """Removes silences from speech"""
    frame_length = int(frame_duration * sample_rate)
    hop_length = int(hop_duration * sample_rate)
    frames = scipy.array([x[i:i+frame_length] for i in range(0, len(x) - frame_length, hop_length)])
    is_speech = [vad.is_speech(frame.tobytes(), sample_rate) for frame in frames]
    speech_indices = np.nonzero(is_speech)
    speech_only = frames[speech_indices].flatten()
    return speech_only

def get_meldb(path, noise_folders, audio_params):
    """Returns 40 filter banks"""
    x, _ = librosa.core.load(path, sr=audio_params['sr'])
    if random.choice([True, True, False]):
        noise_path = sample_noise_path(noise_folders)
        noise, _ = librosa.core.load(noise_path, sr=audio_params['sr'])
        mean_snr = random.choice([7.5, 10 , 12.5, 15, 17.5, 20])
        std_snr = random.uniform(0.1, 1)
        snr = np.random.normal(mean_snr, std_snr)
        x = blend(x, noise, target_snr=snr)
    x = get_speech(x, _, 0.01, 0.01)
    window_length = int(audio_params['window'] * audio_params['sr'])
    hop_length = int(audio_params['hop'] * audio_params['sr'])
    spec = librosa.stft(x, n_fft=audio_params['nfft'],
                        hop_length=hop_length, 
                        win_length=window_length)
    mag_spec = np.abs(spec)
    mel_basis = librosa.filters.mel(audio_params['sr'], audio_params['nfft'],
                                    n_mels=audio_params['nmels'])
    mel_spec = np.dot(mel_basis, mag_spec)
    mel_db = librosa.amplitude_to_db(mel_spec).T
    return mel_db

#### Paths

In [9]:
train_root = "./dataset/dev/wav"
test_root = "./dataset/test/wav"
noise_root = "./dataset/QUT-NOISE/split_noises/"

In [10]:
train_paths = get_speaker_wav_dict(train_root)
test_paths = get_speaker_wav_dict(test_root)
noise_folders = [noise_root + file + "/" for file in os.listdir(noise_root)]

In [11]:
print("Total speakers in train: ", len(train_paths))
print("Total speakers in test: ", len(test_paths))
print("Total files in train: ", sum([len(train_paths[root]) for root in train_paths]))
print("Total files in test: ", sum([len(test_paths[root]) for root in test_paths]))

Total speakers in train:  1211
Total speakers in test:  40
Total files in train:  148642
Total files in test:  4874


### Precompute features for training speedup 

In [13]:
# audio_params = {'mean_snr': 15,
#                'sd_snr':10,
#                'hop': 0.01,
#                'window': 0.025,
#                'sr': 16000,
#                'nfft': 512,
#                'nmels': 40
#                }

# failed = []
# for train_path in tqdm(train_paths):
#     db_dir = train_path.replace("wav", "meldb")
#     for wav_path in train_paths[train_path]:
#         npy_path = wav_path.replace('.wav', '.npy')
#         db_path = db_dir + npy_path
#         audio_path = train_path + wav_path
#         os.makedirs(os.path.dirname(db_path), exist_ok=True)
#         db = get_meldb(audio_path, noise_folders, audio_params)
#         if db.shape[0] >= 140:
#             np.save(db_path, db)
#         else:
#             failed.append(wav_path)

#### Dataset

In [14]:
class VoxCelebDatasest(Dataset):
    def __init__(self, train_paths, noise_folders, failed, M, N, training=True):
        self.batch_size = M
        self.failed = set(failed)
        self.count = 0
        self.segment_size = None
        if training:
            self.paths = train_paths
            self.noise_folders = noise_folders
            self.N = N
        self.speakers = list(train_paths)
        random.shuffle(self.speakers)
        
    def __len__(self):
        return len(self.speakers)
    
    def __getitem__(self, idx):
        speaker = self.speakers[idx]
        wav_files = self.paths[speaker]
        wav_files = [file for file in wav_files if file not in failed]
        wav_files = random.sample(wav_files, self.N)
        
        self.keep_segment_size()
        mel_dbs = []
        for f in wav_files:
            mel_db = np.load(speaker.replace("/wav/", "/meldb/") + f.replace(".wav", ".npy"))
            self.segment_size = min(self.segment_size, len(mel_db))
            mel_dbs.append(mel_db)
            
        for idx, mel_db in enumerate(mel_dbs):
            last = len(mel_db) - self.segment_size
            beg = random.randint(0, last)
            mel_db = mel_db[beg:beg + self.segment_size]
            mel_dbs[idx] = mel_db
        
        return torch.Tensor(mel_dbs)
    
    def keep_segment_size(self):
        self.count += 1
        if self.count % (self.batch_size + 1) == 0:
            self.count = 1
        if self.count == 1:
            self.segment_size = random.randint(140, 180)

#### Model

In [15]:
class GE2EModel(nn.Module):
    def __init__(self, nlstm, dembed, dhid, dout, dropout=0.):
        super(GE2EModel, self).__init__()
        self.dropout = dropout
        self.lstm = nn.LSTM(dembed, dhid, num_layers=nlstm, batch_first=True)
        self.linear = nn.Linear(dhid, dout)
    def forward(self, inp):
        lstm_out, _ = self.lstm(inp)
        out = self.linear(lstm_out[:, -1, :])
        out_norm = out / torch.norm(out, p=2, dim=0)
        return out_norm
            
        
class GE2ELoss(nn.Module):
    def __init__(self, M):
        super(GE2ELoss, self).__init__()
        self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
        self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
        self.M = M
        self.device = device
    def forward(self, k):
        groups = torch.split(k, self.M, dim=0)
        centroids = torch.stack([torch.mean(group, dim=0) for group in groups])
        sims = []
        for idx, item in enumerate(k.repeat(N, 1, 1).transpose(1, 0)):
            cur_centroid_idx = idx // self.M
            for_stability = (centroids[cur_centroid_idx] - item[0]) * self.M / (self.M - 1)
            updated_centroids = centroids.clone()
            updated_centroids[cur_centroid_idx] = for_stability
            sims.append(F.cosine_similarity(item, updated_centroids, dim=-1, eps=1e-8))
        cossim = torch.stack(sims)
        sim_matrix = self.w * cossim + self.b
        loss = self.compute_loss(sim_matrix)
        return loss
    def compute_loss(self, sim_matrix):
        loss = 0 
        for idx, group in enumerate(torch.split(sim_matrix, self.M, dim=0)):
            for utterance_sims in group:
                cur = utterance_sims[idx]
                rest = torch.cat([utterance_sims[:idx], utterance_sims[idx+1:]])
                loss += - cur + torch.log(torch.sum(torch.exp(rest)))
        return loss

### Training methods

In [16]:
def save_model(PATH, epoch, ge2e_model, ge2e_loss, optimizer, loss):
    torch.save({
            'epoch': epoch,
            'ge2e_model_state_dict': ge2e_model.state_dict(),
            'ge2e_loss_state_dict': ge2e_loss.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)
    print("SAVED weights! @ epoch {}".format(epoch))
    
def load_model(PATH, ge2e_model, ge2e_loss, optimizer):
    checkpoint = torch.load(PATH)
    ge2e_model.load_state_dict(checkpoint['ge2e_model_state_dict'])
    ge2e_loss.load_state_dict(checkpoint['ge2e_loss_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return ge2e_model, ge2e_loss, optimizer

In [37]:
def train_epoch(train_loader, ge2e_model, ge2e_loss, optimizer, epoch, batch_size):
    shutil.rmtree("runs")
    writer = SummaryWriter()
    ge2e_model.train()
    total_loss = 0
    with SummaryWriter() as writer:
        for batch_id, mel_db_batch in enumerate(train_loader):
            optimizer.zero_grad()
            num_audio_per_batch = mel_db_batch.size(0) * mel_db_batch.size(1)
            mel_db_batch = torch.reshape(mel_db_batch, 
                                         (num_audio_per_batch, 
                                          mel_db_batch.size(2), 
                                          mel_db_batch.size(3)))
            
            perm = random.sample(range(0, num_audio_per_batch), num_audio_per_batch)
            unperm = list(perm)
            for idx, org_idx in enumerate(perm):
                unperm[org_idx] = idx
            
            mel_db_batch = mel_db_batch[perm]
            embeddings = ge2e_model(mel_db_batch.cuda())
            embeddings = embeddings[unperm]
            loss = ge2e_loss(embeddings)
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(ge2e_model.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            
            # Scale gradients
            for name, param in ge2e_model.named_parameters(): 
                if 'linear' in name:
                    param.grad *= 0.5
            for name, param in ge2e_loss.named_parameters():
                param.grad *= 0.01
                
            optimizer.step()
            total_loss += loss.item()
            avg_loss = total_loss / (batch_id + 1)
            
            if (batch_id + 1) % 5 == 0:
                for name, param in ge2e_model.named_parameters():          
                    writer.add_histogram('ge2e_model.' + name, param, batch_id + 1)
                for name, param in ge2e_loss.named_parameters():          
                    writer.add_histogram('ge2e_loss.' + name, param, batch_id + 1)
                print("epoch {:3d}, batch {:3d}, loss {:3.2f}, mean_loss {:3.2f}".format(epoch, batch_id + 1, loss.item(), avg_loss))
            writer.add_scalar('epoch_{}/avg_loss'.format(epoch), avg_loss, batch_id + 1)
            writer.add_scalar('epoch_{}/batch_loss'.format(epoch), loss.item(), batch_id + 1)
    
    print("epoch {}, loss {}".format(epoch, avg_loss))
    os.makedirs('weights', exist_ok=True)
    PATH = "weights/epoch_{}_param_dict".format(epoch)
    save_model(PATH, epoch, ge2e_model, ge2e_loss, optimizer, avg_loss)

In [19]:
def train(train_loader, nepoch, batch_size, saved_epoch=None, device='cuda'):
    ge2e_model = GE2EModel(nlstm=3, dembed=40, dhid=768, dout=256).to(device)
    ge2e_loss = GE2ELoss(M).to(device)
    
    optimizer = torch.optim.SGD([
            {'params': ge2e_model.parameters()},
            {'params': ge2e_loss.parameters()}
            ], lr= 0.01)
    
    if saved_epoch:
        PATH = "weights/epoch_{}_param_dict".format(saved_epoch)
        ge2e_model, ge2e_loss, optimizer = load_model(PATH, ge2e_model, ge2e_loss, optimizer)
        beg_epoch = saved_epoch
        
    else:
        beg_epoch = 0
    
    with SummaryWriter(comment='GE2EModel') as w:
        w.add_graph(GE2EModel(nlstm=3, dembed=40, dhid=768, dout=256), torch.randn(20, 140, 40), False)
    
    for epoch in range(beg_epoch, beg_epoch + nepoch):
        train_epoch(train_loader, ge2e_model, ge2e_loss, optimizer, epoch + 1, batch_size)
        
    return ge2e_model, ge2e_loss, optimizer

In [20]:
def eval(test_loader, ge2e_model, ge2e_loss, epoch, device='cuda'):
    ge2e_model.eval()
    for batch_id, mel_db_batch in test_loader:
        with torch.no_grad():
            pass

#### Data params

In [21]:
audio_params = {'mean_snr': 15,
               'sd_snr':10,
               'hop': 0.01,
               'window': 0.025,
               'sr': 16000,
               'nfft': 512,
               'nmels': 40
               }
M = 4 # Speakers
N = 5 # Utterances
device = 'cuda'

In [22]:
failed = get_failed_paths(train_paths)
train_dataset = VoxCelebDatasest(train_paths, noise_folders, failed, M, N)
train_loader = DataLoader(train_dataset, batch_size=M, shuffle=True, 
                          num_workers=4, drop_last=True) 

100%|██████████| 1211/1211 [01:29<00:00, 13.48it/s]


#### Train

In [39]:
nepoch, batch_size  = 1, M
ge2e_model, ge2e_loss, optimizer = train(train_loader, nepoch, batch_size)

epoch   1, batch   5, loss 270.04, mean_loss 274.65
epoch   1, batch  10, loss 256.41, mean_loss 269.01
epoch   1, batch  15, loss 244.48, mean_loss 262.33
epoch   1, batch  20, loss 236.50, mean_loss 256.45
epoch   1, batch  25, loss 230.41, mean_loss 251.49
epoch   1, batch  30, loss 222.50, mean_loss 247.42
epoch   1, batch  35, loss 225.36, mean_loss 244.06
epoch   1, batch  40, loss 208.40, mean_loss 240.89
epoch   1, batch  45, loss 207.44, mean_loss 238.10
epoch   1, batch  50, loss 213.26, mean_loss 235.65
epoch   1, batch  55, loss 222.66, mean_loss 233.88
epoch   1, batch  60, loss 225.68, mean_loss 232.62
epoch   1, batch  65, loss 209.42, mean_loss 230.77
epoch   1, batch  70, loss 196.86, mean_loss 229.17
epoch   1, batch  75, loss 215.31, mean_loss 227.82
epoch   1, batch  80, loss 218.56, mean_loss 226.74
epoch   1, batch  85, loss 195.11, mean_loss 225.59
epoch   1, batch  90, loss 180.02, mean_loss 223.89
epoch   1, batch  95, loss 205.41, mean_loss 222.76
epoch   1, b

#### Testing and noise generation

In [180]:
#noise_paths = ["./dataset/QUT-NOISE/" + file for file in os.listdir("./dataset/QUT-NOISE/") if file.endswith(".wav")]
# for noise_path in tqdm(noise_paths):
#     cur_noise_dir = noise_path.replace(".wav", "").replace("-1", "").replace("-2", "")
#     cur_noise_dir_split = cur_noise_dir.split("/")
#     cur_noise_dir_split.insert(-1, "split_noises")
#     cur_noise_dir = "/".join(cur_noise_dir_split) + "/"
#     if not os.path.exists(cur_noise_dir):
#         os.makedirs(cur_noise_dir)
#     wpath = cur_noise_dir + noise_path.split("/")[-1].replace(".wav", "")
#     os.system("ffmpeg -i {} -f segment -segment_time 2 -c copy {}%05d.wav".format(noise_path, wpath))


In [181]:
# segment = random.randint(140, 180)
# sample_speaker = random.sample(list(train_paths), 1)[0]
# sample_file = random.sample(train_paths[sample_speaker], 1)[0]
# path = sample_speaker + sample_file
# meldb, mixed = get_meldb(path, noise_folders, audio_params, segment)
# ipd.Audio(mixed, rate=16000)