# Building a MNIST data processing pipeline using callbacks

## Set up environment

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Import all code from previous notebooks

In [2]:
#export
from exp.nb_03 import *

## Get the MNIST training and validation data sets

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=4799)

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
n_hidden,batch_size = 50,64
# number of output classes
#     for MNIST, one class for each digit
n_out = y_train.max().item()+1
loss_func = F.cross_entropy

In [4]:
# data set properties x and y come from Dataset
print(valid_ds.x.shape)
print(valid_ds.y.shape)

torch.Size([10000, 784])
torch.Size([10000])


# 1. Improving the fit() function

Currently, a call to `fit()` has many inputs

`fit(model, loss_func, opt, train_dl, valid_dl, n_epochs, )`

We can simplify the call to `fit()` to look like this:

`fit(learner, n_epochs)`

where `learner` is a `Learner` Class object, that we will define as a container for the `model`, `loss_func`, `opt`, `train_dl`, `valid_dl` inputs to be passed into `fit()`

This will allow us to tweak what's happening inside the training loop in other places of the code. Because the `Learner` Class object will be mutable, changing any of its attributes elsewhere will modify our training loop.

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5363)

### The DataBunch class
For convenience we define a `DataBunch()` class, as a container for the DataLoader() objects and optionally n_out, the number of output channels.
DataBunch() takes inputs train_dl(), valid_dl(), and optionally n_out.
It has properties train_ds(), valid_ds()

In [5]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, n_out=None):
        self.train_dl,self.valid_dl,self.n_out = train_dl,valid_dl,n_out
    
    # add train_ds() as an attribute
    @property
    def train_ds(self): return self.train_dl.dataset
     
        
    # add valid_ds() as an attribute
    @property
    def valid_ds(self): return self.valid_dl.dataset
    

### The Learner Class

In [6]:
#export
# instantiates the model and the optimizer, given the data and parameters
def get_model(data, learning_rate=0.5, n_hidden = 50):
    n_columns = data.train_ds.x.shape[1]
    n_out = data.n_out
    model = nn.Sequential(nn.Linear(n_columns,n_hidden), nn.ReLU(), nn.Linear(n_hidden,n_out))
    # Q: why can we access the optimizer from within this function?
    return model, optim.SGD(model.parameters(), lr=learning_rate)

In [7]:
#export
# the Learner() class is a container for the model, optimization, loss function and data 
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

## 1A. Refactor the fit() function according to plan

