In [None]:
#default_exp callback.training

# Training Callbacks
> Very basic Callbacks to enhance the training experience

In [None]:
#export
# Contains code used/modified by fastai_minima author from fastai
# Copyright 2019 the fast.ai team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#export
from fastai_minima.callback.core import Callback
from fastai_minima.learner import Learner
from fastai_minima.utils import defaults, noop

from fastprogress.fastprogress import progress_bar,master_bar

from fastcore.basics import patch
from contextlib import contextmanager

In [None]:
#export
class ProgressCallback(Callback):
    "A `Callback` to handle the display of progress bars"
    order,_stateattrs = 60,('mbar','pbar')

    def before_fit(self):
        "Setup the master bar over the epochs"
        assert hasattr(self.learn, 'recorder')
        if self.create_mbar: self.mbar = master_bar(list(range(self.n_epoch)))
        if self.learn.logger != noop:
            self.old_logger,self.learn.logger = self.logger,self._write_stats
            self._write_stats(self.recorder.metric_names)
        else: self.old_logger = noop

    def before_epoch(self):
        "Update the master bar"
        if getattr(self, 'mbar', False): self.mbar.update(self.epoch)

    def before_train(self):    
        "Launch a progress bar over the training dataloader"
        self._launch_pbar()
        
    def before_validate(self): 
        "Launch a progress bar over the validation dataloader"
        self._launch_pbar()
        
    def after_train(self):    
        "Close the progress bar over the training dataloader"
        self.pbar.on_iter_end()
        
    def after_validate(self):  
        "Close the progress bar over the validation dataloader"
        self.pbar.on_iter_end()
        
    def after_batch(self):
        "Update the current progress bar"
        self.pbar.update(self.iter+1)
        if hasattr(self, 'smooth_loss'): self.pbar.comment = f'{self.smooth_loss:.4f}'

    def _launch_pbar(self):
        self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
        self.pbar.update(0)

    def after_fit(self):
        "Close the master bar"
        if getattr(self, 'mbar', False):
            self.mbar.on_iter_end()
            delattr(self, 'mbar')
        if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger

    def _write_stats(self, log):
        if getattr(self, 'mbar', False): self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)

if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback, Recorder, ProgressCallback]
elif ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback)

In [None]:
#hide
import torch
from torch.utils.data import TensorDataset, DataLoader
from fastai_minima.learner import DataLoaders
from torch import nn
def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2):
    "A simple dataset where `x` is random and `y = a*x + b` plus some noise."
    def get_data(n):
        x = torch.randn(int(bs*n))
        return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))
    train_ds = get_data(n_train)
    valid_ds = get_data(n_valid)
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=0)
    valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=0)
    return DataLoaders(train_dl, valid_dl)
def synth_learner(n_train=10, n_valid=2, lr=defaults.lr, **kwargs):
    data = synth_dbunch(n_train=n_train,n_valid=n_valid)
    return Learner(data, RegModel(), loss_func=nn.MSELoss(), lr=lr, **kwargs)

class RegModel(nn.Module):
    "A r"
    def __init__(self): 
        super().__init__()
        self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
    def forward(self, x): return x*self.a + self.b

In [None]:
learn = synth_learner()
learn.fit(5)

epoch,train_loss,valid_loss,time
0,10.118802,8.17487,00:00
1,10.075614,8.11403,00:00
2,10.035525,8.052789,00:00
3,9.972941,7.991217,00:00
4,9.917836,7.931038,00:00


  allow_unreachable=True)  # allow_unreachable flag


In [None]:
#export
@patch
@contextmanager
def no_bar(self:Learner):
    "Context manager that deactivates the use of progress bars"
    has_progress = hasattr(self, 'progress')
    if has_progress: self.remove_cb(self.progress)
    try: yield self
    finally:
        if has_progress: self.add_cb(ProgressCallback())

In [None]:
learn = synth_learner()
with learn.no_bar(): learn.fit(5)

[0, 20.658222198486328, 24.910701751708984, '00:00']
[1, 20.54644775390625, 24.757198333740234, '00:00']
[2, 20.48363494873047, 24.60427474975586, '00:00']
[3, 20.415714263916016, 24.45093536376953, '00:00']
[4, 20.329593658447266, 24.297313690185547, '00:00']


In [None]:
#hide
#Check validate works without any training
import torch.nn.functional as F
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(metrics=tst_metric)
preds,targs = learn.validate()

In [None]:
show_doc(ProgressCallback.before_fit)

<h4 id="ProgressCallback.before_fit" class="doc_header"><code>ProgressCallback.before_fit</code><a href="__main__.py#L6" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.before_fit</code>()

Setup the master bar over the epochs

In [None]:
show_doc(ProgressCallback.before_epoch)

<h4 id="ProgressCallback.before_epoch" class="doc_header"><code>ProgressCallback.before_epoch</code><a href="__main__.py#L15" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.before_epoch</code>()

Update the master bar

In [None]:
show_doc(ProgressCallback.before_train)

<h4 id="ProgressCallback.before_train" class="doc_header"><code>ProgressCallback.before_train</code><a href="__main__.py#L19" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.before_train</code>()

Launch a progress bar over the training dataloader

In [None]:
show_doc(ProgressCallback.before_validate)

<h4 id="ProgressCallback.before_validate" class="doc_header"><code>ProgressCallback.before_validate</code><a href="__main__.py#L23" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.before_validate</code>()

Launch a progress bar over the validation dataloader

In [None]:
show_doc(ProgressCallback.after_batch)

<h4 id="ProgressCallback.after_batch" class="doc_header"><code>ProgressCallback.after_batch</code><a href="__main__.py#L35" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.after_batch</code>()

Update the current progress bar

In [None]:
show_doc(ProgressCallback.after_train)

<h4 id="ProgressCallback.after_train" class="doc_header"><code>ProgressCallback.after_train</code><a href="__main__.py#L27" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.after_train</code>()

Close the progress bar over the training dataloader

In [None]:
show_doc(ProgressCallback.after_validate)

<h4 id="ProgressCallback.after_validate" class="doc_header"><code>ProgressCallback.after_validate</code><a href="__main__.py#L31" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.after_validate</code>()

Close the progress bar over the validation dataloader

In [None]:
show_doc(ProgressCallback.after_fit)

<h4 id="ProgressCallback.after_fit" class="doc_header"><code>ProgressCallback.after_fit</code><a href="__main__.py#L44" class="source_link" style="float:right">[source]</a></h4>

> <code>ProgressCallback.after_fit</code>()

Close the master bar

In [None]:
#export
class CollectDataCallback(Callback):
    "Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing"
    def before_fit(self): self.data = L()
    def after_batch(self): 
        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))

In [None]:
# export
class CudaCallback(Callback):
    "Move data to CUDA device"
    def __init__(self, device=None): self.device = ifnone(device, default_device())
    def before_batch(self): self.learn.xb,self.learn.yb = to_device(self.xb),to_device(self.yb)
    def before_fit(self): self.model.to(self.device)