# center_loss training-tdnn

paper:  http://ydwen.github.io/papers/WenECCV16.pdf  
code: https://github.com/KaiyangZhou/pytorch-center-loss

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../../sv_system/')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

### Center-loss algorithm

In [3]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).                                                                                                                                                                                                             
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
            torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()                                                                                                                                                               
        distmat.addmm_(1, -2, x, self.centers.t())                                                                                                                                                                                                                           

        classes = torch.arange(self.num_classes).long()                                                                                                                                                                                                                      
        if self.use_gpu: classes = classes.cuda()                                                                                                                                                                                                                            
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)                                                                                                                                                                                                    
        mask = labels.eq(classes.expand(batch_size, self.num_classes))                                                                                                                                                                                                       

        dist = []                                                                                                                                                                                                                                                            
        for i in range(batch_size):                                                                                                                                                                                                                                          
            value = distmat[i][mask[i]]                                                                                                                                                                                                                                      
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability                                                                                                                                                                                              
            dist.append(value)                                                                                                                                                                                                                                               
        dist = torch.cat(dist)                                                                                                                                                                                                                                               
        loss = dist.mean()                                                                                                                                                                                                                                                   

        return loss

### Configuration

In [4]:
from utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="gcommand_mfcc30", n_labels=1759,
                              input_frames=100, splice_frames=[100], stride_frames=1, input_format='mfcc',
                              cuda=True, random_clip=False,
                              lrs=[0.1, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=256, num_workers=4,
                              arch="tdnn_xvector", loss="softmax",
                              n_epochs=100
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [5]:
from data.data_utils import find_dataset, find_trial

dfs, datasets = find_dataset(config, basedir='../../sv_system/')
trial = find_trial(config, basedir='../../sv_system/')

=> loaded trial: gcommand_equal_num_30spk_trial


In [6]:
from data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [7]:
# from sv_system.model.model_utils import find_model
# model = find_model(config)

In [8]:
from model.tdnnModel import st_pool_layer

class tdnn_xvector_v1(nn.Module):
    """xvector architecture
        tdnn6.affine is embeding layer no
        untying classifier for flexible embedding positon
    """
    def __init__(self, config, base_width, n_labels=31):
        super(tdnn_xvector_v1, self).__init__()
        inDim = 30
        self.tdnn = nn.Sequential(
            nn.Conv1d(inDim, base_width, stride=1, dilation=1, kernel_size=5),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=3, kernel_size=3),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=4, kernel_size=3),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=1, kernel_size=1),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, 1500, stride=1, dilation=1, kernel_size=1),
            nn.BatchNorm1d(1500),
            nn.ReLU(True),
            st_pool_layer(),
        )
        


        loss_type = config["loss"]
        if loss_type == "angular":
            last_fc = AngleLinear(base_width, n_labels)
        elif loss_type == "softmax":
            last_fc = nn.Linear(base_width, n_labels)
        else:
            print("not implemented loss")
            raise NotImplementedError
        
        self.tdnn6_affine = nn.Linear(3000, base_width)
        self.tdnn6_bn = nn.BatchNorm1d(base_width)
        self.tdnn6_relu = nn.ReLU(True)
        self.tdnn7_affine = nn.Linear(base_width, base_width)
        self.tdnn7_bn = nn.BatchNorm1d(base_width)
        self.tdnn7_relu = nn.ReLU(True)
        self.tdnn8_last = last_fc


        self._initialize_weights()

    def embed(self, x):
        x = x.squeeze(1)
        # (batch, time, freq) -> (batch, freq, time)
        x = x.permute(0,2,1)
        x = self.tdnn(x)

        return x

    def feat_out(self, x):
        x = self.embed(x)
        x = self.tdnn6_affine(x)
        feat = x
        x = self.tdnn6_bn(x)
        x = self.tdnn6_relu(x)
        x = self.tdnn7_affine(x)
        x = self.tdnn7_bn(x)
        x = self.tdnn7_relu(x)
        x = self.tdnn8_last(x)

        return feat, x
    
    def forward(self, x):
        x = self.embed(x)
        x = self.tdnn6_affine(x)
        x = self.tdnn6_bn(x)
        x = self.tdnn6_relu(x)
        x = self.tdnn7_affine(x)
        x = self.tdnn7_bn(x)
        x = self.tdnn7_relu(x)
        x = self.tdnn8_last(x)

        return x
    
            

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


class permute_dim(nn.Module):
    def __init__(self):
        super(permute_dim, self).__init__()

    def forward(self, x):

        # x = x.permute(0,2,1)
        return x.permute(0,2,1)


### Model Train

