# Implement Callback Mechanism

In [None]:
import TensorFlow

struct DataBatch {
    // Simplifying assumption: Model inputs and outputs are Tensor<Float>
    var xb: Tensor<Float>
    var yb: Tensor<Float>
}

struct Data {
    // Simplifying assumption: Batches are in an array.
    var trainBatches: [DataBatch]
}

enum CallbackResult {
    case proceed
    case skip
    case stop
}

enum CallbackEvent {
    // I haven't implemented all the events.
    case beginFit
    case beginEpoch
    case beginBatch
    case afterForwardsBackwards
}

class Callback<Opt: Optimizer>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input == Tensor<Float>,
      Opt.Model.Output == Tensor<Float> {
    func apply(event: CallbackEvent, learner: Learner<Opt>) -> CallbackResult {
        return .proceed
    }
}

class Learner<Opt: Optimizer>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input == Tensor<Float>,
      Opt.Model.Output == Tensor<Float>
{
    typealias Model = Opt.Model
    var model: Model
    
    // I'm getting some crashes in AD-generated code if I put a `lossFunc` in the learner.
    // So I'm putting a `lossWithGradient` for now, to work around this.
    // (model, context, inputs, labels) -> (loss, grad)
    typealias LossWithGradient = (Model, Context, Tensor<Float>, Tensor<Float>) -> (Tensor<Float>, Model.AllDifferentiableVariables)
    var lossWithGradient: LossWithGradient
    
    var optimizer: Opt
    var data: Data
    var callbacks: [Callback<Opt>]
    
    var loss: Tensor<Float> = Tensor(0)
    var grad: Model.AllDifferentiableVariables = Model.AllDifferentiableVariables.zero
    
    var epoch: Int = 0
    var epochs: Int = 0
    
    init(
        model: Model,
        lossWithGradient: @escaping LossWithGradient,
        optimizer: Opt,
        data: Data,
        callbacks: [Callback<Opt>]
    ) {
        self.model = model
        self.lossWithGradient = lossWithGradient
        self.optimizer = optimizer
        self.data = data
        self.callbacks = callbacks
    }
    
    private func resetPerBatchValues() {
        self.loss = Tensor(0)
        self.grad = Model.AllDifferentiableVariables.zero        
    }
    
    func trainOneBatch(xb: Tensor<Float>, yb: Tensor<Float>) -> CallbackResult {
        var cbResult = runCallbacks(event: .beginBatch)
        if cbResult != .proceed {
            return cbResult
        }
        let context = Context(learningPhase: .training)
        (self.loss, self.grad) = lossWithGradient(model, context, xb, yb)
        defer {
            // Zero out the loss & gradient to ensure stale values aren't used.
            resetPerBatchValues()
        }
        cbResult = runCallbacks(event: .afterForwardsBackwards)
        if cbResult != .proceed {
            return cbResult
        }
        optimizer.update(&model.allDifferentiableVariables, along: self.grad)
        return .proceed
    }
    
    func trainOneEpoch() -> CallbackResult {
        switch runCallbacks(event: .beginEpoch) {
            case .stop: return .stop
            case .skip:
                print("Unexpected .skip returned from running callbacks(event: .beginEpoch)")
                return .skip
            case .proceed: break
        }
        for batch in self.data.trainBatches {
            let cbResult = trainOneBatch(xb: batch.xb, yb: batch.yb)
            if cbResult != .proceed {
                return cbResult
            }
        }
        return .proceed
    }

    func fit(epochs: Int) {
        // I haven't implemented validation.
        self.epochs = epochs
        var cbResult = runCallbacks(event: .beginFit)
        if cbResult != .proceed {
            return
        }
        for epoch in 1...epochs {
            self.epoch = epoch
            cbResult = trainOneEpoch()
            if cbResult != .proceed {
                return
            }
        }
    }
    
    private func runCallbacks(event: CallbackEvent) -> CallbackResult {
        for callback in callbacks {
            let cbResult = callback.apply(event: event, learner: self)
            if cbResult != .proceed {
                return cbResult
            }
        }
        return .proceed
    }
}

# Implement some example callbacks

