# Mix-up training

paper: https://arxiv.org/abs/1710.09412  
code: https://github.com/facebookresearch/mixup-cifar10

## Environment

In [1]:
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline

import pandas as pd
import pickle
import numpy as np
import sys
import os

sys.path.append('../')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


### Mix-up algorithm

In [2]:
import torch 

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

### Configuration

In [30]:
from sv_system.utils.parser import set_train_config
import easydict
args = easydict.EasyDict(dict(dataset="voxc1_fbank_xvector",
                              input_frames=800, splice_frames=[300, 800], stride_frames=1, input_format='fbank',
                              cuda=True,
                              lrs=[0.001, 0.001], lr_schedule=[20], seed=1337,
                              no_eer=False,
                              batch_size=128,
                              arch="ResNet34", loss="softmax",
                              n_epochs=10
                             ))
config = set_train_config(args)

### Dataset and Dataloader

In [31]:
from sv_system.data.data_utils import find_dataset, find_trial

_, datasets = find_dataset(config, basedir='../')
trial = find_trial(config, basedir='../')

In [32]:
from sv_system.data.dataloader import init_loaders

dataloaders = init_loaders(config, datasets)

### Define Model

In [33]:
from sv_system.model.model_utils import find_model
model = find_model(config)

### Train

In [25]:
from sv_system.train.train_utils import set_seed, find_optimizer

criterion, optimizer = find_optimizer(config, model)

In [26]:
set_seed(config)

In [27]:
if not config['no_eer']:
    train_loader, val_loader, test_loader, sv_loader = dataloaders
else:
    train_loader, val_loader, test_loader = dataloaders

In [1]:
from tqdm import tqdm_notebook
from sv_system.train.train_utils import print_eval

def train(config, train_loader, model, optimizer, criterion):
    model.train()
    loss_sum = 0
    corrects = 0
    total = 0
    print_steps = (np.array([0.25, 0.5, 0.75, 1.0]) \
                    * len(train_loader)).astype(np.int64)

    splice_frames = config['splice_frames']
    if len(splice_frames) > 1:
        splice_frames_ = np.random.randint(splice_frames[0], splice_frames[1])
    else:
        splice_frames_ = splice_frames[-1]

    for batch_idx, (X, y) in tqdm_notebook(enumerate(train_loader), ncols=300,
            total=len(train_loader)):
        # X.shape is (batch, channel, time, bank)
        X = X.narrow(2, 0, splice_frames_)
        X, y_a, y_b, lam = mixup_data(x=X, y=y, alpha=0.4, use_cuda=False)
        if not config["no_cuda"]:
            X = X.cuda()
            y_a = y_a.cuda()
            y_b = y_b.cuda()
        optimizer.zero_grad()
        scores = model(X)
        loss = mixup_criterion(criterion, scores, y_a, y_b, lam)
        loss_sum += loss.item()
        loss.backward()
        # learning rate change
        optimizer.step()
        # schedule over iteration
        predicted = torch.argmax(scores, dim=1)
        corrects += (lam * predicted.eq(y_a).cpu().sum().float()
                    + (1 - lam) * predicted.eq(y_b).cpu().sum().float())
        total += y_a.size(0)
        if batch_idx in print_steps:
            print("train loss, acc: {:.4f}, {:.5f} ".format(corrects/total, loss_sum))
            
    return loss_sum, corrects/total

ModuleNotFoundError: No module named 'sv_system'

### lr 0.1

In [29]:
from sv_system.train.si_train import val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(config, train_loader, model, optimizer, criterion)

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))

