In [None]:
# default_exp profiling_callback

In [None]:
#hide
#ci
!pip install -Uqq fastai --upgrade

In [None]:
#hide
#local
%cd ..
from my_timesaver_utils.profiling import *
%cd nbs

/Users/butch/devt/workspaces/python3/fastai2_2020/portfolio-projects/my_timesaver_utils
/Users/butch/devt/workspaces/python3/fastai2_2020/portfolio-projects/my_timesaver_utils/nbs


# Profiling Callback

> applying profiling to the fastai learner callback functions
> enables profiling of fastai model training

In [None]:
#hide
#local
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#export
from my_timesaver_utils.profiling import *

In [None]:
#export
import warnings
FASTAI_AVAILABLE = True
try:
    from fastai.callback.core import Callback
    from fastai.learner import Learner
    from fastcore.foundation import patch
except ImportError as e:
    FASTAI_AVAILABLE = False
    warnings.warn('fastai package not installed, callback simulated')

In [None]:
#export
if not FASTAI_AVAILABLE:
    class Callback:
        pass
    class Learner:
        pass
    def patch(fn, *args,**kwargs):
        return fn

**FastAI Training Event Lifecycle Methods**
```
after_create

before_fit
      before_epoch
            before_train
                before_batch
                    after_pred
                    after_loss
                    before_backward
                    after_backward
                    after_cancel_step
                    after_step
                    after_cancel_batch
                after_batch 
            after_cancel_train                    
            after_train
            before_validate
                before_batch
                    after_pred
                    after_loss
                after_cancel_batch
                after_batch
            after_cancel_validate
            after_validate
      after_epoch
after_cancel_fit       
after_fit

```

In [None]:
#exporti
def _print_stat(func_name, level, data, indent_per_level=3):
    indent = ' ' * indent_per_level * level
    if data is None:
        print(f'{indent}{func_name} has no data')
        return
    max_time = max(data)
    avg_time = sum(data) / len(data)
    print(f'{indent}{func_name}  called {len(data)} times. max: {max_time:.3f} avg: {avg_time:.3f}')