In [None]:
%include "EnableIPythonDisplay.swift"
let plt = Python.import("matplotlib.pyplot")
IPythonDisplay.shell.enable_matplotlib("inline")

class Recorder<Opt: Optimizer> : Callback<Opt>
// Hmm, this boilerplate is kind of annoying.
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input == Tensor<Float>,
      Opt.Model.Output == Tensor<Float>,
      // Notice that we can add constraints so that this callback only works with certain types of learners.
      // Here, we require that the optimizer's scalar type is float so that `plt.plot` understands the
      // learning rate.
      Opt.Scalar == Float {
          
    var losses: [Float] = []
    var lrs: [Float] = []
          
    override func apply(event: CallbackEvent, learner: Learner<Opt>) -> CallbackResult {
        switch event {
        case .beginFit:
            losses = []
            lrs = []
        case .afterForwardsBackwards:
            losses.append(learner.loss.scalar!)
            lrs.append(learner.optimizer.learningRate)
        default: break
        }
        return .proceed
    }
          
    func plotLosses() {
        plt.plot(losses)
    }
          
    func plotLrs() {
        plt.plot(lrs)
    }
}

In [None]:
class ParamScheduler<Opt: Optimizer, Param> : Callback<Opt>
// Hmm, this boilerplate is kind of annoying.
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input == Tensor<Float>,
      Opt.Model.Output == Tensor<Float> {
    
    let paramKeyPath: ReferenceWritableKeyPath<Learner<Opt>, Param>
    let schedule: (Float) -> Param
    
    init(paramKeyPath: ReferenceWritableKeyPath<Learner<Opt>, Param>, schedule: @escaping (Float) -> Param) {
        self.paramKeyPath = paramKeyPath
        self.schedule = schedule
    }
          
    override func apply(event: CallbackEvent, learner: Learner<Opt>) -> CallbackResult {
        switch event {
        case .beginBatch:
            learner[keyPath: paramKeyPath] = schedule(Float(learner.epoch) / Float(learner.epochs))
        default: break
        }
        return .proceed
    }
}

# A simple model and data

In [None]:
// Sum of the two inputs is the output.
let data = Data(trainBatches: [
    DataBatch(xb: [[0, 1], [2, 3]], yb: [[1], [5]]),
    DataBatch(xb: [[-3, 4], [-10, 2]], yb: [[1], [-8]]),
])

In [None]:
struct SillyModel : Layer {
    var dense: Dense<Float> = Dense(inputSize: 2, outputSize: 1)
    
    // A non-trained parameter to help illustrate the parameter scheduler.
    @noDerivative var sillyExtraBiasParam: Float = 0
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return dense.applied(to: input, in: context) + sillyExtraBiasParam
    }
}

In [None]:
func lossWithGrad(model: SillyModel, in context: Context, inputs: Tensor<Float>, labels: Tensor<Float>) -> (Tensor<Float>, SillyModel.AllDifferentiableVariables) {
    return model.valueWithGradient { model -> Tensor<Float> in
        let predictions = model.applied(to: inputs, in: context)
        return (predictions - labels).squared().mean()
    }
}

In [None]:
let model = SillyModel()

# Run the learner

In [None]:
// Some typealiases to reduce repeatedly typing types.
typealias MyOptimizer = SGD<SillyModel, Float>
typealias MyLearner = Learner<MyOptimizer>

In [None]:
let optimizer = MyOptimizer(learningRate: 0.01)

In [None]:
// We can't schedule the learning rate because the Optimizer protocol doesn't allow setting learning rates.
// If we change it to allow setting learning rates, `ParamScheduler` should allow setting learning rates,
// with `paramKeyPath: \MyLearner.optimizer.learningRate`.
let scheduler = ParamScheduler(paramKeyPath: \MyLearner.model.sillyExtraBiasParam) { t in
    if t < 0.5 {
        return -10
    } else {
        return 0
    }
}

In [None]:
let recorder = Recorder<MyOptimizer>()

In [None]:
let learner = Learner(
    model: model,
    lossWithGradient: lossWithGrad,
    optimizer: optimizer,
    data: data,
    callbacks: [
        recorder,
        scheduler
    ])

In [None]:
learner.fit(epochs: 100)

In [None]:
recorder.plotLosses()
plt.show()