------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.0092, 1804.66924 
train loss, acc: 0.0143, 3512.80904 
train loss, acc: 0.0216, 5095.39023 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #0, val accuracy: 0.02242671698331833


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #0, sv eer: 0.2500266212330955
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.1148, 1335.55180 
train loss, acc: 0.1332, 2599.64403 
train loss, acc: 0.1579, 3750.68386 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #1, val accuracy: 0.07555592805147171


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #1, sv eer: 0.1752209562346928
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.2908, 1031.77048 
train loss, acc: 0.3074, 2020.66789 
train loss, acc: 0.3222, 2979.92552 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #2, val accuracy: 0.16753706336021423


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #2, sv eer: 0.12985837503993186
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.4113, 873.79143 
train loss, acc: 0.4152, 1754.43987 
train loss, acc: 0.4209, 2609.06752 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #3, val accuracy: 0.2362280935049057


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #3, sv eer: 0.1291662229794484
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.4043, 920.54852 
train loss, acc: 0.4175, 1799.99574 
train loss, acc: 0.4308, 2635.55550 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #4, val accuracy: 0.22464622557163239


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #4, sv eer: 0.18102438504951548
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5522, 728.97903 
train loss, acc: 0.5547, 1428.80001 
train loss, acc: 0.5538, 2144.40922 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #5, val accuracy: 0.32538744807243347


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #5, sv eer: 0.11835800234266851
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.5586, 720.08795 
train loss, acc: 0.5512, 1441.56012 
train loss, acc: 0.5549, 2137.52586 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #6, val accuracy: 0.23999746143817902


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #6, sv eer: 0.16329464380790118
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6210, 634.91511 
train loss, acc: 0.6174, 1284.26613 
train loss, acc: 0.6202, 1915.47532 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #7, val accuracy: 0.3838232755661011


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #7, sv eer: 0.11883718453838782
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6537, 604.41302 
train loss, acc: 0.6376, 1249.55388 
train loss, acc: 0.6377, 1860.32717 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #8, val accuracy: 0.4802686870098114


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #8, sv eer: 0.11542966670216165
------------------------------
curr_lr: 0.1


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.6478, 613.16104 
train loss, acc: 0.6618, 1167.65365 
train loss, acc: 0.6537, 1793.08263 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #9, val accuracy: 0.377147912979126


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #9, sv eer: 0.13560856138856353


### lr 0.001

In [34]:
from sv_system.train.si_train import val, sv_test

for epoch_idx in range(0, config['n_epochs']):
    print("-"*30)
    curr_lr = optimizer.state_dict()['param_groups'][0]['lr']
    idx = 0
    while(epoch_idx >= config['lr_schedule'][idx]):
    # use new lr from schedule epoch not a next epoch
        idx += 1
        if idx == len(config['lr_schedule']):
            break
    curr_lr = config['lrs'][idx]
    optimizer.state_dict()['param_groups'][0]['lr'] = curr_lr
    print("curr_lr: {}".format(curr_lr))

#     train code
    train_loss, train_acc = train(config, train_loader, model, optimizer, criterion)

#     validation code
    val_loss, val_acc = val(config, val_loader, model, criterion, tqdm=tqdm_notebook)
    print("epoch #{}, val accuracy: {}".format(epoch_idx, val_acc))

#     evaluate best_metric
    if not config['no_eer']:
        # eer validation code
        eer, label, score = sv_test(config, sv_loader, model, trial, tqdm=tqdm_notebook)
        print("epoch #{}, sv eer: {}".format(epoch_idx, eer))

------------------------------
curr_lr: 0.001


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.0014, 1911.63920 
train loss, acc: 0.0015, 3823.58453 
train loss, acc: 0.0015, 5730.24357 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #0, val accuracy: 0.000884433975443244


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #0, sv eer: 0.4569268448514535
------------------------------
curr_lr: 0.001


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.0012, 1913.27459 
train loss, acc: 0.0014, 3824.42908 
train loss, acc: 0.0014, 5730.80413 



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=53), HTML(value='')), layout=Layout(display='…


epoch #1, val accuracy: 0.000884433975443244


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…


epoch #1, sv eer: 0.45314662975189013
------------------------------
curr_lr: 0.001


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1042), HTML(value='')), layout=Layout(display…

train loss, acc: 0.0013, 1914.19779 


Process Process-1096:
Process Process-1102:
Process Process-1103:
Process Process-1094:
Process Process-1101:
Process Process-1099:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-1104:
Traceback (most recent call last):
Process Process-1090:
Traceback (most recent call last):
Process Process-1097:
Traceback (most recent call last):
Process Process-1092:
Process Process-1091:
Process Process-1093:
Traceback (most recent call last):
Process Process-1100:
Process Process-1098:
Process Process-1095:
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-1089:
Tracebac

  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _w

KeyboardInterrupt: 

KeyboardInterrupt
KeyboardInterrupt
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "../sv_system/data/dataset.py", line 252, in __getitem__
    return self.preprocess(os.path.join(self.data_folder, self.files[index])), self.labels[index]
KeyboardInterrupt
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
  File "../sv_system/data/dataset.py", line 230, in preprocess
  