In [None]:
%install '.package(path: "$cwd/FastaiNotebooks")' FastaiNotebooks

# Implement Callback Mechanism

In [None]:
import TensorFlow

// Please pick a better name for me! :-)
enum CallbackException {
    case cancelTraining
    case cancelEpoch
    case cancelBatch
}

enum CallbackEvent {
    // I haven't implemented all the events.
    case beginFit
    case beginEpoch
    case beginBatch
    case afterForwardsBackwards
    case afterFit
}

func defaultCallback(e: CallbackEvent) {}

struct DataBatch<Inputs: Differentiable & TensorGroup, Labels: TensorGroup>: TensorGroup {
    var xb: Inputs
    var yb: Labels    
}

In [None]:
class Learner<Opt: Optimizer, Labels: TensorGroup>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup
{
    typealias Model = Opt.Model
    var model: Model
    
    typealias Inputs = Model.Input
    
    // I'm getting some crashes in AD-generated code if I put a `lossFunc` in the learner.
    // So I'm putting a `lossWithGradient` for now, to work around this.
    // (model, context, inputs, labels) -> (loss, grad)
    typealias LossWithGradient = (Model, Context, Inputs, Labels) -> (Tensor<Float>, Model.AllDifferentiableVariables)
    var lossWithGradient: LossWithGradient
    
    var optimizer: Opt
    
    typealias Data = Dataset<DataBatch<Inputs, Labels>>
    var data: Data
    
    var context: Context = Context(learningPhase: .training)

    typealias Callback = (CallbackEvent) throws -> ()    
    var callback: Callback = defaultCallback
    
    var loss: Tensor<Float> = Tensor(0)
    var grad: Model.AllDifferentiableVariables = Model.AllDifferentiableVariables.zero
    
    var epoch: Int = 0
    var epochs: Int = 0
    
    init(
        model: Model,
        lossWithGradient: @escaping LossWithGradient,
        optimizer: Opt,
        data: Data
    ) {
        self.model = model
        self.lossWithGradient = lossWithGradient
        self.optimizer = optimizer
        self.data = data
    }
    
    func trainOneBatch(xb: Inputs, yb: Labels) throws {
        try callback(.beginBatch)
        (self.loss, self.grad) = lossWithGradient(model, self.context, xb, yb)
        defer {
            // Zero out the loss & gradient to ensure stale values aren't used.
            self.loss = Tensor(0)
            self.grad = Model.AllDifferentiableVariables.zero        
        }
        try callback(.afterForwardsBackwards)
        optimizer.update(&model.allDifferentiableVariables, along: self.grad)
    }
    
    func trainOneEpoch() throws {
        try callback(.beginEpoch)
        for batch in self.data {
            do {
                try trainOneBatch(xb: batch.xb, yb: batch.yb)
            } catch CallbackException.cancelBatch {}  // Continue
        }
    }

    func fit(epochs: Int) throws {
        // I haven't implemented validation.
        self.epochs = epochs
        do {
            try callback(.beginFit)
            defer {
                do {
                    try callback(.afterFit)
                } catch {
                    print("Error during callback(.afterFit): \(error)")
                }
            }
            for epoch in 1...epochs {
                self.epoch = epoch
                do {
                    try trainOneEpoch()
                } catch let error as CallbackException where error != .cancelTraining {}  // Continue
            }
        } catch is CallbackException {}  // Catch all CallbackExceptions.
    }
}

# Implement some example callbacks

In [None]:
func installProgress1<Opt, Labels>(on learner: Learner<Opt, Labels>) {
    let chainedCallback = learner.callback  // Keep a handle to the current callback.
    learner.callback = { event in
        switch event {
        case .beginEpoch:
            print("Starting new epoch: \(learner.epoch) of \(learner.epochs)!")
        default: break
        }
        try chainedCallback(event)  // Don't forget to call the previous callback!
    }
}


In [None]:
// Helper so you don't need to do the chaining yourself. :-)
func chainCallback<Opt, Labels>(on learner: Learner<Opt, Labels>, newCallback: @escaping (CallbackEvent) throws -> ()) {
    let existingCallback = learner.callback
    learner.callback = { event in
        try newCallback(event)
        try existingCallback(event)
    }
}


