In [None]:
%install '.package(path: "$cwd/FastaiNotebook_04_callbacks")' FastaiNotebook_04_callbacks

In [None]:
import FastaiNotebook_04_callbacks
import TensorFlow
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

In [None]:
let data = mnistDataBunch(flat: true)

In [None]:
let firstBatch = data.train.first(where: { _ in true })!
let batchShape = firstBatch.xb.shape
let batchSize = batchShape.dimensions[0]
let exampleSize = batchShape.dimensions[1]
let exampleSide: Int32 = 28
assert(exampleSide * exampleSide == exampleSize)
print("Batch size: \(batchSize)")
print("Example size: \(exampleSize)")

let classCount = firstBatch.yb.shape.dimensions[1]
print("Class count: \(classCount)")

In [None]:
struct CnnModel: Layer {
    var reshapeToSquare = Reshape<Float>([batchSize, exampleSide, exampleSide, 1])
    var conv1 = Conv2D<Float>(
        filterShape: (5, 5, 1, 8),
        strides: (2, 2),
        padding: .same,
        activation: relu)
    var conv2 = Conv2D<Float>(
        filterShape: (3, 3, 8, 16),
        strides: (2, 2),
        padding: .same,
        activation: relu)
    var conv3 = Conv2D<Float>(
        filterShape: (3, 3, 16, 32),
        strides: (2, 2),
        padding: .same,
        activation: relu)
    var conv4 = Conv2D<Float>(
        filterShape: (3, 3, 32, 32),
        strides: (2, 2),
        padding: .same,
        activation: relu)
    
    // The Python notebook uses "AdaptiveAvgPool2d", which I assume is different from "AvgPool2D".
    // But our layers lib only has "AvgPool2D" and that sounds good enough for now.
    var pool = AvgPool2D<Float>(poolSize: (2, 2), strides: (1, 1))
    
    var flatten = Flatten<Float>()
    var linear = Dense<Float>(inputSize: 32, outputSize: Int(classCount))
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        // There isn't a "sequenced" defined with enough layers.
        let intermediate =  input.sequenced(
            in: context,
            through: reshapeToSquare, conv1, conv2, conv3, conv4)
        return intermediate.sequenced(in: context, through: pool, flatten, linear)
    }
}

In [None]:
// Test that data goes through the model as expected.
let predictions = CnnModel().applied(to: firstBatch.xb, in: Context(learningPhase: .training))

In [None]:
predictions.shape