# DDA (Deep Discriminant Analysis)

putting dvectors through additional network

In [6]:
import pandas as pd
import pickle
import numpy as np
%pylab
%matplotlib inline

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
import sys
sys.path.append('../')

In [33]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [4]:
trial = pd.read_pickle("../dataset/dataframes/voxc1/voxc_trial.pkl")

### Load embeddings

In [5]:
embedding_dict = pickle.load(open("../best_models/voxc1/ResNet34_v4_softmax/voxc_train_dvectors.pkl", 
                                  "rb"))

In [6]:
embeds = np.array([v for v in embedding_dict.values()])

In [7]:
spks = [k.split("/")[0] for k in embedding_dict.keys()]
spk2label = pd.Series(spks).unique().tolist()
labels = [spk2label.index(spk) for spk in spks]

In [8]:
test_embedding_dict = pickle.load(open("../best_models/voxc1/ResNet34_v4_softmax/voxc_test_dvectors.pkl", 
                                  "rb"))
test_embeds = np.array([v for v in test_embedding_dict.values()])
test_spks = [k.split("/")[0] for k in test_embedding_dict.keys()]
test_spk2label = pd.Series(test_spks).unique().tolist()
test_labels = [test_spk2label.index(spk) for spk in test_spks]

### Dataset and Dataloader

In [10]:
dataset = embedDataset(embeds, labels)
test_dataset = embedDataset(test_embeds, test_labels)

### Model Define

In [11]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).                                                                                                                                                                                                             
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
            torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()                                                                                                                                                               
        distmat.addmm_(1, -2, x, self.centers.t())                                                                                                                                                                                                                           

        classes = torch.arange(self.num_classes).long()                                                                                                                                                                                                                      
        if self.use_gpu: classes = classes.cuda()                                                                                                                                                                                                                            
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)                                                                                                                                                                                                    
        mask = labels.eq(classes.expand(batch_size, self.num_classes))                                                                                                                                                                                                       

        dist = []                                                                                                                                                                                                                                                            
        for i in range(batch_size):                                                                                                                                                                                                                                          
            value = distmat[i][mask[i]]                                                                                                                                                                                                                                      
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability                                                                                                                                                                                              
            dist.append(value)                                                                                                                                                                                                                                               
        dist = torch.cat(dist)                                                                                                                                                                                                                                               
        loss = dist.mean()                                                                                                                                                                                                                                                   

        return loss

In [12]:
import torch.nn as nn

class dda_model(nn.Module):
    def __init__(self, in_dims, n_labels):
        super().__init__()
        self.input_layer = nn.Sequential(
            nn.Linear(in_dims, 2*in_dims),
            nn.ReLU()
        )
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(2*in_dims, 2*in_dims),
            nn.ReLU(),
        )    
#         self.hidden_batch = nn.BatchNorm2d(1)
    
        self.embedding_layer = nn.Linear(2*in_dims, n_labels)
    
    def forward(self, x):           
        x = self.input_layer(x)
        feat = self.hidden_layer(x)
        out = self.embedding_layer(feat)
        
        return feat, out
    
    def embed(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer(x)
    
        return x

In [17]:
dda_net = dda_model(embeds.shape[1], len(spk2label)) 

In [18]:
dda_net

dda_model(
  (input_layer): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
  )
  (hidden_layer): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
  )
  (embedding_layer): Linear(in_features=256, out_features=1211, bias=True)
)

### Model train

In [19]:
import torch

is_cuda = True
weight_cent = 0.01
criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=len(spk2label), feat_dim=256, use_gpu=is_cuda)
optimizer_model = torch.optim.SGD(dda_net.parameters(), lr=0.001, weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

dataloader = DataLoader(dataset, batch_size=128, num_workers=4, shuffle=True)

In [34]:
dda_net.train()

if is_cuda:
    dda_net = dda_net.cuda()

for epoch_idx in range(20):
    print(f"epoch: {epoch_idx}")
    loss_sum = 0
    xent_loss_sum = 0
    cent_loss_sum = 0
    n_corrects = 0
    for batch_idx, (X, y) in enumerate(dataloader):
        if is_cuda:
            X = X.cuda()
            y = y.cuda()

        feats, outs  = dda_net(X)
        loss_xent = criterion_xent(outs, y)
        loss_cent = criterion_cent(feats, y)
        loss_cent *= weight_cent
        loss = loss_xent + loss_cent

        optimizer_model.zero_grad()
        optimizer_centloss.zero_grad()
        loss.backward()
        optimizer_model.step()
        for param in criterion_cent.parameters():                                                                                 
            param.grad.data *= (1. / weight_cent)                                                                            
        optimizer_centloss.step()                                                                                                 
                        
        loss_sum += loss.item()
        xent_loss_sum += loss_xent.item()
        cent_loss_sum += loss_cent.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(outs, dim=1), y))
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f} XentLoss {:.6f} CenterLoss {:.6f}" \
                  .format(batch_idx+1, len(dataloader), loss_sum /(batch_idx+1), 
                        xent_loss_sum/(batch_idx+1), 
                        cent_loss_sum/(batch_idx+1))
                 )

    acc = n_corrects.item() / embeds.shape[0]
    print(f"loss: {loss_sum}, acc:{acc}")

