# center_loss training

paper:  http://ydwen.github.io/papers/WenECCV16.pdf  
code: https://github.com/KaiyangZhou/pytorch-center-loss

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Center-loss algorithm

In [3]:
import torch                                                                                                                                                                                                                                                                 
import torch.nn as nn                                                                                                                                                                                                                                                        

class CenterLoss(nn.Module):                                                                                                                                                                                                                                                 
    """Center loss.                                                                                                                                                                                                                                                          

    Reference:                                                                                                                                                                                                                                                               
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.                                                                                                                                                                              

    Args:                                                                                                                                                                                                                                                                    
     num_classes (int): number of classes.                                                                                                                                                                                                                                
     feat_dim (int): feature dimension.                                                                                                                                                                                                                                   
    """                                                                                                                                                                                                                                                                      
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):                                                                                                                                                                                                            
        super(CenterLoss, self).__init__()                                                                                                                                                                                                                                   
        self.num_classes = num_classes                                                                                                                                                                                                                                       
        self.feat_dim = feat_dim                                                                                                                                                                                                                                             
        self.use_gpu = use_gpu                                                                                                                                                                                                                                               

        if self.use_gpu:                                                                                                                                                                                                                                                     
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())                                                                                                                                                                                 
        else:                                                                                                                                                                                                                                                                
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))                                                                                                                                                                                        

    def forward(self, x, labels):                                                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        Args:                                                                                                                                                                                                                                                                
         x: feature matrix with shape (batch_size, feat_dim).                                                                                                                                                                                                             
         labels: ground truth labels with shape (num_classes).                                                                                                                                                                                                            
        """                                                                                                                                                                                                                                                                  
        batch_size = x.size(0)                                                                                                                                                                                                                                               
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
            torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()                                                                                                                                                               
        distmat.addmm_(1, -2, x, self.centers.t())                                                                                                                                                                                                                           

        classes = torch.arange(self.num_classes).long()                                                                                                                                                                                                                      
        if self.use_gpu: classes = classes.cuda()                                                                                                                                                                                                                            
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)                                                                                                                                                                                                    
        mask = labels.eq(classes.expand(batch_size, self.num_classes))                                                                                                                                                                                                       

        dist = []                                                                                                                                                                                                                                                            
        for i in range(batch_size):                                                                                                                                                                                                                                          
            value = distmat[i][mask[i]]                                                                                                                                                                                                                                      
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability                                                                                                                                                                                              
            dist.append(value)                                                                                                                                                                                                                                               
        dist = torch.cat(dist)                                                                                                                                                                                                                                               
        loss = dist.mean()                                                                                                                                                                                                                                                   

        return loss

### Configuration

In [4]:
from sv_system.utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=800, splice_frames=[300, 800], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.1, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=64,
                              arch="ResNet34_v4", loss="softmax",
                              n_epochs=50
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [5]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [6]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [7]:
from sv_system.model.model_utils import find_model
model = find_model(config)

In [8]:
saved_model = torch.load("../best_models/voxc1/ResNet34_v4_softmax/ResNet34_v4_softmax_best.pth.tar")

import itertools
model_state = model.state_dict()
for k1, k2 in zip(saved_model['state_dict'], model_state):
#     print(k1, k2)
    assert saved_model['state_dict'][k1].shape == model_state[k2].shape
    model_state[k2] = saved_model['state_dict'][k1]
    

model.load_state_dict(model_state)

In [9]:
if not config['no_cuda']:
    model = model.cuda()

### Model Train

In [10]:
import torch

is_cuda = True
weight_cent = 0.01
criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=config['n_labels'], feat_dim=128, use_gpu=is_cuda)
optimizer_model = torch.optim.SGD(model.parameters(), lr=config['lrs'][0], weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)

In [11]:
from sv_system.train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)

In [12]:
set_seed(config)

In [13]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [14]:
def train(model):
    model.train()
    loss_sum = 0
    xent_loss_sum = 0
    cent_loss_sum = 0
    n_corrects = 0
    total = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        if is_cuda:
            X = X.cuda()
            y = y.cuda()

        feats, outs  = model.feat_output(X)
        loss_xent = criterion_xent(outs, y)
        loss_cent = criterion_cent(feats, y)
        loss_cent *= weight_cent
        loss = loss_xent + loss_cent

#         optimizer_model.zero_grad()
#         optimizer_centloss.zero_grad()
#         loss.backward()
#         optimizer_model.step()
#         for param in criterion_cent.parameters():                                                                                 
#             param.grad.data *= (1. / weight_cent)                                                                            
#         optimizer_centloss.step()                                                                                                 
                        
        loss_sum += loss.item()
        xent_loss_sum += loss_xent.item()
        cent_loss_sum += loss_cent.item()
        n_corrects += torch.sum(torch.eq(torch.argmax(outs, dim=1), y)).item()
        total += y.size(0)
        
        if (batch_idx+1) % 100 == 0:
            print("Batch {}/{}\t Loss {:.6f} XentLoss {:.6f} CenterLoss {:.6f}" \
                  .format(batch_idx+1, len(train_loader), loss_sum /(batch_idx+1), 
                        xent_loss_sum/(batch_idx+1), 
                        cent_loss_sum/(batch_idx+1))
                 )
        acc = n_corrects / total

    acc = n_corrects / total
    return loss_sum, acc 

In [15]:
from sv_system.train.si_train import val, sv_test, sv_euc_test
from tqdm import tqdm_notebook

for epoch_idx in range(0, 1):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(model)
    print("epoch #{}, train accuracy: {}".format(epoch_idx, train_acc))

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
#         eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        eer, label, score = sv_euc_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))

------------------------------
curr_lr: 0.1
Batch 100/2083	 Loss 18.785474 XentLoss 0.034589 CenterLoss 18.750886
Batch 200/2083	 Loss 18.736295 XentLoss 0.033912 CenterLoss 18.702383
Batch 300/2083	 Loss 18.736580 XentLoss 0.034453 CenterLoss 18.702127
Batch 400/2083	 Loss 18.754517 XentLoss 0.034049 CenterLoss 18.720467
Batch 500/2083	 Loss 18.753808 XentLoss 0.034063 CenterLoss 18.719745
Batch 600/2083	 Loss 18.766041 XentLoss 0.034014 CenterLoss 18.732027
Batch 700/2083	 Loss 18.765635 XentLoss 0.034057 CenterLoss 18.731578
Batch 800/2083	 Loss 18.756245 XentLoss 0.034337 CenterLoss 18.721908
Batch 900/2083	 Loss 18.760749 XentLoss 0.033946 CenterLoss 18.726803
Batch 1000/2083	 Loss 18.761072 XentLoss 0.033973 CenterLoss 18.727099
Batch 1100/2083	 Loss 18.761380 XentLoss 0.033785 CenterLoss 18.727595
Batch 1200/2083	 Loss 18.766533 XentLoss 0.033868 CenterLoss 18.732665
Batch 1300/2083	 Loss 18.768227 XentLoss 0.033738 CenterLoss 18.734489
Batch 1400/2083	 Loss 18.766510 XentLoss 0

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=105), HTML(value='')), layout=Layout(display=…


epoch #0, val accuracy: 0.8543792963027954


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=76), HTML(value='')), layout=Layout(display='…


epoch #0, sv eer: 0.20535619209881803


In [18]:
eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
eer

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=76), HTML(value='')), layout=Layout(display='…




0.05675646895964221

In [20]:
eer, label, score = sv_euc_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=76), HTML(value='')), layout=Layout(display='…




In [21]:
eer

0.05675646895964221