# ResNet34_only_time_pooling

kaldi의 tdnn_xvector 구조를 똑같이 구현하려고한다.

목표는 utterance-level xvector의 generative한 특성을 살릴 수 있을까 알고싶은 것이다.

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="2"

### Configuration

In [3]:
from sv_system.utils.parser import set_train_config
import easydict

# datasets
# voxc1_fbank_xvector
# gcommand_fbank_xvector

args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=800, splice_frames=[300, 800], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.1, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="tdnn_conv", loss="softmax",
                              n_epochs=50
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [4]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [5]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [6]:
# from sv_system.model.model_utils import find_model
# model = find_model(config)

In [29]:
from sv_system.model.ResNet34 import ResNet34
import torch.nn as nn
import torch.nn.functional as F

class ResNet34_time_pool(ResNet34):
    """
        additional fc layer before output layer
    """
    def __init__(self, config, inplanes=16, n_labels=1000, fc_dims=None):
        super().__init__(config, inplanes, n_labels)

        extractor_output_dim = 8*inplanes * 9
        if not fc_dims:
            fc_dims = extractor_output_dim

        classifier = [nn.Linear(extractor_output_dim,
            fc_dims),
            nn.ReLU(inplace=True)]

        loss_type = config["loss"]
        if loss_type == "angular":
            classifier.append(AngleLinear(fc_dims, n_labels))
        elif loss_type == "softmax":
            classifier.append(nn.Linear(fc_dims, n_labels))
        else:
            print("not implemented loss")
            raise NotImplementedError

        self.classifier = nn.Sequential(*classifier)
        
    def embed(self, x):
        x = self.extractor(x)
        x = F.avg_pool2d(x,(x.shape[-2],1))
        x = x.view(x.size(0), -1)
        
        return x
        
    def forward(self, x):
        feat = self.embed(x)
        out = self.classifier(feat)
        
        return out

In [30]:
model = ResNet34_time_pool(config, inplanes=16, n_labels=config['n_labels'], fc_dims=512)

In [31]:
model

ResNet34_time_pool(
  (extractor): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3)

In [32]:
if not config['no_cuda']:
    model = model.cuda()

### Model Train

In [33]:
from sv_system.train.train_utils import set_seed, find_optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion, optimizer = find_optimizer(config, model)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)

In [34]:
set_seed(config)

In [35]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [36]:
import torch

def train(model):
    model.train()
    loss_sum = 0
    n_corrects = 0
    total = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        if not config['no_cuda']:
            X = X.cuda()
            y = y.cuda()

        optimizer.zero_grad()

        logit = model(X)
        loss = criterion(logit, y)
        loss.backward()
        optimizer.step()
                        
        loss_sum += loss.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(logit, dim=1), y)).item()
        total += y.size(0)
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f}" \
                  .format(batch_idx+1, len(train_loader), loss_sum /(batch_idx+1),)
                 )
        acc = n_corrects / total

    acc = n_corrects / total
    return loss_sum, acc 

In [37]:
from sv_system.train.si_train import val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(model)

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
    
    scheduler.step(train_loss)

------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 6.925580
Batch 200/1042	 Loss 6.728000
Batch 300/1042	 Loss 6.536262
Batch 400/1042	 Loss 6.359204
Batch 500/1042	 Loss 6.179141
Batch 600/1042	 Loss 6.017569
Batch 700/1042	 Loss 5.856726
Batch 800/1042	 Loss 5.704844
Batch 900/1042	 Loss 5.557880
Batch 1000/1042	 Loss 5.418602
epoch #0, val accuracy: 0.07166020572185516
epoch #0, sv eer: 0.19337663720583537
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 3.809774
Batch 200/1042	 Loss 3.730361
Batch 300/1042	 Loss 3.660816
Batch 400/1042	 Loss 3.584737
Batch 500/1042	 Loss 3.508593
Batch 600/1042	 Loss 3.421987
Batch 700/1042	 Loss 3.342123
Batch 800/1042	 Loss 3.266431
Batch 900/1042	 Loss 3.196370
Batch 1000/1042	 Loss 3.129464
epoch #1, val accuracy: 0.16760022938251495
epoch #1, sv eer: 0.16180385475455222
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 2.284419
Batch 200/1042	 Loss 2.250874
Batch 300/1042	 Loss 2.199199
Batch 400/10

Batch 300/1042	 Loss 0.192905
Batch 400/1042	 Loss 0.199508
Batch 500/1042	 Loss 0.206493
Batch 600/1042	 Loss 0.210825
Batch 700/1042	 Loss 0.211645
Batch 800/1042	 Loss 0.217223
Batch 900/1042	 Loss 0.222864
Batch 1000/1042	 Loss 0.225701
epoch #19, val accuracy: 0.6495956778526306
epoch #19, sv eer: 0.126930039399425
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 0.167281
Batch 200/1042	 Loss 0.170226
Batch 300/1042	 Loss 0.181675
Batch 400/1042	 Loss 0.190420
Batch 500/1042	 Loss 0.191184
Batch 600/1042	 Loss 0.199532
Batch 700/1042	 Loss 0.203335
Batch 800/1042	 Loss 0.208071
Batch 900/1042	 Loss 0.210603
Batch 1000/1042	 Loss 0.216963
epoch #20, val accuracy: 0.6596192121505737
epoch #20, sv eer: 0.12565222021084016
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 0.172383
Batch 200/1042	 Loss 0.169364
Batch 300/1042	 Loss 0.176494
Batch 400/1042	 Loss 0.177937
Batch 500/1042	 Loss 0.184656
Batch 600/1042	 Loss 0.186327
Batch 700/1042	 Loss 0.195

Batch 600/1042	 Loss 0.147488
Batch 700/1042	 Loss 0.153435
Batch 800/1042	 Loss 0.154090
Batch 900/1042	 Loss 0.156282
Batch 1000/1042	 Loss 0.157529
epoch #38, val accuracy: 0.6742123961448669
epoch #38, sv eer: 0.1296454051751677
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 0.135388
Batch 200/1042	 Loss 0.129423
Batch 300/1042	 Loss 0.132951
Batch 400/1042	 Loss 0.134802
Batch 500/1042	 Loss 0.136828
Batch 600/1042	 Loss 0.143628
Batch 700/1042	 Loss 0.146745
Batch 800/1042	 Loss 0.149981
Batch 900/1042	 Loss 0.150873
Batch 1000/1042	 Loss 0.154479
epoch #39, val accuracy: 0.6676634550094604
epoch #39, sv eer: 0.13182834628900011
------------------------------
curr_lr: 0.1
Batch 100/1042	 Loss 0.168544
Batch 200/1042	 Loss 0.153246
Batch 300/1042	 Loss 0.155489
Batch 400/1042	 Loss 0.149845
Batch 500/1042	 Loss 0.149067
Batch 600/1042	 Loss 0.156104
Batch 700/1042	 Loss 0.159256
Batch 800/1042	 Loss 0.159665
Batch 900/1042	 Loss 0.161150
Batch 1000/1042	 Loss 0.1