# Contrastive Loss
------------------------------

Constrastive loss: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

Siamese Network for one shot learning: https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

### Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

### Configuration

In [3]:
from sv_system.utils.parser import set_train_config
import easydict

# datasets
# voxc1_fbank_xvector
# gcommand_fbank_xvector

args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=100, splice_frames=[50, 100], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.1, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="tdnn_conv", loss="softmax",
                              n_epochs=50
                             ))
config = set_train_config(args)

### Dataset

In [4]:
import torch.utils.data as data

class embedDataset(data.Dataset):
    def __init__(self, embeds, labels):
        super().__init__()
        self.embeds = embeds
        self.labels = labels
        
    def __getitem__(self, index):
        
        return self.embeds[index], self.labels[index]
    
    def __len__(self):
        
        return self.embeds.shape[0]

In [5]:
def embedToDataset(embeds, key_df):
    labels = key_df.label.tolist()
    dataset = embedDataset(embeds, labels)
    
    return dataset, embeds.shape[1], len(key_df.label.unique())

In [6]:
trial = pd.read_pickle("../dataset/dataframes/voxc1/voxc_trial.pkl")

In [7]:
si_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/key.pkl", "rb"))
si_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/feat.npy")

sv_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/key.pkl", "rb"))
sv_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/feat.npy")

In [8]:
# voxc1_keys = embed_keys[embed_keys.origin == 'voxc1']

In [9]:
def key2df(keys):
    key_df = pd.DataFrame(keys, columns=['key'])
    key_df['spk'] = key_df.key.apply(lambda x: x.split("-")[0])
    key_df['label'] = key_df.groupby('spk').ngroup()
    key_df['origin'] = key_df.spk.apply(lambda x: 'voxc2' if x.startswith('id') else 'voxc1')
    
    return key_df

In [10]:
si_key_df = key2df(si_keys)
sv_key_df = key2df(sv_keys)

In [11]:
si_dataset, embed_dim, n_labels = embedToDataset(si_embeds, si_key_df)
sv_dataset, _, _ = embedToDataset(sv_embeds, sv_key_df)

### Batch Sampler

In [12]:
import math
import random
import itertools

def index_dataset(dataset):
    return {c : [example_idx for example_idx, (_, class_label_ind) in \
                 enumerate(zip(dataset.embeds, dataset.labels)) if class_label_ind == c] 
            for c in set(dataset.labels)}

def sample_from_class(images_by_class, class_label_ind):
    return images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))]

def simple(batch_size, dataset, class2img = None, prob_other = 0.5):
    '''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, 
    then compute the average number of positive pairs per image,
    then sample for every image the same number of negative pairs'''
    
    if class2img is not None:
        images_by_class = class2img
    else:
        images_by_class = index_dataset(dataset)
        
    for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))):
        example_indices = []
        for i in range(0, batch_size, 2):
            perm = random.sample(images_by_class.keys(), 2)
            example_indices += [sample_from_class(images_by_class, perm[0]), 
                                sample_from_class(images_by_class, perm[0 if i == 0 or random.random() > prob_other else 1])]
        yield example_indices[:batch_size]

In [13]:
# use values to exclude unnecessary index
si_key_df['num_id'] = range(len(si_key_df))
si_class2idx = si_key_df.groupby('label').apply(lambda x: x.num_id.values).to_dict()

### Dataloader

In [14]:
adapt_sampler = lambda batch, dataset, sampler, **kwargs: \
type('', (torch.utils.data.sampler.Sampler,), 
     dict(__len__ = dataset.__len__, __iter__ = \
          lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))(dataset)

In [15]:
import torch
from torch.utils.data.dataloader import DataLoader

n_pairs_per_batch = 128
batch_size = n_pairs_per_batch * 2

si_loader = torch.utils.data.DataLoader(si_dataset, 
                                       sampler = adapt_sampler(
                                           batch_size, si_dataset, simple, class2img=si_class2idx
                                       ), 
                                       num_workers = 8, batch_size = batch_size, 
                                       drop_last = True, pin_memory = True)
sv_loader = DataLoader(sv_dataset, batch_size=128, num_workers=2, shuffle=False)

### Model Define

In [16]:
import torch.nn as nn

class dda_model(nn.Module):
    def __init__(self, in_dims, n_labels):
        super().__init__()
        
        self.input_layer = nn.Sequential(
            nn.Linear(in_dims, 1*in_dims),
            nn.PReLU()
        )
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(1*in_dims, 1*in_dims),
            nn.PReLU()
        )    
        self.hidden_batch = nn.BatchNorm1d(1*in_dims)
    
#         self.embedding_layer = nn.Linear(2*in_dims, n_labels)
        
    def embed(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer(x)
        x = self.hidden_batch(x)
    
        return x
    
    def forward(self, x):           
        x = self.embed(x)
        
        return x

### Model Train

In [17]:
import torch.nn.functional as F

def embeds_utterance(config, val_dataloader, model):
#     val_iter = iter(val_dataloader)
    embeddings = []
    labels = []
    model.eval()

    with torch.no_grad():
        for batch in val_dataloader:
            X, y = batch
            if not config['no_cuda']:
                X = X.cuda()
                
            model_output = model.embed(X).cpu().detach()
            embeddings.append(model_output)
            labels.append(y.numpy())
        embeddings = torch.cat(embeddings)
        labels = np.hstack(labels)
    return embeddings, labels 

def sv_test(config, sv_loader, model, trial):
        embeddings, _ = embeds_utterance(config, sv_loader, model)
        sim_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
        cord = [trial.enrolment_id.tolist(), trial.test_id.tolist()]
        score_vector = sim_matrix[cord].numpy()
        label_vector = np.array(trial.label)
        fpr, tpr, thres = roc_curve(
                label_vector, score_vector, pos_label=1)
        eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]

        return eer, label_vector, score_vector
    
