# Mix-up training

paper: https://arxiv.org/abs/1710.09412  
code: https://github.com/facebookresearch/mixup-cifar10

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


### Configuration

In [17]:
from sv_system.utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=800, splice_frames=[300, 800], stride_frames=1, 
                              input_format='fbank',
                              cuda=True,
                              lrs=[0.1, 0.01], lr_schedule=[30], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="ResNet34_v4", loss="softmax",
                              n_epochs=10
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [18]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [19]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [20]:
from sv_system.model.model_utils import find_model
model = find_model(config)

### Load Model

In [30]:
import torch
saved_state_dict = torch.load("./softmax_model_per_epoch/model_9.pt")

In [31]:
import itertools
model_state = model.state_dict()
for k1, k2 in zip(saved_state_dict, model_state):
#     print(k1, k2)
    assert saved_state_dict[k1].shape == model_state[k2].shape
    model_state[k2] = saved_state_dict[k1]
    

model.load_state_dict(model_state)

## Train

In [32]:
from sv_system.train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)

In [33]:
set_seed(config)

In [34]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [35]:
import torch 

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [36]:
from tqdm import tqdm_notebook
from sv_system.train.train_utils import print_eval

def mixup_train(config, train_loader, model, optimizer, criterion):
    model.train()
    loss_sum = 0
    corrects = 0
    total = 0
    print_steps = (np.array([0.25, 0.5, 0.75, 1.0]) \
                    * len(train_loader)).astype(np.int64)

    splice_frames = config['splice_frames']
    if len(splice_frames) > 1:
        splice_frames_ = np.random.randint(splice_frames[0], splice_frames[1])
    else:
        splice_frames_ = splice_frames[-1]

    for batch_idx, (X, y) in tqdm_notebook(enumerate(train_loader), ncols=300,
            total=len(train_loader)):
        # X.shape is (batch, channel, time, bank)
        X = X.narrow(2, 0, splice_frames_)
        X, y_a, y_b, lam = mixup_data(x=X, y=y, alpha=0.4, use_cuda=False)
        if not config["no_cuda"]:
            X = X.cuda()
            y_a = y_a.cuda()
            y_b = y_b.cuda()
        optimizer.zero_grad()
        scores = model(X)
        loss = mixup_criterion(criterion, scores, y_a, y_b, lam)
        loss_sum += loss.item()
        loss.backward()
        # learning rate change
        optimizer.step()
        # schedule over iteration
        predicted = torch.argmax(scores, dim=1)
        corrects += (lam * predicted.eq(y_a).cpu().sum().float()
                    + (1 - lam) * predicted.eq(y_b).cpu().sum().float())
        total += y_a.size(0)
        if batch_idx in print_steps:
            print("train loss, acc: {:.4f}, {:.5f} ".format(corrects/total, loss_sum))
            
    return loss_sum, corrects/total

In [37]:
from sv_system.train.si_train import train, val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
#     train_loss, train_acc = train(config, train_loader, model, optimizer, criterion, tqdm=tqdm_notebook)
    train_loss, train_acc = mixup_train(config, train_loader, model, optimizer, criterion)

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        
        
        eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))
    
#     torch.save(model.state_dict(), open("softmax_model_per_epoch/model_{}.pt".format(epoch_idx), "wb"))
    

------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5590, 801.06744 
train loss, acc: 0.5795, 1532.83120 
train loss, acc: 0.5812, 2287.06029 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #0, val accuracy: 0.32829347252845764


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #0, sv eer: 0.11058460227877755
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6634, 648.76938 
train loss, acc: 0.6382, 1356.88433 
train loss, acc: 0.6434, 1996.20084 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #1, val accuracy: 0.39498400688171387


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #1, sv eer: 0.11345969545309338
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6579, 628.00448 
train loss, acc: 0.6542, 1262.84392 
train loss, acc: 0.6495, 1915.02051 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #2, val accuracy: 0.30077072978019714


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #2, sv eer: 0.11665424342455542
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6500, 631.72888 
train loss, acc: 0.6397, 1302.73551 
train loss, acc: 0.6368, 1966.59232 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #3, val accuracy: 0.4954725205898285


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #3, sv eer: 0.10435523373442658
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5603, 781.10076 
train loss, acc: 0.5672, 1527.77924 
train loss, acc: 0.5761, 2237.13348 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #4, val accuracy: 0.4427644908428192


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #4, sv eer: 0.13773826003620487
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6761, 616.73998 
train loss, acc: 0.6699, 1218.53699 
train loss, acc: 0.6654, 1840.47168 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #5, val accuracy: 0.44853436946868896


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #5, sv eer: 0.112874028324992
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6458, 643.27675 
train loss, acc: 0.6357, 1293.45612 
train loss, acc: 0.6371, 1917.60055 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #6, val accuracy: 0.3433498740196228


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #6, sv eer: 0.11138323927164306
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6809, 580.70911 
train loss, acc: 0.6769, 1182.31042 
train loss, acc: 0.6801, 1766.09261 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #7, val accuracy: 0.5009475946426392


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #7, sv eer: 0.10547332552443829
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.7017, 570.00859 
train loss, acc: 0.6841, 1184.33853 
train loss, acc: 0.6817, 1769.03566 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #8, val accuracy: 0.49159789085388184


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #8, sv eer: 0.09610265147481631
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6860, 588.69394 
train loss, acc: 0.6972, 1121.78004 
train loss, acc: 0.6866, 1734.32402 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #9, val accuracy: 0.4933035969734192


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #9, sv eer: 0.1252262804813119


training acc는 얼마나 나오나?

In [39]:
train_acc

tensor(0.6866)