## Notes

Porting `class Runner` to Swift is a WIP.

* `TrainerCallback` is a class, not a protocol, because `Trainer` needs to store a `[TrainerCallback]` array.

Todos:
* Improve naming, currently many names come directly from Python

In [None]:
import TensorFlow

enum CallbackKind {
    case begin_batch
    case after_pred
    case after_loss
    case after_backward
    case after_step
    case after_batch

    case begin_fit
    case begin_epoch
    case begin_validate
    case after_epoch
    case after_fit
}

open class TrainerCallback<Model : Layer, Opt : Optimizer> {
    public typealias Callback = () -> Void
    open var begin_batch: Callback? { return nil }
    open var after_pred: Callback? { return nil }
    open var after_loss: Callback? { return nil }
    open var after_backward: Callback? { return nil }
    open var after_step: Callback? { return nil }
    open var after_batch: Callback? { return nil }

    open var begin_fit: Callback? { return nil }
    open var begin_epoch: Callback? { return nil }
    open var begin_validate: Callback? { return nil }
    open var after_epoch: Callback? { return nil }
    open var after_fit: Callback? { return nil }

    func callback(_ kind: CallbackKind) -> Callback? {
        switch kind {
        case .begin_batch: return begin_batch
        case .after_pred: return after_pred
        case .after_loss: return after_loss
        case .after_backward: return after_backward
        case .after_step: return after_step
        case .after_batch: return after_batch

        case .begin_fit: return begin_fit
        case .begin_epoch: return begin_epoch
        case .begin_validate: return begin_validate
        case .after_epoch: return after_epoch
        case .after_fit: return after_fit
        }
    }
}

final class TrainingEvaluationCallback<Model, Opt: Optimizer>: TrainerCallback<Model, Opt>
    where Opt.Model == Model, Opt.Scalar : TensorFlowFloatingPoint,
          Model.Input == Tensor<Opt.Scalar>,
          Model.Output == Tensor<Opt.Scalar>
{
    var trainer: Trainer<Model, Opt>
    init(trainer: Trainer<Model, Opt>) {
        self.trainer = trainer
    }

    var epochCount: Float = 0
    var iterationCount: Int = 0

    override var begin_fit: Callback {
        return {
            self.epochCount = 0
            self.iterationCount = 0
        }
    } 

    override var after_batch: Callback {
        return {
            guard self.trainer.isTraining else { return }
            self.epochCount += 1 / Float(self.iterationCount)
            self.iterationCount += 1
        }
    }

    override var begin_epoch: Callback {
        return {
            self.epochCount = Float(self.trainer.epochCount)
            // NOTE: `module.train()` is an API from PyTorch.
            // It changes all layers (notably dropout and batchnorm) to work in training mode.
            // We cannot represent this because our training flag is passed as an argument (in `Context`).
            // trainer.learner.model.train()
            self.trainer.isTraining = true
        }
    }

    override var begin_validate: Callback {
        return {
            // NOTE: `module.eval()` is an API from PyTorch.
            // It changes all layers (notably dropout and batchnorm) to work in inference mode.
            // We cannot represent this because our training flag is passed as an argument (in `Context`).
            // trainer.model.eval()
            self.trainer.isTraining = false
        }
    }
}

class AverageStatistics<Model : Layer> {
    // typealias Metric = (Tensor<Float>, Tensor<Float>) -> Tensor<Float>
    typealias Metric = (Model.Output, Model.Output) -> Tensor<Float>

    var metrics: [Metric]
    var isTraining: Bool

    var totalLoss: Tensor<Float> = Tensor(0)
    var count: Int = 0
    var totalMetrices: [Tensor<Float>] = []

    init(_ metrics: [Metric], isTraining: Bool) {
        self.metrics = metrics
        self.isTraining = isTraining
    }

    func reset() {
        totalLoss = Tensor<Float>(0)
        count = 0
        totalMetrices = Array(repeating: Tensor(0), count: metrics.count)
    }

    var allStatistics: [Tensor<Float>] {
        return [totalLoss] + totalMetrices
    }

    var averageStatistics: [Tensor<Float>] {
        return allStatistics.map { $0 / Float(count) }
    }

    func accumulate<Opt>(trainer: Trainer<Model, Opt>) where Model == Opt.Model {
        let batchSize = trainer.data.shape[0]
        self.totalLoss += trainer.loss * Float(batchSize)
        self.count += Int(batchSize)
        for (i, metric) in metrics.enumerated() {
            self.totalMetrices[i] += metric(trainer.prediction, trainer.labels) * Float(batchSize)
        }
    }
}

final class AverageStatisticsCallback<Model : Layer, Opt : Optimizer> : TrainerCallback<Model, Opt>
    where Opt.Model == Model, Opt.Scalar : TensorFlowFloatingPoint,
          Model.Input == Tensor<Opt.Scalar>,
          Model.Output == Tensor<Opt.Scalar>
{
    var trainer: Trainer<Model, Opt>
    private var trainingStatistics: AverageStatistics<Model>
    private var validationStatistics: AverageStatistics<Model>

    init(trainer: Trainer<Model, Opt>, metrics: [AverageStatistics<Model>.Metric]) {
        self.trainer = trainer
        self.trainingStatistics = AverageStatistics(metrics, isTraining: true)
        self.validationStatistics = AverageStatistics(metrics, isTraining: false)
    }

    var statistics: AverageStatistics<Model> {
        if trainer.isTraining {
            return trainingStatistics
        } else {
            return validationStatistics
        }
    }

    override var begin_epoch: Callback {
        return {}
    }

    override var after_loss: Callback {
        return {}
    }

    override var after_epoch: Callback {
        return {}
    }
}

