<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Improvements-pursued-in-this-notebook" data-toc-modified-id="Improvements-pursued-in-this-notebook-0.1"><span class="toc-item-num">0.1&nbsp;&nbsp;</span>Improvements pursued in this notebook</a></span></li></ul></li><li><span><a href="#Improvement-#1:-multi-label-classification" data-toc-modified-id="Improvement-#1:-multi-label-classification-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Improvement #1: multi-label classification</a></span><ul class="toc-item"><li><span><a href="#Performance-Tweaks" data-toc-modified-id="Performance-Tweaks-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Performance Tweaks</a></span><ul class="toc-item"><li><span><a href="#Leaky-ReLU" data-toc-modified-id="Leaky-ReLU-1.1.1"><span class="toc-item-num">1.1.1&nbsp;&nbsp;</span>Leaky ReLU</a></span></li></ul></li></ul></li></ul></div>

## Improvements pursued in this notebook

1. Change from binary classifier to multi category classifier:
    - add ims 0-9
    - add change loss fxn to cross entropy loss w/ softmax
    - change shape of final activation from 1 to 10
    - change label to 1HE
2. Add RGB

Super-short version with all of the helpers:

# Improvement #1: multi-label classification

In [None]:
from fastai.vision.all import *

### Data ###
def init_data(path, im_size, n_cls, batch_size):
    ## Train
    # ims
    for i in range(n_cls):
        new_ims = torch.stack(
            [tensor(Image.open(fn)) for fn in (path/'training'/f'{i}').ls()]
        ).float()/255
        if i == 0: ims = new_ims
        else: ims = torch.cat([ims,new_ims])
    train_ims = ims.view(-1,im_size)
    # lbls
    train_lbls = []
    for i in range(n_cls):
        l = L([0]*n_cls)
        l[i] = 1
        train_lbls += [l] * len((path/'training'/f'{i}').ls())    
    train_lbls = tensor(train_lbls)
    ## Valid
    # ims
    for i in range(n_cls):
        new_ims = torch.stack(
            [tensor(Image.open(fn)) for fn in (path/'testing'/f'{i}').ls()]
        ).float()/255
        if i == 0: ims = new_ims
        else: ims = torch.cat([ims,new_ims])
    valid_ims = ims.view(-1,im_size)
    # lbls
    valid_lbls = []
    for i in range(n_cls):
        l = L([0]*n_cls)
        l[i] = 1
        valid_lbls += [l] * len((path/'testing'/f'{i}').ls())    
    valid_lbls = tensor(valid_lbls)
    ## DataLoaders
    train_ds = L(zip(train_ims, train_lbls))
    valid_ds = L(zip(valid_ims, valid_lbls))
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=True)
    return train_dl

### Model ###
def init_mod(im_size, n_cls, hidden_params):
    mod = nn.Sequential(
        nn.Linear(im_size,hidden_params),
        nn.ReLU(),
        nn.Linear(hidden_params,n_cls)
    )
    return mod

### Create SGD Stepper ###
class ParamStepper:
    def __init__(self, p, lr): self.p,self.lr = list(p),lr # remembers your params & lr
        
    def step(self, *args, **kwargs):                       # take one step in the optimal direction
        for o in self.p: o.data -= o.grad.data * self.lr
            
    def zero_grad(self, *args, **kwargs):                  # zeros out gradients
        for o in self.p: o.grad = None

### Calculate accuracy over entire dl given dl,mod ###
def validate_epoch(dl, mod):
    a = [acc(mod(xb), yb) for xb,yb in dl]         # Gradients calculated & stored at mod(xb) call
    return round(torch.stack(a).mean().item(), 5)  # Avg over all mini-batches, return scalar (not tensor)

### Adjust parameters w/ stepper for each mini-batch in a dl
def train_once(dl, mod, stepper):
    for xb,yb in dl:
        calc_grad(xb, yb, mod)
        stepper.step()
        stepper.zero_grad()

### Calculate gradients for use in train_once ###
def calc_grad(x,y,mod):
    yp = mod(x)
    ls = loss(yp,y)
    ls.backward()

### Run `train_once` `epochs` times given data `dl`, model `mod`, and stepper `stepper`
def train_model(dl, mod, stepper, epochs):
    l = L()
    for i in range(epochs):
        train_once(dl, mod, stepper)
        # print(validate_epoch(dl, mod), end='\t')
        l += validate_epoch(dl, mod)
    return l

### Perform n training sessions ###
def train_model_n_times(dl, im_size, n_cls, hidden_params, epochs, lr, n):
    o = L()
    print('Current Session: ')
    for i in range(n):
        mod = init_mod(im_size, n_cls, hidden_params)
        stepper = ParamStepper(mod.parameters(), lr)
        o += train_model(dl, mod, stepper, epochs)
        print(i,end=', ')
    return tensor(o).reshape(n,epochs)
    
### Loss & Accuracy ###
def softmax(t):
    if len(t.shape) == 1: return torch.exp(t) / torch.exp(t).sum()
    else:                 return torch.exp(t) / torch.exp(t).sum(dim=1, keepdim=True)
def loss(yp, y):
    yps = softmax(yp)
    return (1 - (y * yps).sum(dim=1, keepdim=True)).mean()
def acc(yp,y):
    yp_max,yp_i = torch.max(yp, dim=1, keepdim=True)
    y_max, y_i  = torch.max(y,  dim=1, keepdim=True)
    return (yp_i==y_i).float().mean()

In [None]:
### Init ###
path          = untar_data(URLs.MNIST)
n_cls         = 10
im_size       = 28*28
batch_size    = 64*2*2*2
dl            = init_data(path, im_size, n_cls, batch_size)

In [None]:
### Training ###
test1 = train_model_n_times(dl, im_size, n_cls, hidden_params=30, epochs=40, lr=.1, n=4)
test1

0	1	2	3	

tensor([[0.2079, 0.3295, 0.5562, 0.7139, 0.7352, 0.7908, 0.8014, 0.8095, 0.8134,
         0.8182, 0.8212, 0.8233, 0.8253, 0.8264, 0.8285, 0.8293, 0.8308, 0.8316,
         0.8323, 0.8334, 0.8346, 0.8347, 0.8357, 0.8374, 0.8379, 0.8379, 0.8386,
         0.8389, 0.8394, 0.8399, 0.8402, 0.8411, 0.8410, 0.8416, 0.8425, 0.8427,
         0.8428, 0.8435, 0.8443, 0.8445],
        [0.2649, 0.4934, 0.6481, 0.6592, 0.6651, 0.6688, 0.6711, 0.7343, 0.7453,
         0.7495, 0.7529, 0.7544, 0.7566, 0.7572, 0.7580, 0.7595, 0.7601, 0.7609,
         0.7614, 0.7628, 0.8073, 0.8162, 0.8200, 0.8234, 0.8252, 0.8268, 0.8278,
         0.8291, 0.8300, 0.8319, 0.8324, 0.8332, 0.8334, 0.8345, 0.8355, 0.8360,
         0.8368, 0.8373, 0.8385, 0.8381],
        [0.2214, 0.4365, 0.6957, 0.7266, 0.7362, 0.8078, 0.8185, 0.8230, 0.8260,
         0.8286, 0.8319, 0.8331, 0.8344, 0.8361, 0.8374, 0.8381, 0.8400, 0.8400,
         0.8419, 0.8426, 0.8431, 0.8435, 0.8446, 0.8450, 0.8463, 0.8468, 0.8470,
         0.8470, 0.8480, 

## Performance Tweaks

### Leaky ReLU