<a href="https://colab.research.google.com/github/dbolella/s4tf-lenet-mnist/blob/master/lenet_mnist_swift_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LeNet-5 & MNIST using Swift for Tensorflow
by Danny Bolella

To learn more about how this Colab works, check out the associated Medium article at: 

This Colab is a reworking of the official S4TF Example found at: https://github.com/tensorflow/swift-models/tree/master/Examples/LeNet-MNIST.

## Installing and Importing Libraries
First, we pull in 2 libraries as swift packages from the official S4TF models repo.  We use `%install` to accomplish this.  Once complete, we then import the libraries we'll be using (Tensorflow is already available on Colab, no need to install).

In [0]:
%install '.package(url: "https://github.com/tensorflow/swift-models.git", .branch("master"))' ImageClassificationModels Datasets

import TensorFlow
import Datasets
import ImageClassificationModels

## Model, Dataset, Optimizer... Oh My!
Next, we instantiate the dataset, model, and optimizer we will be using.  We also setup our epochCount (the number of times we'll train our model) and batchSize (how much data we'll train with at a time).

One last thing to do is setup our test data into batches using our designated batchSize.

In [0]:
let dataset = MNIST()

var model = LeNet()

let optimizer = SGD(for: model, learningRate: 0.1)

let epochCount = 12

let batchSize = 128

let testBatches = dataset.testDataset.batched(batchSize)

## Benchmarking Prep
Lastly, we create a `struct` that we will use to hold our training and testing benchmarks per epoch.  Note that we also have a function in our struct to update our `GuessCount` stats.  This eliminates duplicate code in our training and testing loops.

In [0]:

struct Statistics {
    var correctGuessCount: Int = 0
    var totalGuessCount: Int = 0
    var totalLoss: Float = 0
    
    mutating func updateGuessCounts(logits: Tensor<Float>, labels: Tensor<Int32>, batchSize: Int) {
      let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
      self.correctGuessCount += Int(
            Tensor<Int32>(correctPredictions).sum().scalarized())
      self.totalGuessCount += batchSize
    }
}

## Training Day
Lastly, we run our training!  We run the training loop based on our `epochCount`.  Each time we do, we loop through batches of our data, run it through our model, update our benchmarks, and optimize along the gradients.  

At the end of each epoch, we print out our benchmark data.  We should see our loss decrease and our accuracy increase with each pass of training our model.

In [0]:
print("Beginning training...")

// The training loop.
for epoch in 1...epochCount {
    var trainStats = Statistics()
    var testStats = Statistics()
    let trainingShuffled = dataset.trainingDataset.shuffled(
        sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
    Context.local.learningPhase = .training
    for batch in trainingShuffled.batched(batchSize) {
        let (labels, images) = (batch.label, batch.data)
        // Compute the gradient with respect to the model.
        let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(images)
            trainStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
            return softmaxCrossEntropy(logits: logits, labels: labels)
        }
        trainStats.totalLoss += loss.scalarized()
        optimizer.update(&model, along: gradients)
    }

    Context.local.learningPhase = .inference
    for batch in testBatches {
        let (labels, images) = (batch.label, batch.data)
        // Compute loss on test set
        let logits = model(images)
        testStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
        let loss = softmaxCrossEntropy(logits: logits, labels: labels)
        testStats.totalLoss += loss.scalarized()
    }

    let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
    let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
    print("""
          [Epoch \(epoch)] \
          Training Loss: \(trainStats.totalLoss), \
          Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
          (\(trainAccuracy)), \
          Test Loss: \(testStats.totalLoss), \
          Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
          (\(testAccuracy))
          """)
}