## Set up environment

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Import all code from previous notebooks

In [2]:
#export
from exp.nb_03 import *

## Get the MNIST training and validation data sets

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=4799)

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,64
c = y_train.max().item()+1
loss_func = F.cross_entropy

In [4]:
# data set properties x and y come from Dataset
print(valid_ds.x.shape)
print(valid_ds.y.shape)

torch.Size([10000, 784])
torch.Size([10000])


# 1. Improving the fit() function
### Factor the connected pieces of information (model, optimizer, loss function and data) from the `fit()` argument list into a `Learner Class object`

`fit(n_epochs, model, loss_func, opt, train_dl, valid_dl)`

Let's modify the call to `fit()` to look like this:

`fit(n_epochs, learner)`

Here, `learner` is a `Learner` Class object, that we will define as a container for `model`, `loss_func`, `opt`, `train_dl`, `valid_dl` to be passed into `fit()`

This will allow us to tweak what's happening inside the training loop in other places of the code. Because the `Learner` Class object will be mutable, changing any of its attributes elsewhere will modify our training loop.

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5363)

### The DataBunch class
For convenience we define a `DataBunch()` class, as a container for the DataLoader() objects and optionally n_out, the number of output channels.
DataBunch() takes inputs train_dl(), valid_dl(), and optionally n_out.
It has properties train_ds(), valid_ds()

In [5]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, n_out=None):
        self.train_dl,self.valid_dl,self.n_out = train_dl,valid_dl,n_out
    
    # add train_ds() as an attribute
    @property
    def train_ds(self): return self.train_dl.dataset
     
        
    # add valid_ds() as an attribute
    @property
    def valid_ds(self): return self.valid_dl.dataset
    

### The Learner Class

In [6]:
#export
# instantiates the model and the optimizer, given the data and parameters
def get_model(data, learning_rate=0.5, n_hidden = 50):
    n_columns = data.train_ds.x.shape[1]
    n_out = data.n_out
    model = nn.Sequential(nn.Linear(n_columns,n_hidden), nn.ReLU(), nn.Linear(n_hidden,n_out))
    # why can we access the optimizer from within this function?
    return model, optim.SGD(model.parameters(), lr=learning_rate)

In [7]:
#export
# the Learner() class is a container for the model, optimization, loss function and data 
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

## Refactor the fit() function according to plan

