In [5]:
import os 
import math
import shutil
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [6]:
def get_dataset(img_size, dataset):
    transform_mnist = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()])
    transform_other = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
    
    if dataset == 'mnist':
        train_data = datasets.__dict__[dataset.upper()](root='./data', train=True, transform=transform_mnist, download=True)
        test_data = datasets.__dict__[dataset.upper()](root='./data', train=True, transform=transform_mnist, download=True)
        in_ch = 1
    elif dataset == 'cifar10':
        train_data = datasets.__dict__[dataset.upper()](root='./data', train=True, transform=transform_other, download=True)
        test_data = datasets.__dict__[dataset.upper()](root='./data', train=True, transform=transform_other, download=True)
        in_ch = 3
    elif dataset == 'svhn':
        training_data = datasets.__dict__[dataset.upper()](root='./data', split='train', transform=transform_other, download=True)
        test_data = datasets.__dict__[dataset.upper()](root='./data', split='test', transform=transform_other, download=True)
        in_ch = 3
    dataloader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    dataloader_test = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return dataloader_train, dataloader_test, in_ch

In [7]:
class CNN(nn.Module):
    def __init__(self, in_ch=3, n_cls=10, img_size=32):
        super(CNN, self).__init__()
        n_convs = int(np.log2(img_size)) - 1
            
        layers = []
        out_ch = 16
        for _ in range(n_convs):
            layers.append(nn.Conv2d(in_ch, out_ch, 3, 1, 1))
            layers.append(nn.BatchNorm2d(out_ch))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.MaxPool2d(2, 2))
            in_ch = out_ch
            out_ch = out_ch * 2
        
        self.hidden_layers = nn.Sequential(*layers)
        fc_in = 2 * 2 * in_ch
        self.classifier = nn.Sequential(
            nn.Linear(fc_in, 512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, n_cls))
        
    def forward(self, x):
        h = self.hidden_layers(x)
        h_flat = h.flatten(start_dim=1)
        out = self.classifier(h_flat)
        return out

In [10]:
class CustomLossFunction:
    def __init__(self, reduction='mean'):
        self.reduction = reduction
        
    def xent(self, x, t):
        b, c = x.shape
        x_log_softmax = torch.log_softmax(x, dim=1)
        if self.reduction == 'mean':
            loss = -torch.sum(t*x_log_softmax) / b
        elif self.reduction == 'sum':
            loss = -torch.sum(t*x_log_softmax)
        elif self.reduction == 'none':
            loss = -torch.sum(t*x_log_softmax, keepdims=True)
        return loss

In [11]:
img_size = 32
dataset = 'cifar10'
lr = 0.1
momentum = 0.9
epochs = 100
batch_size = 100
device = torch.device('cuda:0')

dataloader_train, dataloader_test, in_ch = get_dataset(img_size, dataset)
model = CNN(in_ch=in_ch, n_cls=10, img_size=img_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
cstm_lossfunc = CustomLossFunction()
xent = nn.CrossEntropyLoss()
iters = 0
    
scheduler = [int(epochs*0.5), int(epochs*0.75)]
adjust_learning_rate = lr_scheduler.MultiStepLR(optimizer, scheduler, gamma=0.1)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.contiguous().view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [13]:
def train(epoch, model, dataloader, optimizer, cstm_lossfunc, iters):
    model.train()
    top1 = AverageMeter()
    losses = AverageMeter()
    for idx, (data, tgt) in enumerate(dataloader):
        data, tgt = data.to(device), tgt.to(device)
        b = data.size(0)
        gamma = np.random.beta(1,1)
                
        rand_idx = torch.randperm(b).to(device)
        data_rand = data[rand_idx]
        mixed_data = gamma * data + (1 - gamma) * data_rand
            
        onehot = torch.eye(10)[tgt].to(device)
        onehot_rand = onehot[rand_idx]
        mixed_tgt = gamma * onehot + (1 - gamma) * onehot_rand
            
        optimizer.zero_grad()
        logits = model(mixed_data)
        loss = cstm_lossfunc.xent(logits, mixed_tgt)
        loss.backward()
        optimizer.step()
            
        iters += 1
        if idx % 100 == 0:
            prec1 = accuracy(logits.data, tgt, topk=(1,))[0]
            losses.update(loss.data.item(), b)
            top1.update(prec1, b)
            print('%d epochs [%d/%d]| loss: %.4f | acc: %.4f |' % (epoch, idx, len(dataloader), loss.item(), top1.avg))
    return iters

In [14]:
def validation(epoch, model, dataloader, xent):
    model.eval()
    top1 = AverageMeter()
    losses = AverageMeter()
    for idx, (data, tgt) in enumerate(dataloader):
        data, tgt = data.to(device), tgt.to(device)
            
        with torch.no_grad():
            logits = model(data)
                
        loss = xent(logits, tgt)    
        prec1 = accuracy(logits.data, tgt, topk=(1,))[0]
        losses.update(loss.data.item(), data.size(0))
        top1.update(prec1, data.size(0))
    print('%d epochs | loss: %.4f | acc: %.4f |' % (epoch, loss.item(), top1.avg))
    return top1.avg

In [None]:
for epoch in range(epochs):
    iters = train(epoch, model, dataloader_train, optimizer, cstm_lossfunc, iters)
    adjust_learning_rate.step()
    test_acc = validation(epoch, model, dataloader_test, xent)