In [None]:
func installProgress<Opt, Labels>(on learner: Learner<Opt, Labels>) {
    chainCallback(on: learner) { event in
        switch event {
        case .beginEpoch:
            print("Starting new epoch: \(learner.epoch) of \(learner.epochs)!")
        default: break
        }
    }
}

In [None]:
%include "EnableIPythonDisplay.swift"
let plt = Python.import("matplotlib.pyplot")
IPythonDisplay.shell.enable_matplotlib("inline")

public class RecordedInfo {
    public var losses: [Float] = []
    public var lrs: [Float] = []
    
    func plot() {
        plt.plot(self.losses)
        plt.plot(self.lrs)
        // print(losses)
        // print(lrs)
    }
}

func installRecorder<Opt, Labels>(on learner: Learner<Opt, Labels>) -> RecordedInfo where Opt.Scalar == Float {
    let recorder = RecordedInfo()
    chainCallback(on: learner) { event in 
        switch event {
        case .beginFit:
            recorder.losses = []
            recorder.lrs = []
        case .afterForwardsBackwards:
            recorder.losses.append(learner.loss.scalar!)
            recorder.lrs.append(learner.optimizer.learningRate)
        default: break
        }
    }
    return recorder
}


In [None]:
func installParameterScheduler<Opt, Labels, Param>(
    on learner: Learner<Opt, Labels>,
    forParameter paramKeyPath: ReferenceWritableKeyPath<Learner<Opt, Labels>, Param>,
    schedule: @escaping (Float) -> Param) {
    chainCallback(on: learner) { event in
        switch event {
        case .beginBatch:
            learner[keyPath: paramKeyPath] = schedule(Float(learner.epoch) / Float(learner.epochs))
        default: break;
        }
    }
}


# The model and data

In [None]:
import FastaiNotebooks
import Path

var (xTrain,yTrain,xValid,yValid) = loadMNIST(path: Path.home/".fastai"/"data"/"mnist_tst")

In [None]:
xTrain = xTrain.reshaped(toShape: [60000, 784])

let (n,m) = (Int(xTrain.shape[0]),Int(xTrain.shape[1]))
let c = yTrain.max()+1

let nh = 50
let bs: Int32 = 64

let train_ds: Dataset<DataBatch> = Dataset(elements: DataBatch(xb: xTrain, yb: yTrain)).batched(Int64(bs))

In [None]:
let outputCount = 10

struct MyModel: Layer {
    var layer1 = Dense<Float>(inputSize: m, outputSize: nh, activation: relu)
    var layer2 = Dense<Float>(inputSize: nh, outputSize: outputCount)
    
    /// A silly non-trained parameter to show off the parameter scheduler.
    @noDerivative var sillyExtraBiasParam: Tensor<Float> = Tensor(zeros: [Int32(outputCount)])
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2) + sillyExtraBiasParam
    }
}

var model = MyModel()

func lossWithGrad(
    model: MyModel,
    in context: Context,
    inputs: Tensor<Float>,
    labels: Tensor<Int32>
) -> (Tensor<Float>, MyModel.AllDifferentiableVariables) {
    return model.valueWithGradient { model -> Tensor<Float> in
        let predictions = model.applied(to: inputs, in: context)
        return softmaxCrossEntropy(logits: predictions, labels: labels)
    }
}

# Run the learner

In [None]:
// Some typealiases to reduce repeatedly typing types.
typealias MyOptimizer = SGD<MyModel, Float>
typealias MyLearner = Learner<MyOptimizer, Tensor<Int32>>

In [None]:
let optimizer = MyOptimizer(learningRate: 0.01)

In [None]:
let learner = Learner(
    model: model,
    lossWithGradient: lossWithGrad,
    optimizer: optimizer,
    data: train_ds)

In [None]:
// We can't schedule the learning rate because the Optimizer protocol doesn't allow setting learning rates.
// If we change it to allow setting learning rates, `ParamScheduler` should allow setting learning rates,
// with `paramKeyPath: \MyLearner.optimizer.learningRate`.
installParameterScheduler(on: learner, forParameter: \MyLearner.model.sillyExtraBiasParam) { t in
    if t < 0.5 {
        return Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    } else {
        return Tensor([10, 20, 30, 0, 0, 0, 0, 0, 0, 0])
    }
}

In [None]:
let recorder = installRecorder(on: learner)

In [None]:
installProgress(on: learner)

In [None]:
learner.fit(epochs: 6)

In [None]:
recorder.plot()