In [None]:
# default_exp profiling_callback

In [None]:
#hide
#ci
!pip install -U fastai --upgrade

In [None]:
#hide
#local
%cd ..
from my_timesaver_utils.profiling import *
%cd nbs

/Users/butch/devt/workspaces/python3/fastai2_2020/experiments/my_timesaver_utils
/Users/butch/devt/workspaces/python3/fastai2_2020/experiments/my_timesaver_utils/nbs


# Profiling Callback

> applying profiling to the fastai learner callback functions
> enables profiling of fastai model training

In [None]:
#hide
#local
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#export
from my_timesaver_utils.profiling import *

In [None]:
#export
import warnings
FASTAI_AVAILABLE = True
try:
    from fastai.callback.core import Callback
    from fastai.learner import Learner
    from fastcore.foundation import patch
except ImportError as e:
    FASTAI_AVAILABLE = False
    warnings.warn('fastai package not installed, callback simulated')

In [None]:
#export
if not FASTAI_AVAILABLE:
    class Callback:
        pass
    class Learner:
        pass
    def patch(fn, *args,**kwargs):
        return fn

**FastAI Training Event Lifecycle Methods**
```
begin_fit
      begin_epoch
            begin_train
                     begin_batch
                           after_pred
                           after_loss
                           after_backward
                           after_step
                    after_cancel_batch
                    after_batch            
            after_train
            after_cancel_train
            begin_validate
                    begin_batch
                           after_pred
                           after_loss
                           after_backward
                           after_step
                    after_cancel_batch
                    after_batch
                    
	     after_cancel_validate
             after_validate
             
      after_epoch
after_cancel_fit       
after_fit

```

In [None]:
#exporti
def _print_stat(func_name, level, data, indent_per_level=3):
    indent = ' ' * indent_per_level * level
    if data is None:
        print(f'{indent}{func_name} has no data')
        return
    max_time = max(data)
    avg_time = sum(data) / len(data)
    print(f'{indent}{func_name}  called {len(data)} times. max: {max_time:.3f} avg: {avg_time:.3f}')