In [None]:
#export
class MyProfileCallback(Callback):
    'Callback to profile training lifecycle event performance'
    order = -15 # should run before any callbacks
    ordered_callbacks = (
        ('fit',0),
        ('epoch',1),
        ('train',2),
        ('train_batch',3),
        ('train_pred',4),
        ('train_loss',4),
        ('train_backward',4),
        ('train_step',4),
        ('train_zero_grad',4),
        ('valid',2),
        ('valid_batch',3),
        ('valid_pred',4),
        ('valid_loss',4)
    )
    def __init__(self, reset=False):
        self._reset = reset

    def before_fit(self):
        if self._reset:
            self.clear_stats()
        start_record('fit')

    def before_epoch(self):
        start_record('epoch')

    def before_train(self):
        start_record('train')

    def before_batch(self):
        if self.learn.training:
            start_record('train_batch')
            start_record('train_pred')
        else:
            start_record('valid_batch')
            start_record('valid_pred')

    def after_pred(self):
        if self.learn.training:
            end_record('train_pred')
            if len(self.learn.yb) > 0:
                start_record('train_loss')
        else:
            end_record('valid_pred')
            if len(self.learn.yb) > 0:
                start_record('valid_loss')

    def after_loss(self):
        if self.learn.training:
            end_record('train_loss')
            
        else:
            end_record('valid_loss')
            # no start train_backward because
            # valid doesnt execute backward
            
    def before_backward(self):
        start_record('train_backward')
        
    def before_step(self):
        end_record('train_backward')
        start_record('train_step')
        
    def after_step(self):
        end_record('train_step')
        start_record('train_zero_grad')

    def after_cancel_batch(self):
        if self.learn.training:
            if is_recording('train_pred'):
                end_record('train_pred')

            if is_recording('train_loss'):
                end_record('train_loss')

            if is_recording('train_backward'):
                end_record('train_backward')

            if is_recording('train_step'):
                end_record('train_step')

            if is_recording('train_zero_grad'):
                end_record('train_zero_grad')
        else:
            if is_recording('valid_pred'):
                end_record('valid_pred')

            if is_recording('valid_loss'):
                end_record('valid_loss')

            # no more steps after valid_loss

    def after_batch(self):
        if self.learn.training:
            if is_recording('train_zero_grad'):
                end_record('train_zero_grad')
            end_record('train_batch')
        else:
            end_record('valid_batch')

    def after_train(self):
        end_record('train')

    def after_cancel_train(self):
        if is_recording('train_pred'):
            end_record('train_pred')

        if is_recording('train_loss'):
            end_record('train_loss')

        if is_recording('train_backward'):
            end_record('train_backward')

        if is_recording('train_step'):
            end_record('train_step')

        if is_recording('train_zero_grad'):
            end_record('train_zero_grad')

    def before_validate(self):
        start_record('valid')

    def after_cancel_validate(self):
        if is_recording('valid_pred'):
            end_record('valid_pred')
        if is_recording('valid_loss'):
            end_record('valid_loss')

    def after_validate(self):
        end_record('valid')

    def after_epoch(self):
        end_record('epoch')

    def after_cancel_fit(self):
        if is_recording('epoch'):
            end_record('epoch')

        if is_recording('train'):
            end_record('train')

        if is_recording('train_batch'):
            end_record('train_batch')

        if is_recording('train_pred'):
            end_record('train_pred')

        if is_recording('train_loss'):
            end_record('train_loss')

        if is_recording('train_backward'):
            end_record('train_backward')

        if is_recording('train_step'):
            end_record('train_step')

        if is_recording('train_zero_grad'):
            end_record('train_zero_grad')

        if is_recording('valid'):
            end_record('valid')

        if is_recording('valid_batch'):
            end_record('valid_batch')

        if is_recording('valid_pred'):
            end_record('valid_pred')

        if is_recording('valid_loss'):
            end_record('valid_loss')

    def after_fit(self):
        end_record('fit')

    def print_stats(self, fname=None, indent_per_level=3):
        if fname is not None:
            matches = [(func_name,level) for (func_name,level) in self.ordered_callbacks if func_name == fname]
            if len(matches) > 0:
                func_name, level = matches[0]
                data = get_prof_data(func_name)
                _print_stat(func_name, level, data, indent_per_level=indent_per_level)
            else:
                _print_stat(func_name, 0, None, indent_per_level=indent_per_level)
            return

        for func_name,level in self.ordered_callbacks:
            data = get_prof_data(func_name)
            _print_stat(func_name, level, data, indent_per_level=indent_per_level)

    def clear_stats(self, fname=None):
        if fname is not None:
            clear_prof_data(func_name)
            return
        for func_name,_ in self.ordered_callbacks:
            clear_prof_data(func_name)

    def get_stats(self,fname=None):
        if fname is not None:
            matches = [(func_name,level) for (func_name,level) in self.ordered_callbacks if func_name == fname]
            if len(matches) > 0:
                func_name, level = matches[0]
                data = get_prof_data(func_name)
            else:
                func_name = fname
                level = 0
                data = []
            return (func_name, level, data)
        res = []
        for func_name,level in self.ordered_callbacks:
            data = get_prof_data(func_name)
            res.append((func_name,level,data))
        return res

    @property
    def reset(self):
        return self._reset

    @reset.setter
    def reset(self,v):
        self._reset = v

In [None]:
#export        
@patch
def to_my_profile(self:Learner, reset=False):
    'Add my_profile callback to learner'
    cb = MyProfileCallback(reset=reset)
    if not getattr(self, cb.name, None):
        self.add_cb(cb)
    else:
        self.my_profile.reset = reset
    return self

### Example Usage

In [None]:
from fastai.vision.all import *

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
Path.BASE_PATH = path

In [None]:
datablock = DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(),
    item_tfms=Resize(28),
    batch_tfms=[]
)

In [None]:
dls = datablock.dataloaders(path)

In [None]:
learner = cnn_learner(dls,resnet18,metrics=accuracy)

In [None]:
learner.summary()

Sequential (Input shape: 64)
Layer (type)         Output Shape         Param #    Trainable 
                     64 x 64 x 14 x 14   
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                      

In [None]:
learner.to_my_profile()

<fastai.learner.Learner at 0x135ecafd0>

In [None]:
learner.my_profile

MyProfileCallback

In [None]:
learner.my_profile.print_stats()

fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data


In [None]:
learner.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.693655,0.486362,0.749642,00:14


In [None]:
learner.my_profile.print_stats()

fit  called 1 times. max: 14.826 avg: 14.826
   epoch  called 1 times. max: 14.826 avg: 14.826
      train  called 1 times. max: 12.539 avg: 12.539
         train_batch  called 11 times. max: 1.147 avg: 1.093
            train_pred  called 11 times. max: 0.253 avg: 0.219
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 0.900 avg: 0.861
            train_step  called 11 times. max: 0.014 avg: 0.010
            train_zero_grad  called 11 times. max: 0.002 avg: 0.002
      valid  called 1 times. max: 2.283 avg: 2.283
         valid_batch  called 11 times. max: 0.203 avg: 0.181
            valid_pred  called 11 times. max: 0.202 avg: 0.180
            valid_loss  called 11 times. max: 0.002 avg: 0.001


