# Callbacks version 1

In [None]:
%install '.package(path: "$cwd/FastaiNotebook_03_minibatch_training")' FastaiNotebook_03_minibatch_training

## Load data

In [None]:
import FastaiNotebook_03_minibatch_training

In [None]:
// export
import Path
import TensorFlow

In [None]:
var (xTrain,yTrain,xValid,yValid) = loadMNIST(path: mnistPath, flat: true)

In [None]:
let (n,m) = (Int(xTrain.shape[0]),Int(xTrain.shape[1]))
let c = yTrain.max()+1
print(n,m,c)

Those can't be used to define a model cause they're not Ints though...

In [None]:
let (n,m) = (60000,784)
let c = 10
let nHid = 50

In [None]:
// export
public struct BasicModel: Layer {
    public var layer1: Dense<Float>
    public var layer2: Dense<Float>
    
    public init(nIn: Int, nHid: Int, nOut: Int){
        layer1 = Dense(inputSize: nIn, outputSize: nHid, activation: relu)
        layer2 = Dense(inputSize: nHid, outputSize: nOut)
    }
    
    @differentiable
    public func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2)
    }
}

In [None]:
var model = BasicModel(nIn: m, nHid: nHid, nOut: c)

In [None]:
// export
public struct DataBunch<Element> where Element: TensorGroup{
    public var train: Dataset<Element>
    public var valid: Dataset<Element>
    
    public init(train: Dataset<Element>, valid: Dataset<Element>) {
        self.train = train
        self.valid = valid
    }
}

In [None]:
//export
public func mnistDataBunch(path: Path = mnistPath, flat: Bool = false, bs: Int = 64
                          ) -> DataBunch<DataBatch<Tensor<Float>, Tensor<Int32>>>{
    let (xTrain,yTrain,xValid,yValid) = loadMNIST(path: path, flat: flat)
    return DataBunch(train: Dataset(elements:DataBatch(xb:xTrain, yb:yTrain)).batched(Int64(bs)), 
                     valid: Dataset(elements:DataBatch(xb:xValid, yb:yValid)).batched(Int64(bs)))
}

In [None]:
let data = mnistDataBunch(flat: true)

## Learner (Marc's version)

In [None]:
// Please pick a better name for me! :-)
enum CallbackException {
    case cancelTraining
    case cancelEpoch
    case cancelBatch
}

enum CallbackEvent {
    // I haven't implemented all the events.
    case beginFit
    case beginEpoch
    case beginBatch
    case beginValidate
    case afterForwardsBackwards
    case afterEpoch
    case afterFit
}

func defaultCallback(e: CallbackEvent) {}

Basic class

In [None]:
class Learner<Opt: Optimizer, Labels: TensorGroup>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup
{
    typealias Model = Opt.Model
    var model: Model
    
    typealias Inputs = Model.Input
    // I'm getting some crashes in AD-generated code if I put a `lossFunc` in the learner.
    // So I'm putting a `lossWithGradient` for now, to work around this.
    // (model, context, inputs, labels) -> (loss, grad)
    typealias LossWithGradient = (Model, Context, Inputs, Labels
                                 ) -> (Tensor<Float>, Model.Output?, Model.AllDifferentiableVariables)
    var lossWithGradient: LossWithGradient
    
    var optimizer: Opt
    
    typealias Data = DataBunch<DataBatch<Inputs, Labels>>
    var data: Data
    
    var context: Context = Context(learningPhase: .training)

    typealias Callback = (CallbackEvent) throws -> ()    
    var callback: Callback = defaultCallback
    
    //Is there a better way tonitiliaze those to not make them Optionals?
    var input: Model.Input? = nil
    var target: Labels? = nil
    var output: Model.Output? = nil
    
    var loss: Tensor<Float> = Tensor(0)
    var grad: Model.AllDifferentiableVariables = Model.AllDifferentiableVariables.zero
    
    var inTrain: Bool = false
    var epoch: Int = 0
    var epochs: Int = 0
    var nEpochs: Float = 0.0
    var nIter: Int = 0
    var iters: Int = 0
    
    init(
        model: Model,
        lossWithGradient: @escaping LossWithGradient,
        optimizer: Opt,
        data: Data
    ) {
        self.model = model
        self.lossWithGradient = lossWithGradient
        self.optimizer = optimizer
        self.data = data
    }
}

Then let's write the parts of the training loop:

In [None]:
extension Learner{
    func trainOneBatch(xb: Inputs, yb: Labels) throws {
        try callback(.beginBatch)
        (self.loss, self.output, self.grad) = lossWithGradient(model, self.context, xb, yb)
        defer {
            // Zero out the loss & gradient to ensure stale values aren't used.
            self.loss = Tensor(0)
            self.grad = Model.AllDifferentiableVariables.zero        
        }
        try callback(.afterForwardsBackwards)
        if self.inTrain {optimizer.update(&model.allDifferentiableVariables, along: self.grad)}
    }
    