In [8]:
# refactor the fit() function according to our plan
#     Note the order of the inputs has been switched in order to allow n_epochs to be a keyword argument
def fit(learn,n_epochs):
    
    # n_epochs is a keyword argument
    for epoch in range(n_epochs):
        
        # training phase
        learn.model.train()
        for xb,yb in learn.data.train_dl:
            pred = learn.model(xb)
            loss = learn.loss_func(pred, yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        # prediction phase
        # tells pytorch that we are in eval phase, so no dropout, batchnorm, etc.
        learn.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy(pred,yb)
        
        # compute average loss and accuracy
        n_valid = len(learn.data.valid_dl)
        avg_loss, avg_acc = tot_loss/n_valid, tot_acc/n_valid
        print(epoch, avg_loss ,avg_acc )
        
    return avg_loss, avg_acc

## Prepare for minibatch training package up the data into a DataBunch

In [9]:
# package up the data in a DataBunch object
# the * forces you to generate and process the entire sample
batch_size = 64
n_out = 10
data = DataBunch(*get_dls(train_ds, valid_ds, batch_size=batch_size), n_out=n_out)

# data also has properties x and y, inherited from train_ds
print(data.train_dl.dataset.x.shape)
print(data.train_dl.dataset.y.shape)

torch.Size([50000, 784])
torch.Size([50000])


In [10]:
# instantiate a Learner object
learn = Learner(*get_model(data), loss_func, data)

In [11]:
# check module list outputs
print(learn.model.train())
print(learn.model.eval())

Sequential(
  (0): Linear(in_features=784, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=10, bias=True)
)
Sequential(
  (0): Linear(in_features=784, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=10, bias=True)
)


## Instantiate a minibatch training loop!

In [12]:
# instantiate a minibatch training loop and start training!
loss,acc = fit(learn, n_epochs = 1)

0 tensor(0.4549) tensor(0.8686)


# 2. Adding callbacks 
Callbacks are functions that can be set to initiate desired sequences of actions at various stages in the minibatch training process. They will allow us to simplify the training loop, and provide added flexibility.

Here is our training loop (without validation) from the previous notebook, with the inner loop contents factored out:

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(n_epochs):
        for b in train_dl: one_batch(*b)
```

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5628)

## The CallBack() Class

Is a container for a set of nine basic callback methods.

Each method is associated with a well-defined stage in the training process, and returns `True` 


In [13]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self):
        return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self):
        return True
    def after_epoch(self): 
        return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self):
        return True
    def after_step(self):
        return True

## The CallbackHandler Class

is a container for a collection of methods comprised of

    (1) a "wrapper" for each method in the Callback() class 

    (2) a do_stop() method that can set the learn.stop flag
    
takes a list of Callback() objects  as input

each "wrapper" modifies the behavior of its associated callback, and returns a state flag, as follows:
    - initialize the state flag
    - optionally set some object properties and/or run some methods
    - for each Callback() object in the input list:
      check whether the callback is among its methods 
         if so, execute the callback, and
            if the callback returns True, set the state flag to its initialized value
            if the callback returns False, set the state flag to False
         if not, set the state flag to False
    - return the state flag 

In [14]:
class CallbackHandler():
    def __init__(self,callbacks=None):
        self.callbacks = callbacks if callbacks else []

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        learn.stop = False
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_fit(learn)
        return state

    def after_fit(self):
        state = not self.in_train
        for callback in self.callbacks: 
            state = state and callback.after_fit()
        return state
    
    def begin_epoch(self, epoch):
        # the next line trains the model
        self.learn.model.train()
        self.in_train=True
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_epoch(epoch)
        return state

    def begin_validate(self):
        # what does the next line do?
        self.learn.model.eval()
        self.in_train=False
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_validate()
        return state

    def after_epoch(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_epoch()
        return state
    
    def begin_batch(self, xb, yb):
        state = True
        for callback in self.callbacks:
            state = state and callback.begin_batch(xb, yb)
        return state

    def after_loss(self, loss):
        state = self.in_train
        for callback in self.callbacks: 
            state = state and callback.after_loss(loss)
        return state

    def after_backward(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_backward()
        return state

    def after_step(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_step()
        return state
    
    def do_stop(self):
        try:     
            return self.learn.stop
        finally: 
            self.learn.stop = False

## A TestCallback Class

In [15]:
class TestCallback(Callback):
    
    # Q: does the begin_fit() callback below inherit from 
    #     the Callback() class, or 
    #     from the CallbackHandler() class? 
    # A: I think from the Callback() class

    # modify the begin_fit() callback by adding an iteration counter
    def begin_fit(self,learn):
        # calls begin_fit() method from the Callback class
        super().begin_fit(learn)
        self.n_iters = 0
        max_iter = 10
        self.max_iter = max_iter
        # self.n_epochs_float = 0.
        return True
    
    # modify the afer_step() callback to increment the iteration counter, print the current iteration number,
    #   and to set the learn.stop flag to True after max_iter iterations
    # Q: why doesn't after_step() inherit from its previous definition, as we did with begin_fit()?
    # A: because after_step() does nothing but return True
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=self.max_iter: 
            self.learn.stop = True
        return True

## Refactor the training loop using callbacks

In [16]:
# process a single batch, with an input list of callbacks
def one_batch(xb, yb, callback):
    
    # run the model and computer the loss function for the current batch
    if not callback.begin_batch(xb,yb): 
        return
    loss = callback.learn.loss_func(callback.learn.model(xb), yb)
    
    # do backpropagation for the current batch
    if not callback.after_loss(loss): 
        return
    loss.backward()
    
    # update the parameters for the current batch
    if callback.after_backward(): 
        callback.learn.opt.step()
        
    # zero the gradients to prepare for the next batch
    if callback.after_step(): 
        callback.learn.opt.zero_grad()

# process the entire dataset, with an input list of callbacks
#   (i.e. loop through all the batches)
def all_batches(dataloader, callback):
    for xb,yb in dataloader:
        
        # process a batch, then check whether to stop or to process the next batch
        one_batch(xb, yb, callback)
        if callback.do_stop(): 
            return

# training loop with an input callbacks
def fit(learn, callback, n_epochs):
    
    # check whether or not to start the training loop
    if not callback.begin_fit(learn): 
        return
    
    # loop over the specified number of epochs
    for epoch in range(n_epochs):
        
        # check whether or not to process the next epoch
        if not callback.begin_epoch(epoch): 
            continue # jumps to next epoch
        all_batches(learn.data.train_dl, callback)
        
        # check whether or not to process the validation set
        if callback.begin_validate():
            with torch.no_grad(): 
                all_batches(learn.data.valid_dl, callback)
        
        # check whether or not to break after the current epoch has been processed
        if callback.do_stop() or not callback.after_epoch(): 
            break
            
    # set the callback state to indicate that fit() has been run
    callback.after_fit()

In [17]:
# run a training loop with TestCallback()
# The input is a CallbackHandler 
fit(learn, callback=CallbackHandler([TestCallback()]),n_epochs=1)

1
2
3
4
5
6
7
8
9
10


This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing the `callback` list to so many functions is a strong hint they should all be in the same class!

### Why is learn.stop still `False`? The training loop should have set it to `True` before being interrupted.

In [18]:
learn.stop

False

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5811)

## Refactor the Callback() class
Callback() now becomes a container for a callback function, which can be passed in via the set_runner() method.
Add a `name` property, `_order` and `run` attributes, and a `__getattr__` method

Note: I'm confused about what `__getattr__` does

In [19]:
#export
import re

# use regular expressions to construct a 'snake case' callback name for each 'camel case' callback name
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

# refactored Callback() class
class Callback():
        
    # initialize _order to zero. 
    _order=0
    
    # set_runner() method serves to replace all the basic callback methods
    #     note that initially self.run is unset -- there is no default value 
    def set_runner(self, run): 
        self.run=run
    
    def __getattr__(self, cb): 
        return getattr(self.run, cb)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        # if name does not exist, set the name to 'callback'
        return camel2snake(name or 'callback')
        # the above line is equivalent to the following block
        '''
        try:
            return camel2snake(name)
        except:
            return 'callback'
        '''
       

### Example: CamelCase vs. snake_case callback names

In [20]:
# construct a 'snake case' callback name given a 'camel case' callback name
callback_name = 'IngeniousNewAmazingCallback'
print('camel case name is ',callback_name,'; snake case name is ',camel2snake(callback_name))

camel case name is  IngeniousNewAmazingCallback ; snake case name is  ingenious_new_amazing_callback


### How Callback() works now:

In [21]:
# instantiate a Callback() object
c1 = Callback()
try:
    print(c1.run)
except:
    print('c1.run is undefined')

c1.run is undefined


In [22]:
# pass a callback function to the Callback() object
c1.set_runner(TestCallback())
print('c1.run is now ',c1.run)

c1.run is now  <__main__.TestCallback object at 0x000002950146E6D8>


## Refactor the TestCallback() class 

This version just tests whether the number of iterations has reached the maximum allowed value after each parameter update step and if so returns `True`

In [23]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters>=max_iter: 
            return True

## Define a TrainEvalCallback Class
with methods to handle training and validation and to maintain count of iterations and of fraction of an epoch completed after each batch  

In [24]:
#export
class TrainEvalCallback(Callback):
    
    # initialize the epoch and iteration counters
    def begin_fit(self):
        # n_epochs_float keeps track of where we are in the current epoch, need not be integer
        #     in original code, this variable was named n_epochs, but there is another variable with than name
        #     so it is preferable to give it a more apporpriate and descriptive name.
        self.run.n_epochs_float=0.
        self.run.n_iter=0
    
    # if we are in the training phase, increment the epoch and iteration counters
    def after_batch(self):
        if not self.in_train: 
            return
        # each training iteration represents a fraction of an epoch
        #      n_iters comes from TestCallback(), and is an attribute of begin_fit()
        self.run.n_epochs_float += 1./self.n_iters
        self.run.n_iter   += 1
    
    # execute the training phase
    def begin_epoch(self):
        self.run.n_epochs_float=self.n_epochs_float
        self.model.train()
        self.run.in_train=True
    
    # execute prediction phase
    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

### Contructing the callback name

In [25]:
# trim out the substring 'Callback' then return the snake case name
print('TrainEvalCallback().name is ',TrainEvalCallback().name)

TrainEvalCallback().name is  train_eval


### Helper function to convert any input to a list (for export)

In [26]:
#export
from typing import *

# function to convert any input into a list
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

## The Runner Class -- to be implemented in fastai v2.0

In [27]:
#export
class Runner():
    def __init__(self, callbacks=None, callback_funcs=None):
        # create a list of callbacks from the input callbacks
        callbacks = listify(callbacks)
        # append to the callbacks list from the input list of callback_funcs
        for callback_func in listify(callback_funcs):
            callback = callback_func()
            setattr(self, callback.name, callback)
            callbacks.append(callback)
        # set the stopping flag to `False` and append TrainEvalCallback() to the callbacks list
        self.stop,self.callbacks = False,[TrainEvalCallback()]+callbacks

    
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): 
            return
        # run the model
        self.pred = self.model(self.xb)
        if self('after_pred'): 
            return
        # compute the loss function
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: 
            return
        # do backpropagation
        self.loss.backward()
        if self('after_backward'): 
            return
        # update parameters
        self.opt.step()
        if self('after_step'): 
            return
        # zero the gradients to prepare for the next batch
        self.opt.zero_grad()
        
    def all_batches(self, dataloader):
        self.n_iters = len(dataloader)
        self.n_epochs_float = 0.
        for xb,yb in dataloader:
            # break if stopping flag has been set
            if self.stop: 
                break
            # process the next batch and set the `after_batch` flag
            self.one_batch(xb, yb)
            self('after_batch')
        # set the stopping flag to `False`    
        self.stop=False

    def fit(self, learn, n_epochs):
        self.n_epochs,self.learn = n_epochs,learn

        try:
            for callback in self.callbacks: 
                callback.set_runner(self)
            if self('begin_fit'): 
                return
            for epoch in range(n_epochs):
                self.epoch = epoch
                
                # training phase
                if not self('begin_epoch'): 
                    self.all_batches(self.data.train_dl)
                
                # validation phase
                with torch.no_grad(): 
                    if not self('begin_validate'): 
                        self.all_batches(self.data.valid_dl)
                # break if `after_epoch` state is `True`
                if self('after_epoch'): 
                    break
            
        finally:
            # set the `after_fit` state to `True`
            self('after_fit')
            # erase the learner object
            self.learn = None

    def __call__(self, cb_name):
        # loop through the callback list, return True if the requested callback is present, otherwise return False
        for callback in sorted(self.callbacks, key=lambda x: x._order):
            # check this callback name, and return True if it is the requested callback
            f = getattr(callback, cb_name, None)
            if f and f(): 
                return True
        return False

## AvgStatsCallback computes metrics and accuracy

In [28]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): 
        self.metrics,self.in_train = listify(metrics),in_train
    
    # initialize total_loss, count, and total_metrics
    def reset(self):
        self.total_loss,self.count = 0.,0
        self.total_metrics = [0.] * len(self.metrics)
        
    # combine loss and metrics
    @property
    def all_stats(self): 
        return [self.total_loss.item()] + self.total_metrics
    
    # compute avg loss and metrics
    @property
    def avg_stats(self): 
        return [o/self.count for o in self.all_stats]
    
    # compute and display stats
    def __repr__(self):
        if not self.count: 
            return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        # get the number of samples in this batch
        n_samples_in_batch = run.xb.shape[0]
        # weight the loss function for the batch by the number of samples in the batch
        self.total_loss += run.loss * n_samples_in_batch
        # accumulate count of samples processed
        self.count += n_samples_in_batch
        # accumulate the metrics
        for i,metric in enumerate(self.metrics):
            self.total_metrics[i] += metric(run.pred, run.yb) * n_samples_in_batch

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,in_train=True),AvgStats(metrics,in_train=False)
        
    # initialize train_stats and valid_stats
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    # compute and accumulate stats
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): 
            stats.accumulate(self.run)
    # print stats
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [29]:
# instantiate a learner object
learn = Learner(*get_model(data), loss_func, data)

In [30]:
# set metric to accuracy
stats = AvgStatsCallback([accuracy])
# instantiate a Runner object
run = Runner(callbacks=stats)

In [31]:
# run a training loop
run.fit(learn, n_epochs=2)

train: [0.31240044921875, tensor(0.9034)]
valid: [0.21252998046875, tensor(0.9320)]
train: [0.1383840625, tensor(0.9579)]
valid: [0.1312119873046875, tensor(0.9598)]


In [32]:
# compute stats
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9
loss,acc

(0.1312119873046875, tensor(0.9598))

In [33]:
#export
from functools import partial

In [34]:
acc_callback_func = partial(AvgStatsCallback,accuracy)

In [35]:
run = Runner(callback_funcs=acc_callback_func)

In [36]:
run.fit(learn, n_epochs=1)

train: [0.1027776171875, tensor(0.9687)]
valid: [0.111236767578125, tensor(0.9683)]


Using Jupyter means we can get tab-completion even for dynamic code like this! :)

In [37]:
run.avg_stats.valid_stats.avg_stats

[0.111236767578125, tensor(0.9683)]

In [38]:
run.avg_stats.valid_stats

valid: [0.111236767578125, tensor(0.9683)]

## Export

In [40]:
!python notebook2script.py 04_callbacks_jcat.ipynb

Converted 04_callbacks_jcat.ipynb to exp\nb_04.py
