Auto-Encoder
---------------------------------

경미의 [논문](https://drive.google.com/file/d/1RArk7z4NqY5HkwkUWx4cR2ApZNnAQxdF/view?usp=sharing)에 따르면 AE가 좀더 generalize한 feature를 뽑아준다고한다. 물론 image에 대해서 실험했고 (28x28, 32x32 의 작은...) task 간의 generalization에 대해 언급해서 조금 context가 다르다.

그래서 일단 xvector의 feature를 가지고 간단한게 AE를 구현해보려고한다.

Center-Loss
----------------------

Auto-Encoder의 MSELoss에 Center loss를 넣어서 더 모이게 한다면 어떻게 될까?

Unsupervise에서 Supervise로 된거다.

Faster
---------------

현재 centroid가 너무 많아서 연산이 오래걸린다.
하지만 실제로 사용하는 centroid는 batch에 존재하는 class들이다.

그래프를 깨지 않으면서 효율적으로 연산하는 방법을 생각해보자.

### Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

### AE Model

In [3]:
import torch
import torch.nn as nn

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(512, 400),
            nn.ReLU(True),
            nn.Linear(400, 300),
            nn.ReLU(True), nn.Linear(300, 256), nn.ReLU(True), nn.Linear(256, 128))
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 300),
            nn.ReLU(True),
            nn.Linear(300, 400),
            nn.ReLU(True), nn.Linear(400, 512), nn.Tanh())
        
        self.latent_dim = 128

    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)
        return latent, output
    
    def embed(self, x):
        latent = self.encoder(x)
        return latent

### Center-Loss

In [4]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes, feat_dim, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).  |                                                                                                                                                                                                           
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               

        centers_for_batch = self.centers[labels]
        distvec = torch.pow(x, 2).sum(dim=1) + torch.pow(centers_for_batch, 2).sum(dim=1)
#         print(distvec.shape, x.shape, centers_for_batch.shape)
        distvec += torch.sum(-2*x*centers_for_batch, dim=1)
        distvec.clamp(min=1e-12, max=1e+12) 
                                         
        loss = distvec.mean()                                                                                                                                                                                                                                                   

        return loss

### Dataset

In [5]:
import torch.utils.data as data

class embedDataset(data.Dataset):
    def __init__(self, embeds, labels):
        super().__init__()
        self.embeds = embeds
        self.labels = labels
        
    def __getitem__(self, index):
        
        return self.embeds[index], self.labels[index]
    
    def __len__(self):
        
        return self.embeds.shape[0]

def embedToDataset(embeds, key_df):
    labels = key_df.label.tolist()
    dataset = embedDataset(embeds, labels)
    
    return dataset, embeds.shape[1], len(key_df.label.unique())

def key2df(keys):
    key_df = pd.DataFrame(keys, columns=['key'])
    key_df['spk'] = key_df.key.apply(lambda x: x.split("-")[0])
    key_df['label'] = key_df.groupby('spk').ngroup()
    key_df['origin'] = key_df.spk.apply(lambda x: 'voxc2' if x.startswith('id') else 'voxc1')
    
    return key_df

In [24]:
si_dict  = np.load("../best_models/voxc1/ResNet34_v4_softmax/voxc_train_dvectors.pkl")

In [28]:
si_embeds = np.array(si_dict.values())

In [20]:
trial = pd.read_pickle("../dataset/dataframes/voxc1/voxc_trial.pkl")

# si_set
si_dict  = np.load("../best_models/voxc1/ResNet34_v4_softmax/voxc_train_dvectors.pkl")
si_embeds = si_dict.values()
si_keys = list(si_embeds.keys())
si_key_df = key2df(si_keys)

In [22]:
# sv_set
sv_embeds = np.load("../best_models/voxc1/ResNet34_v4_softmax/voxc_test_dvectors.pkl")
sv_keys = list(sv_embeds.keys())
sv_key_df = key2df(sv_keys)