In [None]:
#export
class MyProfileCallback(Callback):
    'Callback to profile training lifecycle event performance'
    ordered_callbacks = (
        ('fit',0),
        ('epoch',1),
        ('train',2),
        ('train_batch',3),
        ('train_pred',4),
        ('train_loss',4),
        ('train_backward',4),
        ('train_step',4),
        ('train_zero_grad',4),
        ('valid',2),
        ('valid_batch',3),
        ('valid_pred',4),
        ('valid_loss',4)
    )
    def __init__(self, reset=False):
        self._reset = reset

    def before_fit(self):
        if self._reset:
            self.clear_stats()
        start_record('fit')

    def before_epoch(self):
        start_record('epoch')

    def before_train(self):
        start_record('train')

    def before_batch(self):
        if self.learn.training:
            start_record('train_batch')
            start_record('train_pred')
        else:
            start_record('valid_batch')
            start_record('valid_pred')

    def after_pred(self):
        if self.learn.training:
            end_record('train_pred')
            if len(self.learn.yb) > 0:
                start_record('train_loss')
        else:
            end_record('valid_pred')
            if len(self.learn.yb) > 0:
                start_record('valid_loss')

    def after_loss(self):
        if self.learn.training:
            end_record('train_loss')
            start_record('train_backward')
        else:
            end_record('valid_loss')
            # no start train_backward because
            # valid doesnt execute backward

    def after_backward(self):
        end_record('train_backward')
        start_record('train_step')

    def after_step(self):
        end_record('train_step')
        start_record('train_zero_grad')

    def after_cancel_batch(self):
        if self.learn.training:
            if is_recording('train_pred'):
                end_record('train_pred')

            if is_recording('train_loss'):
                end_record('train_loss')

            if is_recording('train_backward'):
                end_record('train_backward')

            if is_recording('train_step'):
                end_record('train_step')

            if is_recording('train_zero_grad'):
                end_record('train_zero_grad')
        else:
            if is_recording('valid_pred'):
                end_record('valid_pred')

            if is_recording('valid_loss'):
                end_record('valid_loss')

            # no more steps after valid_loss

    def after_batch(self):
        if self.learn.training:
            if is_recording('train_zero_grad'):
                end_record('train_zero_grad')
            end_record('train_batch')
        else:
            end_record('valid_batch')

    def after_train(self):
        end_record('train')

    def after_cancel_train(self):
        if is_recording('train_pred'):
            end_record('train_pred')

        if is_recording('train_loss'):
            end_record('train_loss')

        if is_recording('train_backward'):
            end_record('train_backward')

        if is_recording('train_step'):
            end_record('train_step')

        if is_recording('train_zero_grad'):
            end_record('train_zero_grad')

    def before_validate(self):
        start_record('valid')

    def after_cancel_validate(self):
        if is_recording('valid_pred'):
            end_record('valid_pred')
        if is_recording('valid_loss'):
            end_record('valid_loss')

    def after_validate(self):
        end_record('valid')

    def after_epoch(self):
        end_record('epoch')

    def after_cancel_fit(self):
        if is_recording('epoch'):
            end_record('epoch')

        if is_recording('train'):
            end_record('train')

        if is_recording('train_batch'):
            end_record('train_batch')

        if is_recording('train_pred'):
            end_record('train_pred')

        if is_recording('train_loss'):
            end_record('train_loss')

        if is_recording('train_backward'):
            end_record('train_backward')

        if is_recording('train_step'):
            end_record('train_step')

        if is_recording('train_zero_grad'):
            end_record('train_zero_grad')

        if is_recording('valid'):
            end_record('valid')

        if is_recording('valid_batch'):
            end_record('valid_batch')

        if is_recording('valid_pred'):
            end_record('valid_pred')

        if is_recording('valid_loss'):
            end_record('valid_loss')

    def after_fit(self):
        end_record('fit')

    def print_stats(self, fname=None, indent_per_level=3):
        if fname is not None:
            matches = [(func_name,level) for (func_name,level) in self.ordered_callbacks if func_name == fname]
            if len(matches) > 0:
                func_name, level = matches[0]
                data = get_prof_data(func_name)
                _print_stat(func_name, level, data, indent_per_level=indent_per_level)
            else:
                _print_stat(func_name, 0, None, indent_per_level=indent_per_level)
            return

        for func_name,level in self.ordered_callbacks:
            data = get_prof_data(func_name)
            _print_stat(func_name, level, data, indent_per_level=indent_per_level)

    def clear_stats(self, fname=None):
        if fname is not None:
            clear_prof_data(func_name)
            return
        for func_name,_ in self.ordered_callbacks:
            clear_prof_data(func_name)

    def get_stats(self,fname=None):
        if fname is not None:
            matches = [(func_name,level) for (func_name,level) in self.ordered_callbacks if func_name == fname]
            if len(matches) > 0:
                func_name, level = matches[0]
                data = get_prof_data(func_name)
            else:
                func_name = fname
                level = 0
                data = []
            return (func_name, level, data)
        res = []
        for func_name,level in self.ordered_callbacks:
            data = get_prof_data(func_name)
            res.append((func_name,level,data))
        return res

    @property
    def reset(self):
        return self._reset

    @reset.setter
    def reset(self,v):
        self._reset = v

In [None]:
#export        
@patch
def to_my_profile(self:Learner, reset=False):
    'Add my_profile callback to learner'
    cb = MyProfileCallback(reset=reset)
    if not getattr(self, cb.name, None):
        self.add_cb(cb)
    else:
        self.my_profile.reset = reset
    return self

### Example Usage

In [None]:
from fastai.vision.all import *

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
Path.BASE_PATH = path

In [None]:
datablock = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
dls = datablock.dataloaders(path)

In [None]:
learner = cnn_learner(dls,resnet18,metrics=accuracy)

In [None]:
learner.to_my_profile()

<fastai.learner.Learner at 0x130955f90>

In [None]:
learner.summary()

Sequential (Input shape: ['64 x 3 x 28 x 28'])
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               64 x 64 x 14 x 14    9,408      False     
________________________________________________________________
BatchNorm2d          64 x 64 x 14 x 14    128        True      
________________________________________________________________
ReLU                 64 x 64 x 14 x 14    0          False     
________________________________________________________________
MaxPool2d            64 x 64 x 7 x 7      0          False     
________________________________________________________________
Conv2d               64 x 64 x 7 x 7      36,864     False     
________________________________________________________________
BatchNorm2d          64 x 64 x 7 x 7      128        True      
________________________________________________________________
ReLU                 64 x 64 x 7 x 7      0          False     
___________________________________________________

In [None]:
learner.my_profile

MyProfileCallback

In [None]:
learner.my_profile.print_stats()

fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data


In [None]:
learner.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.646651,0.224889,0.909871,00:14


In [None]:
learner.my_profile.print_stats()

fit  called 1 times. max: 14.688 avg: 14.688
   epoch  called 1 times. max: 14.678 avg: 14.678
      train  called 1 times. max: 12.425 avg: 12.425
         train_batch  called 11 times. max: 1.170 avg: 1.081
            train_pred  called 11 times. max: 0.304 avg: 0.236
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 0.848 avg: 0.834
            train_step  called 11 times. max: 0.011 avg: 0.008
            train_zero_grad  called 11 times. max: 0.006 avg: 0.003
      valid  called 1 times. max: 2.245 avg: 2.245
         valid_batch  called 11 times. max: 0.195 avg: 0.181
            valid_pred  called 11 times. max: 0.188 avg: 0.178
            valid_loss  called 11 times. max: 0.001 avg: 0.001