    func trainOneEpoch() throws {
        let ds = self.inTrain ? self.data.train : self.data.valid
        self.iters = ds.count(where: {_ in true})
        for batch in ds {
            (self.input,self.target) = (batch.xb,batch.yb)
            do { try trainOneBatch(xb: batch.xb, yb: batch.yb)} 
            catch CallbackException.cancelBatch {}  // Continue
        }
    }
}

And the whole fit function.

In [None]:
extension Learner{
    func fit(epochs: Int) throws {
        self.epochs = epochs
        do {
            try callback(.beginFit)
            defer {
                do { try callback(.afterFit) } 
                catch { print("Error during callback(.afterFit): \(error)")}
            }
            for epoch in 1...epochs {
                self.epoch = epoch
                try callback(.beginEpoch)
                do { try trainOneEpoch()} 
                catch let error as CallbackException where error != .cancelTraining {}  // Continue
                try callback(.beginValidate)
                do { try trainOneEpoch()} 
                catch let error as CallbackException where error != .cancelTraining {}  // Continue
                do { try callback(.afterEpoch) }
                catch { print("Error during callback(.afterEpoch): \(error)")}
            }
        } catch is CallbackException {}  // Catch all CallbackExceptions.
    }
}

In [None]:
// Helper so you don't need to do the chaining yourself. :-)
func chainCallback<Opt, Labels>(on learner: Learner<Opt, Labels>, newCallback: @escaping (CallbackEvent) throws -> ()) {
    let existingCallback = learner.callback
    learner.callback = { event in
        try newCallback(event)
        try existingCallback(event)
    }
}

In [None]:
func installTrainEval<Opt, Labels>(on learner: Learner<Opt, Labels>) {
    chainCallback(on: learner) { event in
        switch event {
        case .beginFit:
            learner.nEpochs = 0.0
            learner.nIter = 0
        case .beginEpoch:
            print("Beginning epoch \(learner.epoch)")
            learner.nEpochs = Float(learner.epoch)
            learner.context = Context(learningPhase: .training)
            learner.inTrain = true
        case .afterForwardsBackwards:
            if learner.inTrain{
                learner.nEpochs += 1.0/Float(learner.iters)
                learner.nIter   += 1
            }
        case .beginValidate:
            learner.context = Context(learningPhase: .inference)
            learner.inTrain = false
        default: break
        }
    }
}

In [None]:
public class AverageMetrics {
    public var metrics: [[Tensor<Float>]] = []
    var count: Int = 0
    var partials: [Tensor<Float>] = []
}

func installAverageMetric<Opt, Labels>(
    _ metrics: [(Tensor<Float>, Tensor<Int32>) -> Tensor<Float>],
    on learner: Learner<Opt, Labels>
    ) -> AverageMetrics{
    let avgMetrics = AverageMetrics()
    chainCallback(on: learner) { event in
        switch event {
        case .beginEpoch:
            avgMetrics.count = 0
            avgMetrics.partials = Array(repeating: Tensor(0), count: metrics.count+1)
        case .afterForwardsBackwards:
            if !learner.inTrain{
                if let target = learner.target as? Tensor<Int32>{
                    avgMetrics.count += Int(target.shape[0])
                    avgMetrics.partials[0] += Float(target.shape[0]) * learner.loss
                    for i in 0..<metrics.count{
                        avgMetrics.partials[i+1] += metrics[i]((learner.output as! Tensor<Float>), target) * Float(target.shape[0])
                    }
                }
            }
        case .afterEpoch:
            for i in 0..<metrics.count+1{
                avgMetrics.partials[i] = avgMetrics.partials[i]/Float(avgMetrics.count)
            }
            avgMetrics.metrics.append(avgMetrics.partials)
            print(avgMetrics.partials)
        default: break
        }
    }
    return avgMetrics
}

### Test

In [None]:
func lossWithGrad(
    model: BasicModel,
    in context: Context,
    inputs: Tensor<Float>,
    labels: Tensor<Int32>
) -> (Tensor<Float>, BasicModel.Output, BasicModel.AllDifferentiableVariables) {
    var outputs: BasicModel.Output? = nil
    let (loss, grads) = model.valueWithGradient { model -> Tensor<Float> in
        let predictions = model.applied(to: inputs, in: context)
        outputs = predictions
        return softmaxCrossEntropy(logits: predictions, labels: labels)
    }
    return (loss, outputs!, grads)
}

In [None]:
let opt = SGD<BasicModel, Float>(learningRate: 1e-2)

In [None]:
let learner = Learner(
    model: model,
    lossWithGradient: lossWithGrad,
    optimizer: opt,
    data: data)

In [None]:
installTrainEval(on: learner)
let avgMetrics = installAverageMetric([accuracy], on: learner)

In [None]:
learner.fit(epochs: 2)