In [9]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [10]:
def train(model):
    model.train()
    loss_sum = 0
    xent_loss_sum = 0
    cent_loss_sum = 0
    n_corrects = 0
    total = 0
    
    splice_frames = config['splice_frames']
    if len(splice_frames) > 1:
        splice_frames_ = np.random.randint(splice_frames[0], splice_frames[1])
    else:
        splice_frames_ = splice_frames[-1]
        
    for batch_idx, (X, y) in enumerate(train_loader):       
#         splice_frames = config['splice_frames']
#         if len(splice_frames) > 1:
#             splice_frames_ = np.random.randint(splice_frames[0], splice_frames[1])
#             # min_spFr, max_spFr = splice_frames
#             # # splice_frames_ =np.floor(max_spFr -  \
#                 # # (max_spFr-min_spFr)*(config['epoch_idx']/config['n_epochs']))
#             # splice_frames_  =  max(max_spFr - 100 * np.floor(config['epoch_idx'] / 5),
#                                     # min_spFr)
#         else:
#             splice_frames_ = splice_frames[-1]

        X = X.narrow(2, 0, splice_frames_)

        if is_cuda:
            X = X.cuda()
            y = y.cuda()
            
        feats, outs  = model.feat_out(X)
        loss_xent = criterion_xent(outs, y)
        loss_cent = criterion_cent(feats, y)
        loss_cent *= weight_cent
        loss = loss_xent + loss_cent

        optimizer_model.zero_grad()
        optimizer_centloss.zero_grad()
        loss.backward()
        optimizer_model.step()
        for param in criterion_cent.parameters():                                                                                 
            param.grad.data *= (1. / weight_cent)                                                                            
        optimizer_centloss.step()                                                                                                 
                        
        loss_sum += loss.item()
        xent_loss_sum += loss_xent.item()
        cent_loss_sum += loss_cent.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(outs, dim=1), y)).item()
        total += y.size(0)
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f} XentLoss {:.6f} CenterLoss {:.6f}, spFr :{}" \
                  .format(batch_idx+1, len(train_loader), loss_sum /(batch_idx+1), 
                        xent_loss_sum/(batch_idx+1), 
                        cent_loss_sum/(batch_idx+1),
                        splice_frames_)
                 )
        acc = n_corrects / total

    acc = n_corrects / total
    return loss_sum, acc 

In [11]:
def val(config, val_loader, model, criterion):
    n_corrects = 0
    total = 0
    with torch.no_grad():
        model.eval()
        accs = []
        loss_sum = 0
        splice_frames = config['splice_frames'][0]
        stride_frames = config['stride_frames']
        
        for (X, y) in val_loader:
            if not config["no_cuda"]:
                X = X.cuda()
                y = y.cuda()
                
            split_points = range(0, X.size(2)-(splice_frames)+1, splice_frames)
            for point in split_points:
                x_in = X.narrow(2, point, splice_frames)
                scores = model(x_in)
                loss = criterion(scores, y)
                loss_sum += loss.item()
                n_corrects += torch.sum(torch.eq(torch.argmax(scores, dim=1), y)).item()
                total += y.size(0)
        avg_acc = n_corrects / total

        return loss_sum, avg_acc

In [12]:
model = tdnn_xvector_v1(config, 512, n_labels=config['n_labels'])

if not config['no_cuda']:
    model = model.cuda()

In [13]:
import torch

is_cuda = True
weight_cent = 0.01
criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=config['n_labels'], feat_dim=512, use_gpu=is_cuda)
optimizer_model = torch.optim.SGD(model.parameters(), lr=config['lrs'][0], weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

from train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)
set_seed(config)

centerloss+64batch 한게 13.9% 25에폭 최저

In [14]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

plateau_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10)
best_metric = 1.0

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

# #     train code
    train_loss, train_acc = train(model)
    print("epoch #{}, train accuracy: {}".format(epoch_idx, train_acc))

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
    
    if eer < best_metric:
        best_metric = eer
        print("save best model")
        torch.save(model.state_dict(), open("checkpoint.pt", "wb"))
    plateau_scheduler.step(train_loss)

------------------------------
curr_lr: 0.1
Batch 100/162	 Loss 5.445306 XentLoss 5.445306 CenterLoss 0.000000, spFr :100


Process Process-4:
Process Process-3:
Process Process-2:
Process Process-1:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda

KeyboardInterrupt: 

In [None]:
output_dir = "../../sv_system/models/gcommand_mfcc30/tdnn_xvector_center"

In [None]:
from train.train_utils import save_checkpoint

save_checkpoint({
    'epoch': epoch_idx,
    'step_no': (epoch_idx+1) * len(train_loader),
    'arch': config["arch"],
    'n_labels': config["n_labels"],
    'dataset': config["dataset"],
    'loss': config["loss"],
    'state_dict': model.state_dict(),
    'best_metric': val_acc,
    'optimizer' : optimizer.state_dict(),
    }, epoch_idx, False, filename="../../sv_system/models/gcommand_mfcc30/tdnn_xvector_center/model_best.pth.tar")