/*
// NOTE: We may not want a `Learner` abstraction.
struct Learner<Model : Layer, Opt : Optimizer> {
    var optimizer: Opt
    var model: Model
    var lossFunction: (Model.Output) -> Tensor<Float>
    var data: (train: Batch, validation: Batch)
}
*/

struct Trainer<Model, Opt : Optimizer>
    where Opt.Model == Model, Opt.Scalar : TensorFlowFloatingPoint,
          Model.Input == Tensor<Opt.Scalar>,
          Model.Output == Tensor<Opt.Scalar>
{
    typealias Scalar = Opt.Scalar
    typealias Batch = (data: Tensor<Scalar>, labels: Tensor<Scalar>)

    var callbacks: [TrainerCallback<Model, Opt>]
    var stop: Bool = false

    var isTraining: Bool
    var data: Model.Input
    var labels: Model.Output
    var prediction: Model.Output
    var loss: Tensor<Float>
    var epochCount: Int
    var epoch: Int
    // var learner: Learner<Model, Opt>

    var optimizer: Opt
    var model: Model
    var lossFunction: (Model.Output) -> Tensor<Float>

    mutating func runOneBatch(_ batch: Batch) {
        self.data = batch.data
        self.labels = batch.labels
        invokeCallback(.begin_batch)
        
        self.prediction = model.applied(to: self.data)
        invokeCallback(.after_pred)
        self.loss = lossFunction(self.prediction)
        invokeCallback(.after_loss)

        // let dloss = gradient(at: 
        invokeCallback(.after_backward)
        invokeCallback(.after_step)
    }

    mutating func runAllBatches(_ batches: [Batch]) {
        let iterationCount = batches.count
        for batch in batches {
            if self.stop { break }
            runOneBatch(batch)
        }
        self.stop = false 
    }

    mutating func fit(epochCount: Int) {
        self.epochCount = epochCount
        invokeCallback(.begin_fit)
        for epoch in 0..<epochCount {
            self.epoch = epoch
            invokeCallback(.begin_epoch)
        }
    }
}

extension Trainer {
    // TODO: Use a better way to indicate if errors occurred.
    // fast.ai API makes `Callback`s return a boolean, with true indicating error occurred.
    func invokeCallback(_ kind: CallbackKind) {
        for callback in callbacks {
            guard let cb = callback.callback(kind) else { continue }
            cb()
        }
    }
}


: ignored

## Note

Below is code directly from the Python notebook.
To be ported to Swift.

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

: ignored

In [None]:
#export
from exp.nb_03 import *

## DataBunch/Learner

In [None]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,64
loss_func = F.cross_entropy

Factor out the connected pieces of info out of the fit() argument list

`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`

In [None]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl):
        self.train_dl,self.valid_dl = train_dl,valid_dl
        self.c = self.train_ds.y.max().item()+1
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [None]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs))

In [None]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb,yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        learn.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [None]:
loss,acc = fit(1, learn)

0 tensor(0.1720) tensor(0.9471)


## CallbackHandler

Add callbacks so we can remove complexity from loop, and make it flexible:

In [None]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

In [None]:
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

In [None]:
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if not cb.after_epoch(): break
    cb.after_fit()

In [None]:
class CallbackHandler():
    def __init__(self): self.stop,self.cbs = False,[]

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        return True
    def after_fit(self): pass
    
    def begin_epoch(self, epoch):
        learn.model.train()
        self.in_train=True
        return True
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        return True
    def after_epoch(self): return True
    
    def begin_batch(self, xb, yb): return True
    def after_loss(self, loss): return self.in_train
    def after_backward(self): return True
    def after_step(self): return True
    
    def do_stop(self):
        try:     return self.stop
        finally: self.stop = False

In [None]:
fit(1, learn, cb=CallbackHandler())

This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!

## Runner

In [None]:
#export
class Callback():
    _order=0
    def __init__(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k)

class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.n_epochs=0.
        self.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.n_epochs += 1./self.iters
        self.n_iter   += 1
        
    def begin_epoch(self):
        self.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

In [None]:
#export
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, tuple): return list(o)
    return [o]

In [None]:
#export
class Runner():
    def __init__(self, cbs=None):
        self.stop,self.cbs = False,[TrainEvalCallback(self)]+listify(cbs)

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

    def fit(self, epochs, learn):
        self.epochs,self.learn = epochs,learn

        try:
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad(): 
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f(): return True
        return False

In [None]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train
    
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

class AvgStatsCallback(Callback):
    def __init__(self, run, metrics):
        super().__init__(run)
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def stats(self): return self.train_stats if self.in_train else self.valid_stats

    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        with torch.no_grad(): self.stats().accumulate(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
run = Runner()
stats = AvgStatsCallback(run, [accuracy])
run.cbs.append(stats)

In [None]:
run.fit(3, learn)

train: [0.312986328125, tensor(0.9031)]
valid: [0.145713427734375, tensor(0.9578)]
train: [0.13666587890625, tensor(0.9582)]
valid: [0.11328406982421875, tensor(0.9675)]
train: [0.10290517578125, tensor(0.9683)]
valid: [0.09884705810546875, tensor(0.9717)]


In [None]:
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9

## Export

In [None]:
!./notebook2script.py 04_callbacks.ipynb

Converted 04_callbacks.ipynb to nb_04.py
