# Mix-up training

paper: https://arxiv.org/abs/1710.09412  
code: https://github.com/facebookresearch/mixup-cifar10

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch 

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

### Configuration

In [3]:
from sv_system.utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=800, splice_frames=[300, 800], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.1, 0.01], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="ResNet34", loss="softmax",
                              n_epochs=10
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [4]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [5]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [6]:
from sv_system.model.model_utils import find_model
model = find_model(config)

### Train

In [7]:
from sv_system.train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)

In [8]:
set_seed(config)

In [9]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

changes:  

alpha: 0.2

In [12]:
from tqdm import tqdm_notebook
from sv_system.train.train_utils import print_eval

def train(config, train_loader, model, optimizer, criterion):
    model.train()
    loss_sum = 0
    corrects = 0
    total = 0
    print_steps = (np.array([0.25, 0.5, 0.75, 1.0]) \
                    * len(train_loader)).astype(np.int64)

    splice_frames = config['splice_frames']
    if len(splice_frames) > 1:
        splice_frames_ = np.random.randint(splice_frames[0], splice_frames[1])
    else:
        splice_frames_ = splice_frames[-1]

    for batch_idx, (X, y) in tqdm_notebook(enumerate(train_loader), ncols=300,
            total=len(train_loader)):
        # X.shape is (batch, channel, time, bank)
        X = X.narrow(2, 0, splice_frames_)
        X, y_a, y_b, lam = mixup_data(x=X, y=y, alpha=0.2, use_cuda=False)
        if not config["no_cuda"]:
            X = X.cuda()
            y_a = y_a.cuda()
            y_b = y_b.cuda()
        optimizer.zero_grad()
        scores = model(X)
        loss = mixup_criterion(criterion, scores, y_a, y_b, lam)
        loss_sum += loss.item()
        loss.backward()
        # learning rate change
        optimizer.step()
        # schedule over iteration
        predicted = torch.argmax(scores, dim=1)
        corrects += (lam * predicted.eq(y_a).cpu().sum().float()
                    + (1 - lam) * predicted.eq(y_b).cpu().sum().float())
        total += y_a.size(0)
        if batch_idx in print_steps:
            print("train loss, acc: {:.4f}, {:.5f} ".format(corrects/total, loss_sum))
            
    return loss_sum, corrects/total

### lr: 0.1, alpha: 0.2

In [13]:
from sv_system.train.si_train import val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(config, train_loader, model, optimizer, criterion)

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))

------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.0091, 1802.39016 
train loss, acc: 0.0151, 3501.87685 
train loss, acc: 0.0247, 5052.08444 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #0, val accuracy: 0.031755391508340836


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #0, sv eer: 0.1956128207858588
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.1068, 1318.34070 
train loss, acc: 0.1295, 2524.00823 
train loss, acc: 0.1533, 3644.98033 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #1, val accuracy: 0.08936994522809982


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #1, sv eer: 0.2426259184325418
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.3324, 925.16417 
train loss, acc: 0.3474, 1808.45402 
train loss, acc: 0.3588, 2680.57668 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #2, val accuracy: 0.18109838664531708


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #2, sv eer: 0.14758811628154617
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.4078, 851.79565 
train loss, acc: 0.4288, 1630.07540 
train loss, acc: 0.4410, 2383.95196 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #3, val accuracy: 0.20914757251739502


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #3, sv eer: 0.16276221914599084
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5831, 635.37282 
train loss, acc: 0.5861, 1263.07215 
train loss, acc: 0.5844, 1907.54760 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #4, val accuracy: 0.21367503702640533


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #4, sv eer: 0.1470024491534448
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5147, 703.79828 
train loss, acc: 0.5240, 1385.82251 
train loss, acc: 0.5340, 2027.90802 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #5, val accuracy: 0.3701145648956299


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #5, sv eer: 0.12905973804706633
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5450, 680.69896 
train loss, acc: 0.5616, 1306.15452 
train loss, acc: 0.5628, 1963.73601 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #6, val accuracy: 0.3462980091571808


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #6, sv eer: 0.16105846022787776
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.7012, 502.10955 
train loss, acc: 0.7040, 993.11866 
train loss, acc: 0.6994, 1517.13050 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #7, val accuracy: 0.45476752519607544


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #7, sv eer: 0.11654775849217336
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6769, 512.26708 
train loss, acc: 0.6652, 1056.92476 
train loss, acc: 0.6641, 1613.22051 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #8, val accuracy: 0.3864765763282776


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #8, sv eer: 0.11292727079118305
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.7513, 454.59503 
train loss, acc: 0.7475, 905.06489 
train loss, acc: 0.7406, 1379.39114 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #9, val accuracy: 0.4226330816745758


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #9, sv eer: 0.11415184751357683
