# center_loss training

paper:  http://ydwen.github.io/papers/WenECCV16.pdf  
code: https://github.com/KaiyangZhou/pytorch-center-loss

## Environment

In [11]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [12]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="2"

### Center-loss algorithm

In [13]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).                                                                                                                                                                                                             
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
            torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()                                                                                                                                                               
        distmat.addmm_(1, -2, x, self.centers.t())                                                                                                                                                                                                                           

        classes = torch.arange(self.num_classes).long()                                                                                                                                                                                                                      
        if self.use_gpu: classes = classes.cuda()                                                                                                                                                                                                                            
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)                                                                                                                                                                                                    
        mask = labels.eq(classes.expand(batch_size, self.num_classes))                                                                                                                                                                                                       

        dist = []                                                                                                                                                                                                                                                            
        for i in range(batch_size):                                                                                                                                                                                                                                          
            value = distmat[i][mask[i]]                                                                                                                                                                                                                                      
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability                                                                                                                                                                                              
            dist.append(value)                                                                                                                                                                                                                                               
        dist = torch.cat(dist)                                                                                                                                                                                                                                               
        loss = dist.mean()                                                                                                                                                                                                                                                   

        return loss

### Configuration

In [58]:
from sv_system.utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="gcommand_equal30_wav",
                              input_frames=101, splice_frames=[50, 101], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.01, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="ResNet34", loss="softmax",
                              n_epochs=50
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [47]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [48]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [49]:
# from sv_system.model.model_utils import find_model
# model = find_model(config)

In [50]:
from sv_system.model.ResNet34 import ResNet34

class ResNet34_v1(ResNet34):
    """
        additional fc layer before output layer
    """
    def __init__(self, config, inplanes=16, n_labels=1000, fc_dims=None):
        super().__init__(config, inplanes, n_labels)

        extractor_output_dim = 8*inplanes
        if not fc_dims:
            fc_dims = extractor_output_dim

        classifier = [nn.Linear(extractor_output_dim,
            fc_dims),
            nn.ReLU(inplace=True)]

        loss_type = config["loss"]
        if loss_type == "angular":
            classifier.append(AngleLinear(fc_dims, n_labels))
        elif loss_type == "softmax":
            classifier.append(nn.Linear(fc_dims, n_labels))
        else:
            print("not implemented loss")
            raise NotImplementedError

        self.classifier = nn.Sequential(*classifier)
        
    def forward(self, x):
        feat = self.embed(x)
        out = self.classifier(feat)
        
        return out
    
    def feat_output(self, x):
        feat = self.embed(x)
        out = self.classifier(feat)
        
        return feat, out

In [51]:
model = ResNet34_v1(config, n_labels=config['n_labels'])

if not config['no_cuda']:
    model = model.cuda()

### Model Train

In [52]:
import torch

is_cuda = True
weight_cent = 0.01
criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=config['n_labels'], feat_dim=128, use_gpu=is_cuda)
optimizer_model = torch.optim.SGD(model.parameters(), lr=config['lrs'][0], weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

In [53]:
from sv_system.train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)

In [54]:
set_seed(config)

In [55]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [56]:
def train(model):
    model.train()
    loss_sum = 0
    xent_loss_sum = 0
    cent_loss_sum = 0
    n_corrects = 0
    total = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        if is_cuda:
            X = X.cuda()
            y = y.cuda()

        feats, outs  = model.feat_output(X)
        loss_xent = criterion_xent(outs, y)
        loss_cent = criterion_cent(feats, y)
        loss_cent *= weight_cent
        loss = loss_xent + loss_cent

        optimizer_model.zero_grad()
        optimizer_centloss.zero_grad()
        loss.backward()
        optimizer_model.step()
        for param in criterion_cent.parameters():                                                                                 
            param.grad.data *= (1. / weight_cent)                                                                            
        optimizer_centloss.step()                                                                                                 
                        
        loss_sum += loss.item()
        xent_loss_sum += loss_xent.item()
        cent_loss_sum += loss_cent.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(outs, dim=1), y)).item()
        total += y.size(0)
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f} XentLoss {:.6f} CenterLoss {:.6f}" \
                  .format(batch_idx+1, len(train_loader), loss_sum /(batch_idx+1), 
                        xent_loss_sum/(batch_idx+1), 
                        cent_loss_sum/(batch_idx+1))
                 )
        acc = n_corrects / total

    acc = n_corrects / total
    return loss_sum, acc 

In [None]:
# from sv_system.train.si_train import val, sv_test, sv_euc_test
from tqdm import tqdm_notebook

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(model)
    print("epoch #{}, train accuracy: {}".format(epoch_idx, train_acc))

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
#         eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        eer, label, score = sv_euc_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
        
#     torch.save(model.state_dict(), open("softmax_model_per_epoch/model_{}.pt".format(epoch_idx), "wb"))