# Chapter 4, Exercise 1: Implement your own Learner

> Create your own implmentation of Learner from scratch, based on the training loop shown in this chapter.

As a reminder, the loop is:

- Init
- Predict
- Loss 
- Gradient
- Step
- Stop

Let's start with the boilerplate:

In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.set_printoptions(edgeitems=2)
torch.manual_seed(42) # Life, the Universe, and Everything

<torch._C.Generator at 0x7fbd18632c90>

I'll use the signature from the book; however, for now I'm going to leave out metrics.  I may come back to this later.

Let's create our model.  [Weights are  initialized for us](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear).

In [2]:
my_model = nn.Sequential(
    nn.Linear(in_features=28*28, out_features=30),
    nn.ReLU(),
    nn.Linear(30, 1)
)

Next, the data loader...which I guess means we'll need some data.  We'll use the FastAI 3/7 image set.

In [3]:
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)
path.ls()

(#3) [Path('/home/aardvark/.fastai/data/mnist_sample/labels.csv'),Path('/home/aardvark/.fastai/data/mnist_sample/valid'),Path('/home/aardvark/.fastai/data/mnist_sample/train')]

Let's load those into tensors:

In [4]:
training_3 = torch.stack([tensor(Image.open(o)) for o in (path /'train/3').ls().sorted()]).float() / 255.0
training_7 = torch.stack([tensor(Image.open(o)) for o in (path /'train/7').ls().sorted()]).float() / 255.0
len(training_3), len(training_7)

(6131, 6265)

In [5]:
train_x = torch.cat([training_3, training_7]).view(-1, 28*28)
train_x.shape

torch.Size([12396, 784])

Time for some labels.

In [6]:
train_y = tensor([1] * len(training_3) + [0] * len(training_7))
train_y.shape

torch.Size([12396])

Now time for the loader:

In [7]:
dset = list(zip(train_x, train_y))

In [8]:
valid_3 = torch.stack([tensor(Image.open(o)) for o in (path / 'valid/3').ls().sorted()]).float() / 255.0
valid_7 = torch.stack([tensor(Image.open(o)) for o in (path / 'valid/7').ls().sorted()]).float() / 255.0
valid_x = torch.cat([valid_3, valid_7]).view(-1, 28*28)
valid_y = tensor([1] * len(valid_3) + [0] * len(valid_7))
valid_dset = list(zip(valid_x, valid_y))
valid_x.shape, valid_y.shape

(torch.Size([2038, 784]), torch.Size([2038]))

In [9]:
dl = DataLoader(dset, batch_size=256)
xb, yb = first(dl)
xb.shape, yb.shape

(torch.Size([256, 784]), torch.Size([256]))

In [10]:
valid_dl = DataLoader(valid_dset, batch_size=256)

Let's make sure this works as expected:

Next up would be optimizer.  I'm going to use the PyTorch SGD optimizer here:

In [11]:
my_optimizer = optim.SGD(my_model.parameters(), lr=0.01, momentum=0.9)

As for loss function, let's just go simple and use mse.   We'll save something fancier for when we get into MNIST.

In [12]:
def my_loss(predicted, actual):
    return (torch.mean(predicted - actual))**2

Now it's time to try some training!

In [13]:
class MyLearner():
    
    def __init__(self, dl, model, opt):
        self.dl_train = dl[0]
        self.dl_valid = dl[1]
        self.model = model
        self.opt = opt(self.model.parameters(), lr=0.1)
        
    def mnist_loss(self, preds, targets):
        preds = preds.sigmoid()
        return torch.where(targets==1, 1-preds, preds).mean()
    
    def batch_acccuracy(self, xb, yb):
        preds = xb.sigmoid()
        correct = (preds > 0.5) == yb
        return correct.float().mean()
    
    def validate_epoch(self):
        accs = [self.batch.accuracy(self.model(x), y) for x, y in self.dl_valid]
        return round(torch.stack(accs.mean().item()), 4)
    
    def cal_grad(self, x, y):
        preds = self.model(x)
        loss = self.mnist_loss(preds, y)
        loss.backward
    
    def train_epoch(self):
        for x, y in self.dl_train:
            self.cal_grad(x, y)
            self.opt.step()
            # self.opt.zero_grad()
            
    def fit(self, epochs):
        for i in range(epochs):
            self.train_epoch()
            print(self.validate_epoch(), end=" ")
        
class MyOriginalLearner:        
    def fit(self, epochs=10, verbose=False):
        '''Fit method 
        '''
        for i in range(0, epochs):
            while True:
                try:
                    X, y = self.loader.next()
                    X = torch.squeeze(X[0])
                    pred = self.model(X)
                    loss = self.loss_func(pred, y)
                    self.metrics['loss'] += loss
                    loss.backward()
                    self.opt_func.step()
                    self.opt_func.zero_grad()
                    if verbose is True and self.loader.counter % 1000 == 0:
                        print("Pred: {}, Actual: {}, Loss: {}".format(pred, y, loss))
                except NoMoreData:
                    self.loader.reset()
                    self.print_epoch_loss()
                    break
                    
    def fit_once(self):
        X, y = self.loader.next()
        X = torch.squeeze(X[0])
        pred = self.model(X)
        loss = self.loss_func(pred, y)
        loss.backward()
        self.opt_func.step()
        self.opt_func.zero_grad()
        return X
                                                                                        
    def print_epoch_loss(self):
        self.metrics['loss'] /= self.loader.length
        print("Epoch loss: {}".format(self.metrics['loss']))
        self.metrics['loss'] = 0

And now to put it all together:

In [14]:
my_learner = MyLearner([dl, valid_dl], my_model, my_optimizer)


TypeError: 'SGD' object is not callable

Let's try out the `fit_once()` method:

In [31]:
X = my_learner.fit_once()

We'll reset the counter...

In [32]:
my_learner.loader.reset()
my_learner.loader.counter

0

Now let's try it out!

In [35]:
%timeit my_learner.fit()

Epoch loss: 0.0011706968070939183
Epoch loss: 0.0011706004152074456
Epoch loss: 0.0011706734076142311
Epoch loss: 0.001170855131931603
Epoch loss: 0.0011711455881595612
Epoch loss: 0.0011715400032699108
Epoch loss: 0.0011719489702954888
Epoch loss: 0.0011723671341314912
Epoch loss: 0.001172703574411571
Epoch loss: 0.0011729162652045488
Epoch loss: 0.0011731692356988788
Epoch loss: 0.0011733882129192352
Epoch loss: 0.001173645374365151
Epoch loss: 0.001173866563476622
Epoch loss: 0.0011740544578060508
Epoch loss: 0.0011741825146600604
Epoch loss: 0.0011742532951757312
Epoch loss: 0.0011742805363610387
Epoch loss: 0.0011742844944819808
Epoch loss: 0.0011742758797481656
Epoch loss: 0.0011742645874619484
Epoch loss: 0.0011742579517886043
Epoch loss: 0.0011742522474378347
Epoch loss: 0.0011742470087483525
Epoch loss: 0.0011742417700588703
Epoch loss: 0.0011742364149540663
Epoch loss: 0.0011742307106032968
Epoch loss: 0.001174224424175918
Epoch loss: 0.0011742182541638613
Epoch loss: 0.00117

In [36]:
validation_3 = torch.stack([tensor(Image.open(o)) for o in (path /'valid/3').ls().sorted()]).float() / 255.0
validation_7 = torch.stack([tensor(Image.open(o)) for o in (path /'valid/7').ls().sorted()]).float() / 255.0
len(validation_3), len(validation_7)

(1010, 1028)

In [41]:
all_valid_data = torch.cat([validation_3, validation_7])
all_valid_data.shape

torch.Size([2038, 28, 28])

In [43]:
valid_labels_3 = torch.ones([len(validation_3)])
valid_labels_7 = torch.zeros([len(validation_7)])
all_valid_labels = torch.cat([valid_labels_3, valid_labels_7])
all_valid_labels.shape

torch.Size([2038])

In [46]:
my_valid_loader = MyLoader(training_data=all_valid_data.view(-1, 28*28), labels=all_valid_labels, batch_size=5)

In [62]:
def batch_accuracy(xb, yb):
    preds = my_learner.model(xb)
    correct = (preds < 0.5) == yb
    return correct.float().mean()

In [91]:
def validate_epoch(model):
    accs = []
    try:
        while True:
            xb, yb = my_valid_loader.next()
            for i in range(len(xb)):
                # print(yb[i])
                # acc = batch_accuracy[]
                accs += [batch_accuracy((xb[i]), yb[i])]
                # print(accs)
                # accs += [batch_accuracy(model(xb_item, ybitem)) for xbitem, yb in my_valid_loader.next()]
    except NoMoreData:
        my_valid_loader.reset()
    return round(torch.stack(accs).mean().item(), 4)

In [92]:
my_valid_loader.reset()

In [93]:
validate_epoch(my_learner.model)

0.4956

# Status

## Predictions

Once I figured out where I needed to put the call to `zero_grad()`, I got predictions.  The loss is strangely low in training, but *terrible* in validation.  OTOH, the loss is similarly terrible on pg 173.

## Overall

I think this is a bit of a mess at the moment.  

This chapter of the book has two big sections:

- One where a simple linear module is used to solve a parabola, with all of the steps being coded from scratch;
- And one where a neural network coded in PyTorch is used to build a model for MNIST.

The two exercises match these sections.

However, the approach I've taken above is a mix of those two approaches.  The `MyLearner` class uses the signature from the PyTorch model and the steps of the first.  `my_model` is a PyTorch model.  Making these things match each other is a bit messy.

In [None]:
class Learner_try:
    def __init__(self, dl, model, opt):
        self.dl_train = dl[0]
        self.dl_valid = dl[1]
        self.model = model
        self.opt = opt(self.model.parameters(), lr = 0.1)


    def mnist_loss(self, preds, targets):
        preds = preds.sigmoid()
        return torch.where(targets==1, 1 - preds, preds).mean()

    def batch_accuracy(self, x, y):
        preds = x.sigmoid()
        correct = (preds>0.5) == y
        return correct.float().mean()

    def validate_epoch(self):
        accs = [self.batch_accuracy(self.model(x), y) for x,y in self.dl_valid]
        return round(torch.stack(accs).mean().item(), 4)

    def cal_grad(self, x, y):
        preds = self.model(x)
        loss = self.mnist_loss(preds, y)
        loss.backward()

    def train_epoch(self):
        for x, y in self.dl_train:
            self.cal_grad(x, y)
            self.opt.step()
            #self.opt.zero_grad()      #This is the step which is acting wierd

    def fit(self, epochs):
        for i in range(epochs):
            self.train_epoch()
            print(self.validate_epoch(), end = " ")
    
simple_net = nn.Sequential(nn.Linear(28 * 28, 30),
                      nn.ReLU(),
                      nn.Linear(30, 1),
                      nn.Sigmoid())

opt = SGD

learn = Learner_try(dls, simple_net, opt = opt)

learn.fit(20)