def sv_euc_test(config, sv_loader, model, trial):
        embeddings, _ = embeds_utterance(config, sv_loader, model)
#         embeddings /= embeddings.norm(dim=1,keepdim=True)
        trial_enroll = embeddings[trial.enrolment_id.tolist()]
        trial_test = embeddings[trial.test_id.tolist()]

        dist = trial_enroll - trial_test
        print(dist)
        score_vector = -dist.norm(dim=1)
#         score_vector = -dist.pow(2).sum(1).sqrt()
        label_vector = np.array(trial.label)
        fpr, tpr, thres = roc_curve(
                label_vector, score_vector, pos_label=1)
        eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]

        return eer, label_vector, score_vector

In [18]:
def constrastive_hard_mining(anchor, pos_egs, neg_egs, margin):
    pass


In [19]:
def constrastive_loss(n1, n2, label, margin):
    dist_square = (n1 - n2).pow(2).sum(1)
    dist_square = torch.clamp(dist_square, min=1e-16)
    dist = dist_square.sqrt()

    loss = torch.mean(
        (1.0-label)*dist_square + (label)*torch.pow(torch.clamp(margin-dist, min=0.0), 2)
    )
    
    return loss

In [20]:
import torch

def costrastive_train(model, loader, criterion , margin):
    model.train()
    loss_sum = 0
    n_corrects = 0
    total = 0
    for batch_idx, (X, y) in enumerate(loader):
        if not config['no_cuda']:
            X = X.cuda()
            y = y.cuda()

        optimizer.zero_grad()

        embeds = model(X)
#         embeds = embeds / embeds.norm(dim=1,keepdim=True)
        n1 = embeds[0:batch_size:2]
        n2 = embeds[1:batch_size:2]
        n1_y = y[0:batch_size:2]
        n2_y = y[1:batch_size:2]
        label = n1_y.eq(n2_y).float() # equal or diff
        loss = criterion(n1, n2, label, margin=margin)
        loss.backward()
        optimizer.step()
                        
        loss_sum += loss.item()
        total += y.size(0)
        if (batch_idx+1) % 1000 == 0:
            print("Batch {}/{}\t Loss {:.6f}" \
                  .format(batch_idx+1, len(loader), loss_sum / total))
    return loss_sum / total

In [21]:
model = dda_model(embed_dim, n_labels) 

In [22]:
if not config['no_cuda']:
    model = model.cuda()

In [23]:
from sv_system.train.train_utils import set_seed, find_optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR

config['lrs'] = [0.01]
_, optimizer = find_optimizer(config, model)

criterion = constrastive_loss
plateau_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)
step_scheduler = MultiStepLR(optimizer, [30], 0.1)

In [24]:
from sv_system.train.si_train import val
from sklearn.metrics import roc_curve
from tensorboardX import SummaryWriter

writer = SummaryWriter("logs/xvector_contrasitive_test.tf.log")

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print("curr_lr: {}".format(curr_lr))

#     step_scheduler.step()    
    
#     train code
    train_loss = costrastive_train(model, si_loader, criterion, margin=0.4)
    print("epoch #{}, train loss: {}".format(epoch_idx, train_loss))
    writer.add_scalar("train/loss", train_loss, epoch_idx+1)

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_euc_test(config, sv_loader, model, trial)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
        writer.add_scalar("sv_test/eer", eer, +1)
    plateau_scheduler.step(train_loss)

------------------------------
curr_lr: 0.01
Batch 1000/4989	 Loss 0.004673
Batch 2000/4989	 Loss 0.002435
Batch 3000/4989	 Loss 0.001687
Batch 4000/4989	 Loss 0.001313
epoch #0, train loss: 0.0010898308916805258
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0000,  0.0000,  ..., -0.0000,  0.0000, -0.0000]])
epoch #0, sv eer: 0.4104994143328719
------------------------------
curr_lr: 0.01
Batch 1000/4989	 Loss 0.000187
Batch 2000/4989	 Loss 0.000184
Batch 3000/4989	 Loss 0.000183
Batch 4000/4989	 Loss 0.000182
epoch #1, train loss: 0.00018160392993415812
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [-0.0000,  0.0000,  0.00

Process Process-97:
Process Process-93:
Process Process-92:
Process Process-98:
Process Process-95:
Process Process-96:
Process Process-94:
Process Process-91:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py",

KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt


Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-24-857ea17f75dd>", line 15, in <module>
    train_loss = costrastive_train(model, si_loader, criterion, margin=0.4)
  File "<ipython-input-20-2dd2f2d0e60f>", line 8, in costrastive_train
    for batch_idx, (X, y) in enumerate(loader):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 286, in __next__
    return self._process_next_batch(batch)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _process_next_batch
    self._put_indices()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 295, in _put_indices
    indices = next(self.sample_iter, None)
  File "/opt/conda/envs/pytorch-py3.6/lib

Exception in thread Thread-14:
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 71, in _worker_manager_loop
    r = in_queue.get()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/resource_s

KeyboardInterrupt: 

In [None]:
torch.save(dda_net.state_dict(), open("temp_dda_net.pt", "wb"))