# original_tdnn_xvector
----
trello: https://trello.com/c/h9uUoFVc

kaldi의 tdnn_xvector 구조를 똑같이 구현하려고한다.

목표는 utterance-level xvector의 generative한 특성을 살릴 수 있을까 알고싶은 것이다.

----

30dim mfcc를 가져와야하나?

**splice_frames를 800까지 높이면 터진다.**

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Configuration

In [3]:
from sv_system.utils.parser import set_train_config
import easydict

# datasets
# voxc1_fbank_xvector
# gcommand_fbank_xvector

args = easydict.EasyDict(dict(dataset="voxc1_wav",
                              input_frames=300, splice_frames=[300], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.01, 0.001], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="tdnn_original", loss="softmax",
                              n_epochs=50
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [4]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [5]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

In [6]:
config['input_dim'] = 40

### Define Model

In [7]:
# from sv_system.model.model_utils import find_model
# model = find_model(config)

In [8]:
from sv_system.model.tdnnModel import gTDNN, st_pool_layer
from sv_system.model.ResNet34 import BasicBlock

import torch.nn as nn
import torch.nn.functional as F

class tdnn_xvector_orig(gTDNN):
    """xvector architecture"""
    def __init__(self, config, n_labels=31):
        super(tdnn_xvector_orig, self).__init__(config, n_labels)
        inDim = config['input_dim']
        self.tdnn = nn.Sequential(
            #tdnn1
            nn.Conv1d(inDim, 512, stride=1, dilation=1, kernel_size=5),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            #tdnn2
            nn.Conv1d(512, 512, stride=1, dilation=2, kernel_size=3),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            #tdnn3
            nn.Conv1d(512, 512, stride=1, dilation=3, kernel_size=3),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            #tdnn4
            nn.Conv1d(512, 512, stride=1, dilation=1, kernel_size=1),
            nn.ReLU(True),            
            nn.BatchNorm1d(512),
            #tdnn5
            nn.Conv1d(512, 1500, stride=1, dilation=1, kernel_size=1),
            nn.ReLU(True),        
            nn.BatchNorm1d(1500),
            # statistic-pooling
            st_pool_layer(),
            #tdnn6
            nn.Linear(3000, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
        )
        
        self.classifier = nn.Sequential(
            #tdnn7
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            #output
            nn.Linear(512, n_labels),
        )
        
        self._initialize_weights()

    def embed(self, x):
        x = x.squeeze()
        # (batch, time, freq) -> (batch, freq, time)
        x = x.permute(0,2,1)
        x = self.tdnn(x)

        return x

    def forward(self, x):
        x = self.embed(x)
        x = self.classifier(x)

        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

In [9]:
model = tdnn_xvector_orig(config, config['n_labels'])

In [10]:
if not config['no_cuda']:
    model = model.cuda()

### Model Train

In [11]:
from sv_system.train.train_utils import set_seed, find_optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR

criterion, optimizer = find_optimizer(config, model)
# scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)
scheduler = MultiStepLR(optimizer, [10], 0.1)

In [12]:
set_seed(config)

In [13]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [16]:
import torch

def train(model):
    model.train()
    loss_sum = 0
    n_corrects = 0
    total = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        if not config['no_cuda']:
            X = X.cuda()
            y = y.cuda()

        optimizer.zero_grad()

        logit = model(X)
        loss = criterion(logit, y)
        loss.backward()
        optimizer.step()
                        
        loss_sum += loss.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(logit, dim=1), y)).item()
        total += y.size(0)
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f}" \
                  .format(batch_idx+1, len(train_loader), loss_sum /(batch_idx+1),)
                 )
        acc = n_corrects / total

    acc = n_corrects / total
    return loss_sum, acc 

In [None]:
## from sv_system.train.si_train import val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print("curr_lr: {}".format(curr_lr))

    scheduler.step()    
    
#     train code22222
    train_loss, train_acc = train(model)
    print("epoch #{}, train accuracy: {}".format(epoch_idx, train_acc))

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
    
