Auto-Encoder
---------------------------------

경미의 [논문](https://drive.google.com/file/d/1RArk7z4NqY5HkwkUWx4cR2ApZNnAQxdF/view?usp=sharing)에 따르면 AE가 좀더 generalize한 feature를 뽑아준다고한다. 물론 image에 대해서 실험했고 (28x28, 32x32 의 작은...) task 간의 generalization에 대해 언급해서 조금 context가 다르다.

그래서 일단 xvector의 feature를 가지고 간단한게 AE를 구현해보려고한다.

Center-Loss
----------------------

Auto-Encoder의 MSELoss에 Center loss를 넣어서 더 모이게 한다면 어떻게 될까?

Unsupervise에서 Supervise로 된거다.

### Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### AE Model

In [3]:
import torch
import torch.nn as nn

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(512, 400),
            nn.ReLU(True),
            nn.Linear(400, 300),
            nn.ReLU(True), nn.Linear(300, 256), nn.ReLU(True), nn.Linear(256, 128))
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 300),
            nn.ReLU(True),
            nn.Linear(300, 400),
            nn.ReLU(True), nn.Linear(400, 512), nn.Tanh())
        
        self.latent_dim = 128

    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)
        return latent, output

### Center-Loss

In [4]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes, feat_dim, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).  |                                                                                                                                                                                                           
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
            torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()                                                                                                                                                               
        distmat.addmm_(1, -2, x, self.centers.t())                                                                                                                                                                                                                           

        classes = torch.arange(self.num_classes).long()                                                                                                                                                                                                                      
        if self.use_gpu: classes = classes.cuda()                                                                                                                                                                                                                            
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)                                                                                                                                                                                                    
        mask = labels.eq(classes.expand(batch_size, self.num_classes))                                                                                                                                                                                                       

        dist = []                                                                                                                                                                                                                                                            
        for i in range(batch_size):                                                                                                                                                                                                                                          
            value = distmat[i][mask[i]]
            # for checking 'nan's
            value_ = value.clone().cpu()
            assert np.count_nonzero(np.isnan(value_.detach().numpy())) == 0
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability                                                                                                                                                                                              
            dist.append(value)                                                                                                                                                                                                                                               
        dist = torch.cat(dist)                                                                                                                                                                                                                                               
        loss = dist.mean()                                                                                                                                                                                                                                                   

        return loss

### Dataset

In [5]:
import torch.utils.data as data

class embedDataset(data.Dataset):
    def __init__(self, embeds, labels):
        super().__init__()
        self.embeds = embeds
        self.labels = labels
        
    def __getitem__(self, index):
        
        return self.embeds[index], self.labels[index]
    
    def __len__(self):
        
        return self.embeds.shape[0]

def embedToDataset(embeds, key_df):
    labels = key_df.label.tolist()
    dataset = embedDataset(embeds, labels)
    
    return dataset, embeds.shape[1], len(key_df.label.unique())

def key2df(keys):
    key_df = pd.DataFrame(keys, columns=['key'])
    key_df['spk'] = key_df.key.apply(lambda x: x.split("-")[0])
    key_df['label'] = key_df.groupby('spk').ngroup()
    key_df['origin'] = key_df.spk.apply(lambda x: 'voxc2' if x.startswith('id') else 'voxc1')
    
    return key_df

In [6]:
trial = pd.read_pickle("../dataset/dataframes/voxc1/voxc_trial.pkl")

# si_set
si_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/key.pkl", "rb"))
si_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/feat.npy")
si_key_df = key2df(si_keys)

# sv_set
sv_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/key.pkl", "rb"))
sv_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/feat.npy")
sv_key_df = key2df(sv_keys)

In [7]:
si_dataset, embed_dim, n_labels = embedToDataset(si_embeds, si_key_df)
sv_dataset, _, _ = embedToDataset(sv_embeds, sv_key_df)

### Training

In [8]:
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
no_cuda = False

In [9]:
model = autoencoder().cuda()

In [10]:
import torch

weight_cent = 0.01 
criterion_ae = nn.MSELoss()
criterion_cent = CenterLoss(num_classes=7324, feat_dim=model.latent_dim, use_gpu=(not no_cuda))
optimizer_ae = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-05, momentum=0.9)
optimizer_cent = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

In [11]:
from torch.utils.data.dataloader import DataLoader

si_loader = DataLoader(si_dataset, num_workers=0, batch_size=batch_size, 
                           drop_last=True, pin_memory=True)

sv_loader = DataLoader(sv_dataset, batch_size=batch_size, num_workers=0, shuffle=False)

In [None]:
if not no_cuda:
    model = model.cuda()
    
for epoch in range(num_epochs):
    loss_sum = 0
    total = 0
    for batch_idx, (X, y) in enumerate(si_loader):
        model.train()
        if not no_cuda:
            X = X.cuda()
            y = y.cuda()
        # ===================forward=====================
        latent, output = model(X)
        loss_ae = criterion_ae(output, X)
        loss_cent = criterion_cent(latent, y)
        loss_cent *= weight_cent
        loss = loss_ae + loss_cent
        # ===================backward====================
        optimizer_ae.zero_grad()
        optimizer_cent.zero_grad()
        loss.backward()
        
        optimizer_ae.step()
        for param in criterion_cent.parameters():                                                                                 
            param.grad.data *= (1. / weight_cent)
        optimizer_cent.step()
        
#         loss_sum += loss.item()
#         total += X.size(0)
#         if batch_idx % 100 == 0:
#             print(f"train loss: {loss_sum/total}")
   
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}, ae_loss:{:.4f}, cent_loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item(), loss_ae.item(), loss_cent.item()))

epoch [1/100], loss:2.1070, ae_loss:0.9407, cent_loss:1.1663
epoch [2/100], loss:0.9408, ae_loss:0.9370, cent_loss:0.0038
