In [None]:
import sys
from functools import partial
from typing import Dict, Callable
import torch
from torch import nn, optim, Tensor
from torchvision import transforms
import pandas as pd
import seaborn as sns
from tqdm.autonotebook import tqdm
from dataloaders import get_cifar10_data_loaders

print(torch.cuda.is_available(), torch.backends.cudnn.enabled)
cuda_flag = torch.cuda.is_available()
device = torch.device('cuda') if cuda_flag else torch.device('cpu')

In [11]:
def channel_avg(train_dl):
    r_sum = g_sum = b_sum = 0
    for x, y in iter(train_dl):
        r_sum += x[:,0,:,:].numpy().ravel().sum()
        b_sum += x[:,1,:,:].numpy().ravel().sum()
        g_sum += x[:,2,:,:].numpy().ravel().sum()
    num_pix = (50000 * 32 * 32)
    r_ave = r_sum / num_pix
    b_ave = b_sum / num_pix
    g_ave = g_sum / num_pix
    rgb_ave = [r_ave, b_ave, g_ave]
    return rgb_ave

dl, _ = get_cifar10_data_loaders(data_dir='./data/cifar10', batch_size=100)
rgb_ave = channel_avg(dl)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


KeyboardInterrupt: 

In [None]:
# schedulers.py

def _triangular_f(it, ss, min_lr, max_lr):
    'TODO: docstring'
    # calculate number of completed cycles
    cyc = math.floor(it / (ss * 2))
    # calculate number of completed iterations in current cycle
    it_cyc = it - (cyc * 2 * ss)
    # calculate distance from lr_max iteration
    mid_dist = math.fabs(it_cyc - ss)
    # scale lr difference
    scalar = mid_dist / ss
    return min_lr + (1 - scalar) * (max_lr - min_lr)

def _triangular2_f(it, ss, min_lr, max_lr):
    'TODO: docstring'
    # calculate number of completed cycles
    cyc = math.floor(it / (ss * 2))
    # calculate number of completed iterations in current cycle
    it_cyc = it - (cyc * 2 * ss)
    # calculate distance from lr_max iteration
    mid_dist = math.fabs(it_cyc - ss)
    # scale lr difference
    scalar = mid_dist / ss
    return min_lr + (1 - scalar) * ((max_lr - min_lr) / 2 ** cyc)
    
class FixedScheduler(optim.lr_scheduler._LRScheduler):
    'TODO: docstring'
    
    def __init__(self, optimizer:optim.Optimizer):
        'TODO: docstring'
        super().__init__(optimizer)
    
    def get_lr(self):
        'TODO: docstring'
        # _LRScheduler increments `last_epoch` on each call to `step()`
        return [pg['lr'] for pg in self.optimizer.param_groups]
    
class TriangularScheduler(optim.lr_scheduler._LRScheduler):
    'TODO: docstring'
    
    def __init__(self, step_size:int, min_lr:float, max_lr:float, optimizer:optim.Optimizer):
        'TODO: docstring'
        self.step_size = step_size
        self.min_lr = min_lr
        self.max_lr = max_lr
        super().__init__(optimizer)
    
    def get_lr(self):
        'TODO: docstring'
        it = self.last_epoch
        lr = _triangular_f(it, self.step_size, self.min_lr, self.max_lr)
        return [lr] * len(self.optimizer.param_groups)
    
class Triangular2Scheduler(optim.lr_scheduler._LRScheduler):
    'TODO: docstring'
    
    def __init__(self, step_size:int, min_lr:float, max_lr:float, optimizer:optim.Optimizer):
        'TODO: docstring'
        self.step_size = step_size
        self.min_lr = min_lr
        self.max_lr = max_lr
        super().__init__(optimizer)
    
    def get_lr(self):
        'TODO: docstring'
        it = self.last_epoch
        lr = _triangular2_f(it, self.step_size, self.min_lr, self.max_lr)
        return [lr] * len(self.optimizer.param_groups)


In [None]:
# model.py

class Cifar10Net(nn.Module):
    'TODO: docstring'
    
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.AvgPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.ip1 = nn.Linear(64 * 7 * 7, 64)
        self.ip2 = nn.Linear(64, 10)
        
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.pool1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.pool2(out)
        out = self.conv3(out)
        out = self.ip1(out.view(out.shape[0], out.shape[1] * out.shape[2] * out.shape[3]))
        out = self.ip2(out)
        return out
    
def init_weights(module:nn.Module):
    if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)):
        nn.init.normal_(module.weight.data, std=0.01, mean=0)
        nn.init.constant_(module.bias.data, 0)

In [None]:
# train.py

def validate(dl:torch.utils.data.DataLoader,
             model:nn.Module,
             criterion:nn.Module,
             device: torch.device):
    'TODO: docstring'
    model.eval()
    total_loss = total_correct =  0
    with torch.no_grad():
        for xs, ys in dl:
            xs = xs.to(device)
            ys = ys.to(device)
            out = model(xs)
            # loss
            loss = criterion(out, ys)
            total_loss += loss.item() * xs.size(0)
            # acc
            correct = (out.max(1)[1] == ys).sum().item()
            total_correct += correct

    return total_loss / len(dl.dataset), total_correct / len(dl.dataset)


def train_batch(xs:Tensor,
                ys:Tensor,
                model:nn.Module,
                criterion:nn.Module,
                optimizer:optim.Optimizer):
    'TODO: docstring'
    model.train()
    optimizer.zero_grad()
    out = model(xs)
    loss = criterion(out, ys)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_run(model:nn.Module,
              train_dl:torch.utils.data.dataloader.DataLoader,
              criterion:nn.Module,
              optimizer:optim.Optimizer,
              scheduler:optim.lr_scheduler._LRScheduler,
              num_it:int,
              on_batch_end:Callable[[int, float, float], None],
              device: torch.device):
    'TODO: docstring'
    iterator = iter(train_dl)
    bar = tqdm(range(num_it))
    for i in bar:
        try:
            xs, ys = next(iterator)
        except StopIteration:
            iterator = iter(train_dl)
            xs, ys = next(iterator)
        xs = xs.to(device)
        ys = ys.to(device)
        loss = train_batch(xs, ys, model, criterion, optimizer)
        on_batch_end(bar, i, loss, scheduler.get_lr()[0])
        scheduler.step()
            
def on_batch_end(recorder:Dict,
                 test_dl: torch.utils.data.dataloader.DataLoader,
                 model: nn.Module,
                 criterion: nn.Module,
                 p_bar:tqdm,
                 it_num:int,
                 trn_loss:float,
                 lr: float):
    'TODO: docstring'
    if it_num == 0 or (it_num + 1) % 500 == 0 :
        recorder['iteration'].append(it_num + 1)
        recorder['trn_loss'].append(trn_loss)
        recorder['lr'].append(lr)
        val_loss, val_acc = validate(test_dl, model, criterion, device)
        recorder['val_loss'].append(val_loss)
        recorder['val_acc'].append(val_acc)
        p_bar.write(f'{trn_loss} | {val_loss} | {val_acc}')