<a href="https://colab.research.google.com/github/dbolella/s4tf-lenet-mnist/blob/master/lenet_mnist_swift_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LeNet-5 & MNIST using Swift for Tensorflow
by Danny Bolella

Edit: Updated to use Swift for Tensorflow v. 0.6

To learn more about how this Colab works, check out the associated Medium article at: https://heartbeat.fritz.ai/swifty-ml-an-intro-to-swift-for-tensorflow-9edc7045bc0c

This Colab is a reworking of the official S4TF Example found at: https://github.com/tensorflow/swift-models/tree/master/Examples/LeNet-MNIST.

## Installing and Importing Libraries
First, we pull in 2 libraries as swift packages from the official S4TF models repo.  We use `%install` to accomplish this.  Once complete, we then import the libraries we'll be using (Tensorflow is already available on Colab, no need to install).

In [1]:
%install '.package(url: "https://github.com/tensorflow/swift-models.git", .branch("master"))' ImageClassificationModels Datasets

import TensorFlow
import Datasets
import ImageClassificationModels

Installing packages:
	.package(url: "https://github.com/tensorflow/swift-models.git", .branch("master"))
		ImageClassificationModels
		Datasets
With SwiftPM flags: []
Working in: /tmp/tmp_82lvl7h/swift-install
Fetching https://github.com/tensorflow/swift-models.git
Fetching https://github.com/kylef/Commander.git
Fetching https://github.com/kylef/Spectre.git
Cloning https://github.com/tensorflow/swift-models.git
Resolving https://github.com/tensorflow/swift-models.git at master
Cloning https://github.com/kylef/Commander.git
Resolving https://github.com/kylef/Commander.git at 0.9.1
Cloning https://github.com/kylef/Spectre.git
Resolving https://github.com/kylef/Spectre.git at 0.9.0
[1/14] Compiling ModelSupport Stderr.swift
[2/14] Compiling ImageClassificationModels DenseNet121.swift
[3/14] Compiling ModelSupport Image.swift
[4/15] Merging module ModelSupport
[5/20] Compiling Datasets LabeledExample.swift
[6/20] Compiling Datasets MNIST.swift
[7/20] Compiling Datasets CIFAR10.swift
[8/20]

## Model, Dataset, Optimizer... Oh My!
Next, we instantiate the dataset, model, and optimizer we will be using.  We also setup our epochCount (the number of times we'll train our model) and batchSize (how much data we'll train with at a time).

One last thing to do is setup our test data into batches using our designated batchSize.

*Note: if this code block hangs, break out the calls here into separate code blocks, restart runtimes, and re-run the notebook each codeblock at a time.*

In [2]:
let epochCount = 12

let batchSize = 128

let dataset = MNIST()

var model = LeNet()

let optimizer = SGD(for: model, learningRate: 0.1)

let testBatches = dataset.testDataset.batched(batchSize)

Loading resource: train-images-idx3-ubyte
Loading local data at: /content/train-images-idx3-ubyte
Succesfully loaded resource: train-images-idx3-ubyte
Loading resource: train-labels-idx1-ubyte
Loading local data at: /content/train-labels-idx1-ubyte
Succesfully loaded resource: train-labels-idx1-ubyte
Loading resource: t10k-images-idx3-ubyte
Loading local data at: /content/t10k-images-idx3-ubyte
Succesfully loaded resource: t10k-images-idx3-ubyte
Loading resource: t10k-labels-idx1-ubyte
Loading local data at: /content/t10k-labels-idx1-ubyte
Succesfully loaded resource: t10k-labels-idx1-ubyte


## Benchmarking Prep
Lastly, we create a `struct` that we will use to hold our training and testing benchmarks per epoch.  Note that we also have a function in our struct to update our `GuessCount` stats.  This eliminates duplicate code in our training and testing loops.

In [0]:

struct Statistics {
    var correctGuessCount: Int = 0
    var totalGuessCount: Int = 0
    var totalLoss: Float = 0
    
    mutating func updateGuessCounts(logits: Tensor<Float>, labels: Tensor<Int32>, batchSize: Int) {
      let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
      self.correctGuessCount += Int(
            Tensor<Int32>(correctPredictions).sum().scalarized())
      self.totalGuessCount += batchSize
    }
}

## Training Day
Lastly, we run our training!  We run the training loop based on our `epochCount`.  Each time we do, we loop through batches of our data, run it through our model, update our benchmarks, and optimize along the gradients.  

At the end of each epoch, we print out our benchmark data.  We should see our loss decrease and our accuracy increase with each pass of training our model.

In [4]:
print("Beginning training...")

// The training loop.
for epoch in 1...epochCount {
    var trainStats = Statistics()
    var testStats = Statistics()
    let trainingShuffled = dataset.trainingDataset.shuffled(
        sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
    Context.local.learningPhase = .training
    for batch in trainingShuffled.batched(batchSize) {
        let (labels, images) = (batch.label, batch.data)
        // Compute the gradient with respect to the model.
        let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(images)
            trainStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
            return softmaxCrossEntropy(logits: logits, labels: labels)
        }
        trainStats.totalLoss += loss.scalarized()
        optimizer.update(&model, along: gradients)
    }

    Context.local.learningPhase = .inference
    for batch in testBatches {
        let (labels, images) = (batch.label, batch.data)
        // Compute loss on test set
        let logits = model(images)
        testStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
        let loss = softmaxCrossEntropy(logits: logits, labels: labels)
        testStats.totalLoss += loss.scalarized()
    }

    let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
    let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
    print("""
          [Epoch \(epoch)] \
          Training Loss: \(trainStats.totalLoss), \
          Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
          (\(trainAccuracy)), \
          Test Loss: \(testStats.totalLoss), \
          Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
          (\(testAccuracy))
          """)
}

Beginning training...
[Epoch 1] Training Loss: 224.62172, Training Accuracy: 51032/60032 (0.85007995), Test Loss: 10.67791, Test Accuracy: 9563/10112 (0.9457081)
[Epoch 2] Training Loss: 53.248043, Training Accuracy: 57909/60032 (0.96463555), Test Loss: 6.694432, Test Accuracy: 9737/10112 (0.96291536)
[Epoch 3] Training Loss: 37.35574, Training Accuracy: 58520/60032 (0.97481346), Test Loss: 4.70572, Test Accuracy: 9805/10112 (0.96964)
[Epoch 4] Training Loss: 30.11113, Training Accuracy: 58841/60032 (0.9801606), Test Loss: 4.8403378, Test Accuracy: 9817/10112 (0.97082675)
[Epoch 5] Training Loss: 25.07181, Training Accuracy: 59017/60032 (0.98309237), Test Loss: 3.5734432, Test Accuracy: 9849/10112 (0.9739913)
[Epoch 6] Training Loss: 21.899233, Training Accuracy: 59100/60032 (0.98447496), Test Loss: 4.1357136, Test Accuracy: 9824/10112 (0.971519)
[Epoch 7] Training Loss: 19.437084, Training Accuracy: 59222/60032 (0.9865072), Test Loss: 3.4082255, Test Accuracy: 9852/10112 (0.974288)
[E