epoch: 0
Batch 100/1156	 Loss 1.347935 XentLoss 0.669932 CenterLoss 0.678003
Batch 200/1156	 Loss 1.356123 XentLoss 0.675990 CenterLoss 0.680133
Batch 300/1156	 Loss 1.349205 XentLoss 0.671279 CenterLoss 0.677925
Batch 400/1156	 Loss 1.348451 XentLoss 0.670388 CenterLoss 0.678062
Batch 500/1156	 Loss 1.349134 XentLoss 0.670966 CenterLoss 0.678168
Batch 600/1156	 Loss 1.350251 XentLoss 0.671806 CenterLoss 0.678444
Batch 700/1156	 Loss 1.349512 XentLoss 0.671183 CenterLoss 0.678329
Batch 800/1156	 Loss 1.347875 XentLoss 0.669868 CenterLoss 0.678006
Batch 900/1156	 Loss 1.345107 XentLoss 0.667614 CenterLoss 0.677493
Batch 1000/1156	 Loss 1.344864 XentLoss 0.667301 CenterLoss 0.677563
Batch 1100/1156	 Loss 1.344678 XentLoss 0.667193 CenterLoss 0.677485
loss: 1555.0148074626923, acc:0.9490789873931118
epoch: 1
Batch 100/1156	 Loss 1.323208 XentLoss 0.655143 CenterLoss 0.668064
Batch 200/1156	 Loss 1.326369 XentLoss 0.657923 CenterLoss 0.668446
Batch 300/1156	 Loss 1.324209 XentLoss 0.654893

Batch 300/1156	 Loss 1.160116 XentLoss 0.547176 CenterLoss 0.612940
Batch 400/1156	 Loss 1.162422 XentLoss 0.549147 CenterLoss 0.613275
Batch 500/1156	 Loss 1.164449 XentLoss 0.550701 CenterLoss 0.613748
Batch 600/1156	 Loss 1.163393 XentLoss 0.549537 CenterLoss 0.613856
Batch 700/1156	 Loss 1.162055 XentLoss 0.548262 CenterLoss 0.613793
Batch 800/1156	 Loss 1.161496 XentLoss 0.547691 CenterLoss 0.613805
Batch 900/1156	 Loss 1.162428 XentLoss 0.548287 CenterLoss 0.614141
Batch 1000/1156	 Loss 1.163695 XentLoss 0.548994 CenterLoss 0.614701
Batch 1100/1156	 Loss 1.163769 XentLoss 0.549248 CenterLoss 0.614521
loss: 1344.3152918815613, acc:0.9619089464967723
epoch: 11
Batch 100/1156	 Loss 1.149263 XentLoss 0.540494 CenterLoss 0.608768
Batch 200/1156	 Loss 1.150369 XentLoss 0.542444 CenterLoss 0.607924
Batch 300/1156	 Loss 1.148086 XentLoss 0.540227 CenterLoss 0.607859
Batch 400/1156	 Loss 1.148280 XentLoss 0.540011 CenterLoss 0.608269
Batch 500/1156	 Loss 1.149690 XentLoss 0.541157 CenterL

In [22]:
torch.save(dda_net.state_dict(), open("temp_dda_net.pt", "wb"))

### Extracting new embeddings

In [23]:
dda_net.load_state_dict(torch.load("temp_dda_net.pt"))

In [35]:
test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=1, shuffle=False)

In [36]:
new_embeds = []
for (X, y) in test_dataloader:
        if is_cuda:
            X = X.cuda()
        new_embed = dda_net.embed(X)
        new_embeds += [new_embed]

In [37]:
new_embed_tensor = torch.cat(new_embeds, dim=0)

In [38]:
import torch.nn.functional as F
sim_matrix = F.cosine_similarity(                                                                                                                                                                                                                                    
     new_embed_tensor.cpu().unsqueeze(1), new_embed_tensor.cpu().unsqueeze(0), dim=2)       

In [39]:
from sklearn.metrics import roc_curve

cord = [trial.enrolment_id.tolist(), trial.test_id.tolist()]                                                                                                                                                                                                         
score_vector = sim_matrix[cord].detach().numpy()                                                                                                                                                                                                                              
label_vector = np.array(trial.label)                                                                                                                                                                                                                                 
fpr, tpr, thres = roc_curve(                                                                                                                                                                                                                                         
     label_vector, score_vector, pos_label=1)                                                                                                                                                                                                                     
eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]

In [40]:
eer

0.064263656692578