In [1]:
import os
import time

import numpy as np
import pandas as pd
from scipy.io import wavfile
from scipy import signal

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms import Compose

from sklearn.metrics import confusion_matrix

import tensorboardX
from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
import warnings
warnings.filterwarnings('ignore') # scipy throws future warnings on fft (known bug)

In [3]:
def wav2spectrogram(path, segment_len=3, window='hamming', Tw=25, Ts=10, 
                    pre_emphasis=0.97, alpha=0.99, return_onesided=False):
    # read .wav file
    try:
        rate, samples = wavfile.read(path)
    
    except ValueError:
        print(path)
        assert True==False, path
    
    ## parameters
    # frame duration (samples)
    Nw = int(rate * Tw * 1e-3)
    Ns = int(rate * (Tw - Ts) * 1e-3)
    # overlapped duration (samples)
    # 2 ** to the next pow of 2 of (Nw - 1)
    nfft = 2 ** (Nw - 1).bit_length()

    # preemphasis filter
    samples = np.append(samples[0], samples[1:] - pre_emphasis * samples[:-1])

    # removes DC component of the signal and add a small dither
    samples = signal.lfilter([1, -1], [1, -alpha], samples)
    dither = np.random.uniform(-1, 1, samples.shape)
    spow = np.std(samples)
    samples = samples + 1e-6 * spow * dither

    # segment selection
    upper_bound = len(samples) - segment_len * rate
    start = np.random.randint(0, upper_bound)
    end = start + segment_len * rate
    samples = samples[start:end]

    # spectogram
    _, _, spec = signal.spectrogram(samples, rate, window, Nw, Ns, nfft, 
                                    mode='magnitude', return_onesided=return_onesided)

    # just multiplying it by 1600 makes spectrograms in the paper and here "the same"
    spec *= rate / 10
    
    return spec

In [4]:
class IdentificationDatasetTrain(Dataset):
    
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform
        
        iden_split_path = os.path.join(path, 'iden_split.txt')
        split = pd.read_table(iden_split_path, sep=' ', header=None, names=['phase', 'path'])
        split['label'] = split['path'].apply(lambda x: int(x.split('/')[0].replace('id1', '')) - 1)
        
        # make train/test id split (in paths class id numbering starts with 1)
        fullid_arr = np.arange(1251) # 1--1251
        testid_arr = np.arange(269, 309) # 270--309
        trainid_arr = np.setdiff1d(fullid_arr, testid_arr) # 1--1251 \ 270--309
        # subsetting ids for training
        mask = split['label'].isin(trainid_arr)
        self.dataset = split['path'][mask].reset_index(drop=True)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # path
        track_path = self.dataset[idx]
        audio_path = os.path.join(self.path, 'audio', track_path)
        
        # extract label from path like id10003/L9_sh8msGV59/00001.txt
        # subtracting 1 because PyTorch assumes that C_i in [0, 1251-1]
        label = int(track_path.split('/')[0].replace('id1', '')) - 1
        # PyTorch complains if label > num_classes. For ex, num_classes=1211
        # label is 1250. train labels \in [0, ..., 268, 309, ..., 1250]. (269 + 942 = 1211)
        # therefore, we subtract 40 (# of test classes) from a label => label \in [0, 1211]
        if label >= 309:
            label -= 40
        
        # make a spectrogram from a .wavfile
        spec = wav2spectrogram(audio_path)
        
        if self.transform:
            spec = self.transform(spec)

        return label, spec

In [5]:
class Normalize(object):
    """Normalizes voice spectrogram (mean-varience)"""
    
    def __call__(self, spec):
        
        # (Freq, Time)
        # mean-variance normalization for every spectrogram (not batch-wise)
        mu = spec.mean(axis=1).reshape(512, 1)
        sigma = spec.std(axis=1).reshape(512, 1)
        spec = (spec - mu) / sigma

        return spec

class ToTensor(object):
    """Convert spectogram to Tensor."""
    
    def __call__(self, spec):
        F, T = spec.shape
        
        # now specs are of size (Freq, Time) and 2D but has to be 3D (channel dim)
        spec = spec.reshape(1, F, T)
        
        # make the ndarray to be of a proper type (was float64)
        spec = spec.astype(np.float32)
        
        return torch.from_numpy(spec)

