# MNIST Loss Function Example

In [1]:
import torch
import matplotlib.pyplot as plt

import fastbook
fastbook.setup_book()
from fastai.vision.all import *
from fastbook import *
#!pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html


  return torch._C._cuda_getDeviceCount() > 0


In [61]:
path = untar_data(URLs.MNIST_SAMPLE)
path

Path('/home/azhasc/.fastai/data/mnist_sample')

In [67]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

6131

In [4]:
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255

three_tensors = [tensor(Image.open(o)) for o in threes]
seven_tensors = [tensor(Image.open(o)) for o in sevens]

In [70]:
stacked_threes = torch.stack(three_tensors).float()/255
stacked_sevens = torch.stack(seven_tensors).float()/255


(6131, 6265)

In [69]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1,28*28)
train_y = tensor([1]*len(threes) + [0]* len(sevens)).unsqueeze(1)

dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y

12396

In [25]:
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [8]:
#Step 1
def init_params(size, std = 1.0): return (torch.randn(size)*std).requires_grad_()

In [9]:
weights = init_params((28*28,1))
bias = init_params(1)

In [10]:
def linear1(xb): return ((xb@weights) + bias)

def sigmoid(x): return (1/(1+torch.exp(-x)))
def mnist_sig_loss(predictions,targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1,1-predictions,predictions).mean()

In [26]:
#Create mini batches

dl = DataLoader(dset[:100], batch_size = 256)
valid_dl = DataLoader(valid_dset[:100], batch_size = 256)
xb,yb = first(dl)
xv,yv = first(valid_dl)

In [12]:
def cal_grad(xb,yb,model):
    preds = model(xb)
    loss = mnist_sig_loss(preds,yb)
    loss = loss.backward()
    return weights.grad.mean(),bias.grad

In [13]:
def train_epoch(model,lr,params):
    for xb,yb in dl:
        cal_grad(xb,yb,model)
        for p in params:
            p.data -= p.grad * lr
            p.grad.zero_()

In [14]:
def valid_epoch(model,lr,params):
    for xb,yb in valid_dl:
        cal_grad(xb,yb,model)
        for p in params:
            p.data -= p.grad * lr
            p.grad.zero_()

In [15]:
def train_accuracy(xb,yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [16]:
def validate_train_epoch(model):
    accs = [train_accuracy(model(xb),yb) for xb,yb in dl]
    return round(torch.stack(accs).mean().item(),4)

In [17]:
def validate_valid_epoch(model):
    accs = [train_accuracy(model(xb),yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(),4)

In [18]:
lr = 1.0
params = weights,bias

train_epoch(linear1,lr,params)
validate_train_epoch(linear1)

0.7205

In [19]:
valid_epoch(linear1,lr,params)
validate_valid_epoch(linear1)

0.845

In [20]:
for i in range(20):
    train_epoch(linear1, lr, params)
    print(validate_train_epoch(linear1), end=' ')

0.8542 0.9051 0.9265 0.9396 0.9477 0.9544 0.958 0.961 0.9631 0.9647 0.9663 0.9676 0.9693 0.9706 0.9718 0.9727 0.9736 0.9745 0.9753 0.9755 

In [21]:
for i in range(20):
    valid_epoch(linear1, lr, params)
    print(validate_valid_epoch(linear1), end=' ')

0.9687 0.9687 0.9696 0.9696 0.9696 0.9701 0.9701 0.9716 0.9721 0.9726 0.9726 0.9736 0.9745 0.9745 0.975 0.975 0.9755 0.976 0.976 0.9765 