In [8]:
# refactor the fit() function according to our plan
#     Note the order of the inputs has been switched in order to allow n_epochs to be a keyword argument
def fit(learn,n_epochs):
    
    # n_epochs is a keyword argument
    for epoch in range(n_epochs):
        
        # training phase
        learn.model.train()
        for xb,yb in learn.data.train_dl:
            pred = learn.model(xb)
            loss = learn.loss_func(pred, yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        # prediction phase
        # tells pytorch that we are in eval phase, so no dropout, batchnorm, etc.
        learn.model.eval()
        # accumulate loss and accuracy over batches
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy(pred,yb)
        
        # compute average loss and accuracy
        n_valid = len(learn.data.valid_dl)
        avg_loss, avg_acc = tot_loss/n_valid, tot_acc/n_valid
        print(epoch, avg_loss ,avg_acc )
        
    return avg_loss, avg_acc

## 1B. Prepare for minibatch training

### package up the data into a DataBunch object

In [9]:
# package up the data in a DataBunch object
# the * forces you to generate and process the entire sample
batch_size = 64
n_out = y_train.max().item()+1
data = DataBunch(*get_dls(train_ds, valid_ds, batch_size=batch_size), n_out=n_out)

# data also has properties x and y, inherited from train_ds
print(data.train_dl.dataset.x.shape)
print(data.train_dl.dataset.y.shape)

torch.Size([50000, 784])
torch.Size([50000])


### lengths of the training and validation dataloaders are the number of batches in the training and validation data sets, respectively

In [10]:
print(len(data.train_dl), len(data.valid_dl))
print(len(data.train_dl)*batch_size, len(data.valid_dl)*2*batch_size)

782 79
50048 10112


### package up the model, opt, loss_func, and data in a Learner Object

In [11]:
# instantiate a Learner object
learn = Learner(*get_model(data), loss_func, data)

In [12]:
# check module list outputs
print(learn.model.train())
print(learn.model.eval())

Sequential(
  (0): Linear(in_features=784, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=10, bias=True)
)
Sequential(
  (0): Linear(in_features=784, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=10, bias=True)
)


## 1C. Instantiate a minibatch training loop and train the model!

In [13]:
# instantiate a minibatch training loop and start training!
loss,acc = fit(learn, n_epochs = 1)

0 tensor(0.1808) tensor(0.9437)


# 2. Adding callbacks 
Callbacks are functions that can be set to initiate desired sequences of actions at various stages in the minibatch training process. They will allow us to simplify the training loop, and provide added flexibility.

Here is our training loop (without validation) from the previous notebook, with the inner loop contents factored out:

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(n_epochs):
        for b in train_dl: one_batch(*b)
```

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5628)

## The CallBack() Class

Is a container for a set of nine basic callback methods.

Each method is associated with a well-defined stage in the training process, and returns `True` 


In [14]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self):
        return True
    def begin_epoch(self, epoch):
        self.epoch_number=epoch
        return True
    def begin_validate(self):
        return True
    def after_epoch(self): 
        return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self):
        return True
    def after_step(self):
        return True

## The CallbackHandler Class

is a container for a collection of methods comprised of

    (1) a "wrapper" for each method in the Callback() class 

    (2) a do_stop() method that can set the learn.stop flag
    
takes a list of Callback() objects  as input

each "wrapper" modifies the behavior of its associated callback, and returns a state flag, as follows:
    - initialize the state flag
    - optionally set some object properties and/or run some methods
    - for each Callback() object in the input list:
      check whether the callback is among its methods 
         if so, execute the callback, and
            if the callback returns True, set the state flag to its initialized value
            if the callback returns False, set the state flag to False
         if not, set the state flag to False
    - return the state flag 

In [15]:
class CallbackHandler():
    def __init__(self,callbacks=None):
        self.callbacks = callbacks if callbacks else []

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        learn.stop = False
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_fit(learn)
        return state

    def after_fit(self):
        state = not self.in_train
        for callback in self.callbacks: 
            state = state and callback.after_fit()
        return state
    
    def begin_epoch(self, epoch):
        # the next line trains the model
        self.learn.model.train()
        self.in_train=True
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_epoch(epoch)
        return state

    def begin_validate(self):
        # what does the next line do?
        self.learn.model.eval()
        self.in_train=False
        state = True
        for callback in self.callbacks: 
            state = state and callback.begin_validate()
            print('begin_validate state ',state)
            print('self.learn.stop ',self.learn.stop)
        return state

    def after_epoch(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_epoch()
        return state
    
    def begin_batch(self, xb, yb):
        state = True
        for callback in self.callbacks:
            state = state and callback.begin_batch(xb, yb)
        return state

    def after_loss(self, loss):
        state = self.in_train
        for callback in self.callbacks: 
            state = state and callback.after_loss(loss)
        return state

    def after_backward(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_backward()
        return state

    def after_step(self):
        state = True
        for callback in self.callbacks: 
            state = state and callback.after_step()
        return state
    
    def do_stop(self):
        try:     
            return self.learn.stop
        finally: 
            self.learn.stop = False

## A TestCallback Class

In [16]:
class TestCallback(Callback):
    
    # Q: does the begin_fit() callback below inherit from 
    #     the Callback() class, or 
    #     from the CallbackHandler() class? 
    # A: I think from the Callback() class

    # modify the begin_fit() callback 
    #     add iteration, batch and epoch counters
    def begin_fit(self,learn):
        # calls begin_fit() method from the Callback class
        super().begin_fit(learn)
        max_iter = 10
        self.max_iter = max_iter
        self.n_iter = 0
        self.n_epoch_float = 0.
        self.n_batch = 0
        return True
    
    # after each step update in the first batch, begin_batch is executed
    #     but once n_iter reaches 10, learn.stop is set, and after_step never executes again, because
    #     the program exits the training loop and the validation data is then processed
    #def begin_batch(self,xb,yb):
    #    super().begin_batch(xb,yb)
    #    print('n_iter = ',self.n_iter,' validate = ',self.begin_validate(),'learn.stop = ',self.learn.stop)
    #    return True
    
    # modify the afer_step() callback to increment the iteration counter, print the current iteration number,
    #   and to set the learn.stop flag to True after max_iter iterations is reached
    # Q: why doesn't after_step() inherit from its previous definition, as we did with begin_fit()?
    # A: because the previous after_step() does nothing but return True
    def after_step(self):
        self.n_iter += 1
        if self.n_iter >= self.max_iter: 
            self.learn.stop = True
        print('parameter update iteration number ',self.n_iter,'batch_number ',self.n_batch,' learn.stop = ',self.learn.stop)
        return True

## Refactor the training loop using callbacks

In [17]:
# process a single batch, with an input list of callbacks
def one_batch(xb, yb, callback):
    
    # Q: how do parameter updates happen within a batch?
        
    # check for exit condition
    # Q: when would this condition be True?
    if not callback.begin_batch(xb,yb): 
        print('exiting from one_batch due to begin_batch...')
        return
    
    # run the model and compute the loss function for the current batch
    loss = callback.learn.loss_func(callback.learn.model(xb), yb)
    
    print('executing one_batch...')
    
    # do backpropagation for the current batch if in training mode, otherwise exit here
    if not callback.after_loss(loss): 
        print('exiting from one_batch due to after_loss...')
        return
    loss.backward()
    
    # update the parameters for the current batch
    if callback.after_backward(): 
        callback.learn.opt.step()
        
    # zero the gradients to prepare for the next batch
    if callback.after_step(): 
        callback.learn.opt.zero_grad()
        
# process the entire dataset, with an input list of callbacks
#   (i.e. loop through all the batches)
def all_batches(dataloader, callback):
    for xb,yb in dataloader:
        
        # process a batch
        one_batch(xb, yb, callback)
        
        # check whether to stop or to process the next batch
        if callback.do_stop(): 
            return

# training loop with an input callback
def fit(learn, callback, n_epochs):
    
    # check whether or not to start the training loop
    if not callback.begin_fit(learn): 
        return
    
    # loop over the specified number of epochs
    for epoch in range(n_epochs):
        
        # check whether or not to process the next epoch
        if not callback.begin_epoch(epoch): 
            continue # jumps to next epoch
        all_batches(learn.data.train_dl, callback)
        
        # check whether or not to process the validation set
        if callback.begin_validate():
            with torch.no_grad(): 
                all_batches(learn.data.valid_dl, callback)
        
        # check whether or not to break after the current epoch has been processed
        if callback.do_stop() or not callback.after_epoch(): 
            break
            
    # set the callback state to indicate that fit() has been run
    callback.after_fit()

In [18]:
# run a training loop with TestCallback()
# The input is a CallbackHandler 
# Q: Why are there only 10 parameter update iterations? 
#    There should be 10 iterations for each batch
# A: Once n_iter reaches 10, the learn.stop flag is set:
#    the program exits the training loop and processes the validation data, since begin_validate() is True
#    so with this call, fit() processes only a single batch
fit(learn, callback=CallbackHandler([TestCallback()]),n_epochs=1)

executing one_batch...
parameter update iteration number  1 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  2 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  3 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  4 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  5 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  6 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  7 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  8 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  9 batch_number  0  learn.stop =  False
executing one_batch...
parameter update iteration number  10 batch_number  0  learn.stop =  True
begin_validate state  True
sel

This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing the `callback` list to so many functions is a strong hint they should all be in the same class!

### Q: Why is learn.stop still `False`? The training loop set it to `True` before exiting.
### A: but then the validation loop set it back to `False`

In [19]:
learn.stop

False

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5811)

# 3. Refactor the Callback() class
Callback() now becomes a container for a callback function, which can be passed in via the set_runner() method.

Add a `name` property, `_order` and `run` attributes, and a `__getattr__` method to get the callback name

In [20]:
#export
import re

# helper function uses regular expressions to transform a CamelCase callback name to snake_case
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

# refactored Callback() class
class Callback():
        
    # initialize _order to zero. 
    _order=0
    
    # set_runner() method takes a callback as an input
    #     note that initially self.run is unset -- there is no default value 
    def set_runner(self, run): 
        self.run=run
    
    def __getattr__(self, callback_name): 
        return getattr(self.run, callback_name)
    
    # set the callback name property
    #     if the callback doesn't have a name, set the callback name property to 'callback'
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')
        # note: the above line is equivalent to the following block
        '''
        try:
            return camel2snake(name)
        except:
            return 'callback'
        '''
       

## How Callback() works now:

In [21]:
# instantiate a Callback() object
c1 = Callback()
try:
    print(c1.run)
except:
    print('c1.run is undefined')

c1.run is undefined


In [22]:
# pass a callback function to the Callback() object
c1.set_runner(TestCallback())
print('c1.run is now ',c1.run)

c1.run is now  <__main__.TestCallback object at 0x0000027228BEA048>


In [23]:
# check the getattr metod
f = getattr(c1, 'after_step', None)
print(f)

<bound method TestCallback.after_step of <__main__.TestCallback object at 0x0000027228BEA048>>


In [24]:
# an equivalent way to check the getattr method
g = Callback.__getattr__(c1,'after_step')
print(g)

<bound method TestCallback.after_step of <__main__.TestCallback object at 0x0000027228BEA048>>


## Example 1: Transform a CamelCase string to snake_case

In [25]:
# construct a 'snake case' callback name given a 'camel case' callback name
CamelCaseString = 'IngeniousNewAmazingTrick'
print('CamelCase string is ',CamelCaseString,'; snake_case is ',camel2snake(CamelCaseString))

CamelCase string is  IngeniousNewAmazingTrick ; snake_case is  ingenious_new_amazing_trick


## Refactor the TestCallback() class 

after_step checks whether the number of iterations has reached the maximum allowed value, and if so returns `True`

In [26]:
class TestCallback(Callback):
    def after_step(self):
        
        print('parameter update iteration number ',self.n_iter,'batch_number ',self.n_batch)

        if self.train_eval.n_iter >= max_iter: 
            return True

## Define a TrainEvalCallback Class
This is just a collection of methods to handle training and validation and to maintain count of iterations and of fraction of an epoch completed after each batch  

In [27]:
#export
class TrainEvalCallback(Callback):
    
    # initialize the epoch, batch, and iteration counters
    def begin_fit(self):
        # n_epoch_float keeps track of fractional number of elapsed epochs
        self.run.n_epoch_float = 0.
        self.run.n_batch = 0
        self.run.n_iter = 0
    
    # if we are in the training phase, update the epoch and batch counters
    def after_batch(self):
        if not self.in_train: 
            return
        # each batch represents a fraction of an epoch
        self.run.n_epoch_float += 1./self.n_batches
        self.run.n_batch   += 1
    
    # execute the training phase
    def begin_epoch(self):
        self.run.n_epoch_float=self.n_epoch_float
        self.model.train()
        self.run.in_train=True
    
    # execute the prediction phase
    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

## Example 2: Contructing the callback name
By convention, a CamelCase callback name contains a suffix, 'Callback'.

We remove the suffix, then transform to snake_case

In [28]:
# remove the 'Callback' suffix, then transform to snake_case
print('TrainEvalCallback().name is ',TrainEvalCallback().name)

TrainEvalCallback().name is  train_eval


# 4. The Runner Class, a hackable training/validation module
to be implemented in fastai v2.0

In [29]:
#export
from typing import *

# helper function to convert any input into a list
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]


class Runner():
    # initialize by setting the stop Flag to False, and constructing a list of callbacks from the inputs
    def __init__(self, callbacks=None, callback_funcs=None):
        # inputs are two lists: callbacks and callback_funcs
        # Q: it's not clear why we need two lists rather than one
        # create a list of callbacks from the input callbacks
        callbacks = listify(callbacks)
        # associate each callback_func() to its snake case callback name, then append it to the callbacks list
        for callback_func in listify(callback_funcs):
            callback = callback_func()
            setattr(self, callback.name, callback)
            callbacks.append(callback)
        # set the stopping flag to `False` and append TrainEvalCallback() to the callbacks list
        self.stop,self.callbacks = False,[TrainEvalCallback()]+callbacks

    # get the properties of the Learner object
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    # method to process a single batch
    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): 
            return
        # run the model
        self.pred = self.model(self.xb)
        if self('after_pred'): 
            return
        # compute the loss function
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: 
            return
        # do backpropagation
        self.loss.backward()
        if self('after_backward'): 
            return
        # update parameters
        self.opt.step()
        if self('after_step'): 
            return
        # zero the gradients to prepare for the next batch
        self.opt.zero_grad()
        
    # method to process all batches
    def all_batches(self, dataloader):
        # total number of batches in an epoch
        self.n_batches = len(dataloader)
        # self.n_epoch_float = 0.
        for xb,yb in dataloader:
            # break if run.stop flag has been set
            if self.stop: 
                break
            # process the next batch, then run the `after_batch` callback
            self.one_batch(xb, yb)
            self('after_batch')
        # set the run.stop flag to `False`    
        self.stop=False

    # method to process training or validation data
    def fit(self, learn, n_epochs):
        self.n_epochs,self.learn = n_epochs,learn

        try:
            # loop over all callbacks in list and set_runner for each one
            for callback in self.callbacks: 
                callback.set_runner(self)
            if self('begin_fit'): 
                return
            for epoch_number in range(n_epochs):
                self.epoch_number = epoch_number
                
                # training phase
                if not self('begin_epoch'): 
                    self.all_batches(self.data.train_dl)
                
                # validation phase
                with torch.no_grad(): 
                    if not self('begin_validate'): 
                        self.all_batches(self.data.valid_dl)
                # break if `after_epoch` state is `True`
                if self('after_epoch'): 
                    break
            
        finally:
            # set the `after_fit` state to `True`
            self('after_fit')
            # erase the Learner object
            self.learn = None

    def __call__(self, callback_name):
        # __call__ allows an instance of this class to be called as a function
        # loop through the callback list, return True if the requested callback callback_name is present, 
        #     otherwise return False
        for callback in sorted(self.callbacks, key=lambda x: x._order):
            # check this callback name, and return True if it is the requested callback
            # get the callback associated with callback_name, otherwise return None
            f = getattr(callback, callback_name, None)
            if f and f(): # guarantees that the callback is present and is a function
                return True
        return False

## AvgStatsCallback computes loss and metrics, such as accuracy

In [30]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): 
        self.metrics,self.in_train = listify(metrics),in_train
    
    # initialize total_loss and count to zero, and total_metrics to zeros for each metric
    def reset(self):
        # count keeps track of total samples processed
        self.total_loss,self.count = 0.,0
        self.total_metrics = [0.] * len(self.metrics)
        
    # combine loss and metrics
    @property
    def all_stats(self):
        # all_stats is a list containing loss and all metrics
        # Q: why does total_loss have to be extracted with .item()
        return [self.total_loss.item()] + self.total_metrics
    
    # compute avg loss and metrics per sample
    @property
    def avg_stats(self):
        # each stat is averaged over the number of samples
        return [o/self.count for o in self.all_stats]
    
    # compute and display stats
    def __repr__(self):
        if not self.count: 
            return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        # get the number of samples in this batch
        n_samples_in_batch = run.xb.shape[0]
        # weight the loss function for the batch by the number of samples in the batch
        self.total_loss += run.loss * n_samples_in_batch
        # accumulate count of samples processed
        self.count += n_samples_in_batch
        # accumulate the metrics, weighting each by number of samples in the batch
        for i,metric in enumerate(self.metrics):
            self.total_metrics[i] += metric(run.pred, run.yb) * n_samples_in_batch

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,in_train=True),AvgStats(metrics,in_train=False)
        
    # initialize train_stats and valid_stats at the start of an epoch
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    # compute and accumulate stats after the loss function has been evaluated
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): 
            stats.accumulate(self.run)
    # print stats after the epoch has been processed
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

# 5. Model Training and Validation

## Training/Validation Method 1: 
### Use the Runner() class with the `callbacks` input

In [31]:
# instantiate an AvgStatsCallback with the accuracy metric
stats = AvgStatsCallback([accuracy])

In [32]:
# instantiate a Runner object using the callbacks input
run = Runner(callbacks=stats)

In [33]:
# instantiate a learner object with data, loss_func, opt and model
learn = Learner(*get_model(data), loss_func, data) 

In [34]:
# run a training loop
run.fit(learn, n_epochs=2)

train: [0.3064101953125, tensor(0.9056)]
valid: [0.3826717041015625, tensor(0.8782)]
train: [0.142082958984375, tensor(0.9569)]
valid: [0.2425769775390625, tensor(0.9223)]


In [35]:
# compute stats, which got modified by the call to run.fit
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9
loss,acc

(0.2425769775390625, tensor(0.9223))

## Training/Validation Method #2: 
###  Use the Runner() class with the `callback_funcs()` input

In [36]:
#export
from functools import partial

In [37]:
# contruct a callback_func that does the same job as stats, above
stats_func = partial(AvgStatsCallback,accuracy)

In [38]:
# instantiate a Runner object using the callback_funcs() input
run = Runner(callback_funcs=stats_func)

In [39]:
# instantiate a Learner object with data, loss_func, opt and model
learn = Learner(*get_model(data), loss_func, data) 

In [40]:
# run a training/validation loop
run.fit(learn, n_epochs=2)

train: [0.3240140625, tensor(0.8977)]
valid: [0.2048943115234375, tensor(0.9350)]
train: [0.142941611328125, tensor(0.9555)]
valid: [0.136198046875, tensor(0.9598)]


In [41]:
# get statistics
print(run.avg_stats.valid_stats.avg_stats)
print(run.avg_stats.valid_stats)

[0.136198046875, tensor(0.9598)]
valid: [0.136198046875, tensor(0.9598)]


## Export

In [42]:
!python notebook2script.py 04_callbacks_jcat.ipynb

Converted 04_callbacks_jcat.ipynb to exp\nb_04.py
