Variational-Auto-Encoder
---------------------------------


### Environment

In [7]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [8]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### VAE Model

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(512, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 512)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

### Dataset

In [31]:
import torch.utils.data as data

class embedDataset(data.Dataset):
    def __init__(self, embeds, labels):
        super().__init__()
        self.embeds = embeds
        self.labels = labels
        
    def __getitem__(self, index):
        
        return self.embeds[index], self.labels[index]
    
    def __len__(self):
        
        return self.embeds.shape[0]

def embedToDataset(embeds, key_df):
    labels = key_df.label.tolist()
    dataset = embedDataset(embeds, labels)
    
    return dataset, embeds.shape[1], len(key_df.label.unique())

def key2df(keys):
    key_df = pd.DataFrame(keys, columns=['key'])
    key_df['spk'] = key_df.key.apply(lambda x: x.split("-")[0])
    key_df['label'] = key_df.groupby('spk').ngroup()
    key_df['origin'] = key_df.spk.apply(lambda x: 'voxc2' if x.startswith('id') else 'voxc1')
    
    return key_df

In [32]:
trial = pd.read_pickle("../dataset/dataframes/voxc1/voxc_trial.pkl")

# si_set
si_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/key.pkl", "rb"))
si_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/train_feat/feat.npy")
si_key_df = key2df(si_keys)

# sv_set
sv_keys = pickle.load(open("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/key.pkl", "rb"))
sv_embeds = np.load("../embeddings/voxc12/xvectors/xvectors_tdnn7b/test_feat/feat.npy")
sv_key_df = key2df(sv_keys)

In [33]:
si_dataset, embed_dim, n_labels = embedToDataset(si_embeds, si_key_df)
sv_dataset, _, _ = embedToDataset(sv_embeds, sv_key_df)

### Training

In [34]:
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
no_cuda = False

In [35]:
model = VAE()

In [36]:
reconstruction_function = nn.MSELoss(reduction='sum')

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + KLD

In [37]:
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [38]:
from torch.utils.data.dataloader import DataLoader

si_loader = DataLoader(si_dataset, num_workers = 0, batch_size = batch_size, 
                           drop_last = True, pin_memory = True)

sv_loader = DataLoader(sv_dataset, batch_size=128, num_workers=0, shuffle=False)

In [None]:
if not no_cuda:
    model = model.cuda()
    
for epoch in range(num_epochs):
    model.train()
    loss_sum = 0
    total = 0
    for batch_idx, (X, _)  in enumerate(si_loader):
        if not no_cuda:
            X = X.cuda()
        # ===================forward=====================
        recon_X, mu, logvar = model(X)
        loss = loss_function(recon_X, X, mu, logvar)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         loss_sum += loss.item()
#         total += X.size(0)
#         if batch_idx % 1000 == 0:
#             print(f"train loss: {loss_sum/total}")
            
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item()))
    
    # =================sv_loss======================
    for batch_idx, (X, _)  in enumerate(sv_loader):
        if not no_cuda:
                X = X.cuda()
        recon_X, mu, logvar = model(X)
        loss = loss_function(recon_X, X, mu, logvar)
    print(f"sv loss: {loss.item()}")        

In [44]:
torch.save(model.state_dict(), open("saved_models/vae_test.pt", "wb"))