In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab

In [None]:
#export
from fastai.basics import *

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#default_exp callback.rnn

# Callback for RNN training

> Callback that uses the outputs of language models to add AR and TAR regularization

In [None]:
#export
@docs
class ModelResetter(Callback):
    "`Callback` that resets the model at each validation/training step"
    def before_train(self):    self.model.reset()
    def before_validate(self): self.model.reset()
    def after_fit(self):      self.model.reset()

    _docs = dict(before_train="Reset the model before training",
                 before_validate="Reset the model before validation",
                 after_fit="Reset the model after fitting")

In [None]:
#export
@docs
class RNNRegularizer(Callback):
    "`Callback` that adds AR and TAR regularization in RNN training"
    def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta

    def after_pred(self):
        self.raw_out = self.pred[1][-1] if is_listy(self.pred[1]) else self.pred[1]
        self.out     = self.pred[2][-1] if is_listy(self.pred[2]) else self.pred[2]
        self.learn.pred = self.pred[0]

    def after_loss(self):
        if not self.training: return
        if self.alpha != 0.:  self.learn.loss += self.alpha * self.out.float().pow(2).mean()
        if self.beta != 0.:
            h = self.raw_out
            if len(h)>1: self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()

    _docs = dict(after_pred="Save the raw and dropped-out outputs and only keep the true output for loss computation",
                 after_loss="Add AR and TAR regularization")

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.l