In [23]:
si_dataset, embed_dim, n_labels = embedToDataset(si_embeds, si_key_df)
sv_dataset, _, _ = embedToDataset(sv_embeds, sv_key_df)

AttributeError: 'dict' object has no attribute 'shape'

### Training

In [8]:
import torch.nn.functional as F
from sklearn.metrics import roc_curve

def embeds_utterance(val_dataloader, model):
    embeddings = []
    labels = []
    if torch.cuda.is_available():
            model = model.cuda()
    model.eval()

    with torch.no_grad():
        for batch in val_dataloader:
            X, y = batch
            if not no_cuda:
                X = X.cuda()
                model = model.cuda()
                
            model_output = model.embed(X).cpu().detach()
            embeddings.append(model_output)
            labels.append(y.numpy())
        embeddings = torch.cat(embeddings)
        labels = np.hstack(labels)
    return embeddings, labels 

def sv_test(sv_loader, model, trial):
    embeddings, _ = embeds_utterance(sv_loader, model)
    trial_enroll = embeddings[trial.enrolment_id.tolist()]
    trial_test = embeddings[trial.test_id.tolist()]

    score_vector = F.cosine_similarity(trial_enroll, trial_test, dim=1)
    label_vector = np.array(trial.label)
    fpr, tpr, thres = roc_curve(
            label_vector, score_vector, pos_label=1)
    eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]

    return eer

In [9]:
num_epochs = 100
batch_size = 128
learning_rate = 1e-2
no_cuda = False

In [10]:
model = autoencoder().cuda()

cent_lr을 0.1까지 줄였더니... eer이 20%넘게까지 올라간다..

In [11]:
import torch

weight_cent = 0.01 
criterion_ae = nn.MSELoss()
criterion_cent = CenterLoss(num_classes=7324, feat_dim=model.latent_dim, use_gpu=(not no_cuda))
optimizer_ae = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-05, momentum=0.9)
optimizer_cent = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

In [12]:
from torch.utils.data.dataloader import DataLoader

si_loader = DataLoader(si_dataset, num_workers=0, batch_size=batch_size, 
                           drop_last=True, pin_memory=True)

sv_loader = DataLoader(sv_dataset, batch_size=batch_size, num_workers=0, shuffle=False)

In [None]:
if not no_cuda:
    model = model.cuda()
    
for epoch in range(num_epochs):
    loss_sum = 0
    total = 0
    for batch_idx, (X, y) in enumerate(si_loader):
        model.train()
        if not no_cuda:
            X = X.cuda()
            y = y.cuda()
        # ===================forward=====================
        latent, output = model(X)
        loss_ae = criterion_ae(output, X)
        loss_cent = criterion_cent(latent, y)
        loss_cent *= weight_cent
        loss = loss_ae + loss_cent
        # ===================backward====================
        optimizer_ae.zero_grad()
        optimizer_cent.zero_grad()
        loss.backward()
        
        optimizer_ae.step()
        for param in criterion_cent.parameters():                                                                                 
            param.grad.data *= (1. / weight_cent)
        optimizer_cent.step()
        
#         loss_sum += loss.item()
#         total += X.size(0)
#         if batch_idx % 100 == 0:
#             print(f"train loss: {loss_sum/total}")
   
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}, ae_loss:{:.4f}, cent_loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item(), loss_ae.item(), loss_cent.item()))
    
    # =================sv_loss======================
    for batch_idx, (X, y)  in enumerate(sv_loader):
        model.eval()
        if not no_cuda:
                X = X.cuda()
                y.cuda()
        latent, output = model(X)
        loss_ae = criterion_ae(output, X)
        loss_cent = criterion_cent(latent, y)
        loss_cent *= weight_cent
        loss = loss_ae + loss_cent
    eer = sv_test(sv_loader, model, trial)
    print("sv loss: {:.4f}, sv eer: {:.4f}".format(loss.item(), eer))    

In [None]:
torch.save(model.state_dict(), open("saved_models/center_ae_test.pt", "wb"))