In [None]:
fit_stats = learner.my_profile.get_stats();fit_stats

[('fit', 0, [14.826272010803223]),
 ('epoch', 1, [14.825506210327148]),
 ('train', 2, [12.53893232345581]),
 ('train_batch',
  3,
  [1.147028923034668,
   1.0965969562530518,
   1.0539379119873047,
   1.0700407028198242,
   1.0998239517211914,
   1.0905580520629883,
   1.0969460010528564,
   1.0751848220825195,
   1.1051452159881592,
   1.0613350868225098,
   1.1278557777404785]),
 ('train_pred',
  4,
  [0.25260400772094727,
   0.2151319980621338,
   0.21577811241149902,
   0.21297788619995117,
   0.2168900966644287,
   0.21621465682983398,
   0.21819210052490234,
   0.2154397964477539,
   0.21802592277526855,
   0.21338295936584473,
   0.21511292457580566]),
 ('train_loss',
  4,
  [0.0011301040649414062,
   0.0007872581481933594,
   0.000743865966796875,
   0.0007627010345458984,
   0.0007507801055908203,
   0.0007741451263427734,
   0.0007429122924804688,
   0.0007698535919189453,
   0.0007410049438476562,
   0.0007848739624023438,
   0.0007369518280029297]),
 ('train_backward',
  4,

In [None]:
learner.my_profile.print_stats('train_batch')

         train_batch  called 11 times. max: 1.147 avg: 1.093


In [None]:
train_batch_stats = learner.my_profile.get_stats('train_batch'); train_batch_stats

('train_batch',
 3,
 [1.147028923034668,
  1.0965969562530518,
  1.0539379119873047,
  1.0700407028198242,
  1.0998239517211914,
  1.0905580520629883,
  1.0969460010528564,
  1.0751848220825195,
  1.1051452159881592,
  1.0613350868225098,
  1.1278557777404785])

In [None]:
learner.my_profile.clear_stats()

In [None]:
learner.my_profile.print_stats()

fit has no data
   epoch has no data
      train has no data
         train_batch has no data
            train_pred has no data
            train_loss has no data
            train_backward has no data
            train_step has no data
            train_zero_grad has no data
      valid has no data
         valid_batch has no data
            valid_pred has no data
            valid_loss has no data


In [None]:
learner.my_profile.print_stats('train')

      train has no data


In [None]:
learner.fine_tune(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.339887,0.245247,0.908441,00:14


epoch,train_loss,valid_loss,accuracy,time
0,0.273823,0.196766,0.919886,00:21


In [None]:
learner.my_profile.print_stats()

fit  called 2 times. max: 21.981 avg: 18.306
   epoch  called 2 times. max: 21.980 avg: 18.305
      train  called 2 times. max: 19.679 avg: 15.999
         train_batch  called 22 times. max: 2.007 avg: 1.426
            train_pred  called 22 times. max: 0.259 avg: 0.219
            train_loss  called 22 times. max: 0.001 avg: 0.001
            train_backward  called 22 times. max: 1.592 avg: 1.150
            train_step  called 22 times. max: 0.148 avg: 0.051
            train_zero_grad  called 22 times. max: 0.007 avg: 0.004
      valid  called 2 times. max: 2.306 avg: 2.301
         valid_batch  called 22 times. max: 0.211 avg: 0.182
            valid_pred  called 22 times. max: 0.209 avg: 0.181
            valid_loss  called 22 times. max: 0.002 avg: 0.001


In [None]:
learner.my_profile.reset = True

In [None]:
learner.fine_tune(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.16097,0.16135,0.944206,00:15


epoch,train_loss,valid_loss,accuracy,time
0,0.120808,0.129148,0.958512,00:23


In [None]:
learner.my_profile.print_stats()

fit  called 1 times. max: 23.155 avg: 23.155
   epoch  called 1 times. max: 23.154 avg: 23.154
      train  called 1 times. max: 20.823 avg: 20.823
         train_batch  called 11 times. max: 1.939 avg: 1.864
            train_pred  called 11 times. max: 0.247 avg: 0.216
            train_loss  called 11 times. max: 0.001 avg: 0.001
            train_backward  called 11 times. max: 1.590 avg: 1.546
            train_step  called 11 times. max: 0.147 avg: 0.093
            train_zero_grad  called 11 times. max: 0.008 avg: 0.007
      valid  called 1 times. max: 2.326 avg: 2.326
         valid_batch  called 11 times. max: 0.214 avg: 0.183
            valid_pred  called 11 times. max: 0.212 avg: 0.182
            valid_loss  called 11 times. max: 0.002 avg: 0.001


In [None]:
learner.my_profile.reset