In [None]:
fit_stats = learner.my_profile.get_stats();fit_stats

[('fit', 0, [14.687530994415283]),
 ('epoch', 1, [14.67847228050232]),
 ('train', 2, [12.425010204315186]),
 ('train_batch',
  3,
  [1.1697580814361572,
   1.0837702751159668,
   1.0729010105133057,
   1.0813207626342773,
   1.0650250911712646,
   1.0707719326019287,
   1.087137222290039,
   1.0528171062469482,
   1.0657010078430176,
   1.0718588829040527,
   1.0731070041656494]),
 ('train_pred',
  4,
  [0.3040790557861328,
   0.23769092559814453,
   0.2236030101776123,
   0.23487424850463867,
   0.22310113906860352,
   0.22788310050964355,
   0.23348093032836914,
   0.23145484924316406,
   0.21995902061462402,
   0.22831082344055176,
   0.2304689884185791]),
 ('train_loss',
  4,
  [0.0010890960693359375,
   0.0006949901580810547,
   0.0007219314575195312,
   0.0006780624389648438,
   0.0010340213775634766,
   0.0007722377777099609,
   0.0006699562072753906,
   0.0006821155548095703,
   0.0007212162017822266,
   0.0006709098815917969,
   0.0006799697875976562]),
 ('train_backward',
  4

In [None]:
learner.my_profile.print_stats('train_batch')

         train_batch  called 11 times. max: 1.170 avg: 1.081


In [None]:
train_batch_stats = learner.my_profile.get_stats('train_batch'); train_batch_stats

('train_batch',
 3,
 [1.1697580814361572,
  1.0837702751159668,
  1.0729010105133057,
  1.0813207626342773,
  1.0650250911712646,
  1.0707719326019287,
  1.087137222290039,
  1.0528171062469482,
  1.0657010078430176,
  1.0718588829040527,
  1.0731070041656494])

In [None]:
learner.my_profile.clear_stats()

In [None]:
learner.my_profile.print_stats()

fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data


In [None]:
learner.my_profile.print_stats('train')

      train has no data


In [None]:
learner.fine_tune(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.370858,0.235521,0.912732,00:14


epoch,train_loss,valid_loss,accuracy,time
0,0.231076,0.167055,0.939914,00:22


In [None]:
learner.my_profile.print_stats()

fit  called 2 times. max: 22.034 avg: 18.256
   epoch  called 2 times. max: 22.030 avg: 18.252
      train  called 2 times. max: 19.765 avg: 16.002
         train_batch  called 22 times. max: 2.025 avg: 1.429
            train_pred  called 22 times. max: 0.302 avg: 0.232
            train_loss  called 22 times. max: 0.001 avg: 0.001
            train_backward  called 22 times. max: 1.520 avg: 1.132
            train_step  called 22 times. max: 0.222 avg: 0.058
            train_zero_grad  called 22 times. max: 0.011 avg: 0.005
      valid  called 2 times. max: 2.258 avg: 2.244
         valid_batch  called 22 times. max: 0.211 avg: 0.180
            valid_pred  called 22 times. max: 0.204 avg: 0.177
            valid_loss  called 22 times. max: 0.002 avg: 0.001


In [None]:
learner.my_profile.reset = True

In [None]:
learner.fine_tune(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.158287,0.196981,0.9299,00:16


epoch,train_loss,valid_loss,accuracy,time
0,0.090954,0.162717,0.944206,00:23


In [None]:
learner.my_profile.print_stats()

fit  called 1 times. max: 23.094 avg: 23.094
   epoch  called 1 times. max: 23.091 avg: 23.091
      train  called 1 times. max: 20.766 avg: 20.766
         train_batch  called 11 times. max: 2.104 avg: 1.863
            train_pred  called 11 times. max: 0.265 avg: 0.231
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 1.648 avg: 1.519
            train_step  called 11 times. max: 0.180 avg: 0.104
            train_zero_grad  called 11 times. max: 0.011 avg: 0.008
      valid  called 1 times. max: 2.318 avg: 2.318
         valid_batch  called 11 times. max: 0.212 avg: 0.187
            valid_pred  called 11 times. max: 0.209 avg: 0.183
            valid_loss  called 11 times. max: 0.002 avg: 0.001


In [None]:
learner.my_profile.reset

True