In [6]:
class VoiceNet(nn.Module):

    def __init__(self, num_classes=2):
        super(VoiceNet, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=7, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
        self.bn1 = nn.BatchNorm2d(num_features=96)
        self.bn2 = nn.BatchNorm2d(num_features=256)
        self.bn3 = nn.BatchNorm2d(num_features=256)
        self.bn4 = nn.BatchNorm2d(num_features=256)
        self.bn5 = nn.BatchNorm2d(num_features=256)
        self.bn6 = nn.BatchNorm2d(num_features=4096)
        self.bn7 = nn.BatchNorm1d(num_features=1024)
        
        self.relu = nn.ReLU()
        
        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.mpool5 = nn.MaxPool2d(kernel_size=(5, 3), stride=(3, 2))
        
        # Conv2d with weights of size (H, 1) is identical to FC with H weights
        self.fc6 = nn.Conv2d(in_channels=256, out_channels=4096, kernel_size=(9, 1))
        self.fc7 = nn.Linear(in_features=4096, out_features=1024)
        self.fc8 = nn.Linear(in_features=1024, out_features=num_classes)
        
    def forward_once(self, x):
        B, C, H, W = x.size()
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.mpool1(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.mpool2(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.mpool5(x)
        x = self.relu(self.bn6(self.fc6(x)))
        
        _, _, _, W = x.size()
        self.apool6 = nn.AvgPool2d(kernel_size=(1, W))
        x = self.apool6(x)
        
        x = x.view(x.size(0), -1)
        
        if self.training:
            x = self.relu(self.bn7(self.fc7(x)))
            x = self.fc8(x)
        
        # we use the fc7 output for Hard Negative Mining (inference)
        else:
            x = self.fc7(x)
            x = F.normalize(x)
        
        # during training, there's no need for SoftMax because CELoss calculates it
        return x
    
    # phase: [training_iden, inference_negative_mining, training_siamese, verif_test]
    def forward(self, voice1, voice2=None, phase='train_iden'):
        if phase in ['train_iden', 'eval_mining']:
            return self.forward_once(voice1)
        
        elif phase in ['train_veri', 'eval_veri']:
            voice1 = self.forward_once(voice1)
            voice2 = self.forward_once(voice2)
            return voice1, voice2

In [7]:
DATASET_PATH = '/home/nvme/data/vc1/'
LOG_PATH = '/home/nvme/logs/VoxCeleb/_grad_test_gpu1_{}'.format(time.time()) ## HERE
EPOCH_NUM = 30

# in shared code B = 100 but PyTorch throws CUDA out of memory at B = 97 
# though B=96 takes only 90.6% of the GPU Mem (bug?):
# https://discuss.pytorch.org/t/lesser-memory-consumption-with-a-larger-batch-in-multi-gpu-setup/29087
# B = 96
# but when 
torch.backends.cudnn.deterministic = True
# I can set B = 100
B = 100

WEIGHT_DECAY = 5e-4
LR_INIT = 1e-2
LR_LAST = 1e-4
# lr scheduler parameter
gamma = 10 ** (np.log10(LR_LAST / LR_INIT) / (EPOCH_NUM - 1))
MOMENTUM = 0.9
DEVICE = 'cuda:1'
torch.cuda.set_device(1)
NUM_WORKERS = 4
TBoard = tensorboardX.SummaryWriter(log_dir=LOG_PATH)

In [8]:
# net = VoiceNet(num_classes=1211)
# net.to(DEVICE)

In [9]:
# transforms = Compose([
#     Normalize(),
#     ToTensor()
# ])

# trainset = IdentificationDatasetTrain(DATASET_PATH, transform=transforms)
# trainsetloader = torch.utils.data.DataLoader(trainset, batch_size=B, 
#                                              num_workers=NUM_WORKERS, shuffle=True)

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), LR_INIT, MOMENTUM, weight_decay=WEIGHT_DECAY)
# lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

In [10]:
# for epoch_num in range(EPOCH_NUM):
#     lr_scheduler.step()
    
#     # train
#     net.train()
    
#     for iter_num, (labels, specs) in tqdm(enumerate(trainsetloader)):
#         optimizer.zero_grad()
#         labels, specs = labels.to(DEVICE), specs.to(DEVICE)
#         scores = net(specs, phase='train_iden')
#         loss = criterion(scores, labels)
#         loss.backward()
#         optimizer.step()
        
#         # TBoard
#         step_num = epoch_num * len(trainsetloader) + iter_num
#         TBoard.add_scalar('Metrics/train_loss', loss.item(), step_num)
#         TBoard.add_scalar('Metrics/lr', lr_scheduler.get_lr()[0], step_num)
        
# # when the training is finished save the model
# torch.save(net.state_dict(), os.path.join(LOG_PATH, 'model_snapshot_{}.txt'.format(time.time())))
# TBoard.close()

In [11]:
# del specs, labels, net
# torch.cuda.empty_cache()

In [12]:
# pretrained_dict = torch.load(os.path.join(LOG_PATH, 'model_snapshot_1542979501.519298.txt'))

# net = VoiceNet(num_classes=1211)
# net.to(DEVICE)

# model_dict = net.state_dict()

# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(pretrained_dict) 
# # 3. load the new state dict
# net.load_state_dict(model_dict)

In [13]:
# pretrained_dict = torch.load(os.path.join(LOG_PATH, 'model_snapshot_1542979501.519298.txt'))
pretrained_dict = torch.load('/home/nvme/logs/VoxCeleb/verif_class/model_snapshot_1542979501.519298.txt')

net = VoiceNet(num_classes=1211)
net.to(DEVICE)

model_dict = net.state_dict()

# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(pretrained_dict) 
# # 3. load the new state dict
# net.load_state_dict(model_dict)
net.load_state_dict(pretrained_dict)

In [14]:
class VerificationDatasetTrain(Dataset):
    
    def __init__(self, path, model, batch_size, device, transform=None):
        self.path = path
        self.model = model
        self.transform = transform
        self.device = device
        
        fullid_arr = np.arange(1, 1252) # 1--1251
        testid_arr = np.arange(270, 310) # 270--309
        trainid_arr = np.setdiff1d(fullid_arr, testid_arr) # 1--1251 \ 270--309

        # split the set of ids into `len(trainid_arr) // batch_size` subsets
        self.splits = np.array_split(trainid_arr, len(trainid_arr) // batch_size)
        
    def __len__(self):
        return len(self.splits)
    
    def cosine_sim_matrix(self, tensor1, tensor2):
        B, D = tensor1.size()
        dot = tensor2 @ tensor1.t()
        norm1 = tensor1.norm(dim=1)
        norm2 = tensor2.norm(dim=1).view(1, B).t()
        dot /= norm1 * norm2
        return dot.t()
    
    def __getitem__(self, idx):
        
        ## POSITIVE PART
        ids = self.splits[idx]
        # shuffle ids to make sure that every negative pair will consist of voices of 
        # different identities at each iteration.
        ids = np.random.permutation(ids)
        anchors = [0] * len(ids)
        positives = [0] * len(ids)
        
        for i, id in enumerate(ids):
            # folders have paths as follows:
            # ids/tracks/segments
            # for examples: id10254/7gWzIy6yIIk/00001.wav
            # 265 -> id10265
            full_id = 'id1{:04d}'.format(id)
            # list all tracks for that id
            track_list = os.listdir(os.path.join(self.path, 'audio', full_id))
            # randomly select two tracks without replacement
            track1, track2 = np.random.choice(track_list, 2, replace=False)
            # select two voice tracks
            track1_fullpath = os.path.join(self.path, 'audio', full_id, track1)
            track2_fullpath = os.path.join(self.path, 'audio', full_id, track2)
            # list all segments for each voice track
            track1_segments = os.listdir(track1_fullpath)
            track2_segments = os.listdir(track2_fullpath)
            # randomly select two voice segments
            track1_name = np.random.choice(track1_segments)
            track2_name = np.random.choice(track2_segments)
            # then construct full paths
            voice1_path = os.path.join(track1_fullpath, track1_name)
            voice2_path = os.path.join(track2_fullpath, track2_name)
            # create spectrograms for selected .wav files
            spec1 = wav2spectrogram(voice1_path)
            spec2 = wav2spectrogram(voice2_path)
            
            # apply transformations
            if self.transform:
                spec1 = self.transform(spec1)
                spec2 = self.transform(spec2)

            # add to the list
            anchors[i] = spec1
            positives[i] = spec2
        
        # concatenate and add "channel" dimension
        anchors = torch.cat(anchors).unsqueeze(1)
        positives = torch.cat(positives).unsqueeze(1)
        
        # we need to keep spectrograms in memory in order to return them later
        anchor_specs = anchors.clone()
        positive_specs = positives.clone()
        
        # before feeding tensors into net, transfer them to a device (GPU)
        anchors = anchors.to(self.device)
        positives = positives.to(self.device)
        
        # calculate embeddings and make sure we switch phase back to training
        self.model.eval()
        anchors = self.model(anchors, phase='eval_mining') # B, 1024
        positives = self.model(positives, phase='eval_mining') # --//---
        self.model.train()
        
        # there is no need to keep tensors in GPU memory # TODO
        anchors = anchors.cpu()
        positives = positives.cpu()
        torch.cuda.empty_cache()
        
        ## NEGATIVE PART
        # calculate a cosine similarity matrix
        sim_mat = self.cosine_sim_matrix(anchors, positives)
        
        sim_sorted, sim_sorted_idx = sim_mat.sort(dim=1)
        # Given a sim matrix Sij, if i=j a value corresponds to a similarity between 
        # positive pairs -> we need to prevent them from getting to the negative samples
        # First, we need to remove i=j elements.
        B = len(ids)
        mask = (sim_sorted_idx != torch.arange(B).repeat(1, B).view(B, B).t())
        sim_sorted_idx_rm = sim_sorted_idx[mask].view(B, B-1)
        
        # HARD NEGATIVE MINING PART
        # select the indices for appropriately hard samples
        tau = 0.1
        idx_threshold = round(tau * (B-2))
        # only half of the batch size -> B // 2
        hnm_idxs = sim_sorted_idx_rm[B // 2:, idx_threshold]
        
        # RANDOM PART
        idx_threshold_rand = torch.from_numpy(np.random.uniform(size=(B, 1)) * (B-1)).long()
        rand_idxs = torch.gather(sim_sorted_idx_rm, dim=1, index=idx_threshold_rand)[:B // 2]
        negative_specs = positive_specs[torch.cat([rand_idxs.view(-1), hnm_idxs.view(-1)]), :]
        
        anchors_anchors_specs = torch.cat([anchor_specs, anchor_specs])
        positives_negatives_specs = torch.cat([positive_specs, negative_specs])
        labels = torch.cat([torch.ones(B), torch.zeros(B)])
        
        return labels, anchors_anchors_specs, positives_negatives_specs

In [15]:
class VerificationDatasetTest(Dataset):
    
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform
        
        test_pairs_path = os.path.join(self.path, 'veri_test.txt')
        self.dataset = pd.read_table(test_pairs_path, sep=' ', header=None)
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        label, voice1_path, voice2_path = self.dataset.iloc[idx]
        
        voice1_path_full = os.path.join(self.path, 'audio', voice1_path)
        voice2_path_full = os.path.join(self.path, 'audio', voice2_path)
        
        spec1 = wav2spectrogram(voice1_path_full)
        spec2 = wav2spectrogram(voice2_path_full)
        
        if self.transform:
            spec1 = self.transform(spec1)
            spec2 = self.transform(spec2)
        
        return label, spec1, spec2

In [16]:
class ContrastiveLoss(nn.Module):

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.pdist = nn.PairwiseDistance()
        
    def forward(self, labels, anchors, counterparts, step_num=None):
        dists = self.pdist(F.normalize(anchors), F.normalize(counterparts))
        loss = torch.mean(labels * dists ** 2 + (1 - labels) * (self.margin - dists).clamp(0) ** 2)
        
        TBoard.add_scalar('Metrics_verification/pos_part_mean', (labels * dists ** 2).mean(), step_num)
        TBoard.add_scalar('Metrics_verification/neg_part_mean', 
                          ((1 - labels) * (self.margin - dists) ** 2).mean(), step_num)
        TBoard.add_scalar('Metrics_verification/margin', self.margin, step_num)
        
        TBoard.add_scalar('Metrics_verification/pos_dist_mean', (labels * dists).mean(), step_num)
        TBoard.add_scalar('Metrics_verification/neg_dist_mean', ((1 - labels) * dists).mean(), step_num)
        
        return loss

In [17]:
def Cdet_min(y_true, pred_dists, threshold_step=1e-3, Cmiss=1, Cfa=1, Ptar=0.01):
    Cdets = []
    Pmisses = []
    Pfas = []
    
    for threshold in tqdm(np.arange(0, 2, threshold_step)):
        y_pred = (pred_dists < threshold).astype(np.int16)
        TN, FP, FN, TP = confusion_matrix(y_true, y_pred).ravel()
        N = FP + TN
        P = FN + TP
        Pfa = FP / N
        Pmiss = FN / P
        Cdet = Cmiss * Pmiss * Ptar + Cfa * Pfa * (1 - Ptar)
        Cdets.append(Cdet)
        Pmisses.append(Pmiss)
        Pfas.append(Pfa)
        
    Cdets = np.array(Cdets)
    Pmisses = np.array(Pmisses)
    Pfas = np.array(Pfas)
    
    return np.min(Cdets), Pfas, Pmisses#, Cdets


def EER(Pfas, Pmisses):
    """ Equal Error Rate
    Returns an average value between closest Pfa and Pmiss. 
    For exmaple, Pfa = 0.114; Pmiss = 0.112, ERR = 0.113"""
    
    idx = np.abs(Pfas - Pmisses).argmin()
    
    return np.mean([Pfas[idx], Pmisses[idx]])

In [18]:
for param in net.parameters():
    param.requires_grad = False
    
net.fc8 = nn.Linear(net.fc8.in_features, 1024)
net.to(DEVICE);

In [19]:
B = 30

transforms = Compose([
    Normalize(),
    ToTensor()
])

# TODO: add a comment on different batch sizes
trainset = VerificationDatasetTrain(DATASET_PATH, model=net, batch_size=B,
                                    device=DEVICE, transform=transforms)
trainsetloader = torch.utils.data.DataLoader(trainset, batch_size=1, num_workers=0, 
                                             shuffle=True)

testset = VerificationDatasetTest(DATASET_PATH, transforms)
testsetloader = torch.utils.data.DataLoader(testset, batch_size=1, num_workers=0)

criterion = ContrastiveLoss(margin=1)
optimizer = optim.SGD(net.parameters(), 1e-1, MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

In [20]:
for epoch_num in range(EPOCH_NUM):
#     lr_scheduler.step() # here
    
    # train
    net.train()
    
    for iter_num, (labels, anchors, counterparts) in tqdm(enumerate(trainsetloader)):
        step_num = epoch_num * len(trainsetloader) + iter_num
        
        anchors, counterparts = anchors.squeeze(0), counterparts.squeeze(0)
        optimizer.zero_grad()
        labels, anchors, counterparts = labels.to(DEVICE), anchors.to(DEVICE), counterparts.to(DEVICE)
        anchors, counterparts = net(anchors, counterparts, phase='train_veri')
        loss = criterion(labels, anchors, counterparts, step_num)
        loss.backward()
        optimizer.step()
        
        # TBoard
        step_num = epoch_num * len(trainsetloader) + iter_num
        TBoard.add_scalar('Metrics_verification/train_loss', loss.item(), step_num)
        TBoard.add_scalar('Metrics_verification/lr', lr_scheduler.get_lr()[0], step_num)
        TBoard.add_scalar('Metrics_verification/conv5', net.conv5.weight.mean(), step_num)
        TBoard.add_scalar('Metrics_verification/fc8', net.fc8.weight.mean(), step_num)

        torch.cuda.empty_cache()
    
#     if (epoch_num + 1) % 3 == 0:
    # test
    net.eval()

    labels = []
    pred_dists = []
    net.eval()

    for iter_num, (label, spec1, spec2) in tqdm(enumerate(testsetloader)):
        label, spec1, spec2 = label.to(DEVICE), spec1.to(DEVICE), spec2.to(DEVICE)
        spec1, spec2 = net(spec1, spec2, phase='eval_veri')
        dist = F.pairwise_distance(spec1, spec2).item()

        # append a prediction and label to results
        labels.append(label.item())
        pred_dists.append(dist)

    labels = np.array(labels)
    pred_dists = np.array(pred_dists)

    Cdetmin, Pfas, Pmisses = Cdet_min(labels, pred_dists)
    eer = EER(Pfas, Pmisses)

    TBoard.add_scalar('Metrics_verification/Cdet_min', Cdetmin, epoch_num)
    TBoard.add_scalar('Metrics_verification/EER', eer, epoch_num)
    
# when the training is finished save the model
torch.save(net.state_dict(), os.path.join(LOG_PATH, 'model_snapshot_{}.txt'.format(time.time())))
TBoard.close()

40it [00:26,  1.54it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.49it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.51it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.51it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.55it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.51it/s]
40it [00:26,  1.51it/s]
40it [00:26,  1.55it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.52it/s]
40it [00:26,  1.54it/s]
40it [00:26,  1.53it/s]
40it [00:26,  1.55it/s]
