<a href="https://colab.research.google.com/github/huan/tensorflow-handbook-swift/blob/master/tensorflow-handbook-swift-example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Swift MNIST Example

Learn more from Github: https://github.com/huan/tensorflow-handbook-swift


In [0]:
import TensorFlow
import Python
import Foundation

## Minist Dataset Helper

In [0]:
class Mnist {

  let mnistBaseURL = "https://raw.githubusercontent.com/tensorflow/swift-models/master/Datasets/MNIST/"
  let mnistFiles = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte"]
  
  var images: Tensor<Float>?
  var labels: Tensor<Int32>?

  init() {
    // convert into tensors
    (self.images, self.labels) = self.readMNIST(
      imagesFile: self.mnistFiles[0],
      labelsFile: self.mnistFiles[1]
    )
  }

  // Split data into training and test
  public func splitTrainTest() -> (
    Tensor<Float>,
    Tensor<Int32>,
    Tensor<Float> ,
    Tensor<Int32>
  ) {
    let data = self.images!
    let labels = self.labels!

    let N = Int(data.shape[0])
    let split = Int(0.8 * Float(N))
    
    let trainX = data[0..<split]
    let trainY = labels[0..<split]
    
    let testX = data[split..<N]
    let testY = labels[split..<N]
    
    return (trainX, trainY, testX, testY)
  }

  // report accuracy of a batch 
  public func getAccuracy (
    y:      Tensor<Int32>, 
    logits: Tensor<Float>
  ) -> Float{
    let out = Tensor<Int32>(logits.argmax(squeezingAxis: 1) .== y).sum().scalarized()
    return Float(out) / Float(y.shape[0])
  }

  private func download() {
    let urllibRequest = Python.import("urllib.request")

    for file in self.mnistFiles {
      if !(FileManager.default.fileExists(atPath: file)) {
        print("Downloading \(file) ...")
        let url = self.mnistBaseURL + file
        urllibRequest.urlretrieve(url, file)
      }
    }
  }

  /// Reads a file into an array of bytes.
  private func readFile(_ path: String) -> [UInt8] {
      let url = URL(fileURLWithPath: path)
      let data = try! Data(contentsOf: url, options: [])
      return [UInt8](data)
  }

  /// Reads MNIST images and labels from specified file paths.
  private func readMNIST(
    imagesFile: String,
    labelsFile: String
  ) -> (
    images: Tensor<Float>,
    labels: Tensor<Int32>
  ) {
    download()

    print("Reading data.")
    let images = readFile(imagesFile).dropFirst(16).map(Float.init)
    let labels = readFile(labelsFile).dropFirst(8).map(Int32.init)
    let rowCount = Int(labels.count)
    let imageHeight: Int = 28, imageWidth: Int = 28

    print("Constructing data tensors.")
    return (
        images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
                .transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
        labels: Tensor(labels)
    )
  }
}

## Define a Simple MLP Model

In [0]:
struct MLP: Layer {
  typealias Input = Tensor<Float>
  typealias Output = Tensor<Float>

  var flatten = Flatten<Float>()
  var dense = Dense<Float>(inputSize: 784, outputSize: 10)
  
  @differentiable
  public func callAsFunction(_ input: Input) -> Output {
    return input.sequenced(through: flatten, dense)
  }  
}

var model = MLP()
let optimizer = Adam(for: model)

## Training

In [4]:
let mnist = Mnist()
let (trainImages, trainLabels, testImages, testLabels) = mnist.splitTrainTest()

let imageBatch = Dataset(elements: trainImages).batched(32)
let labelBatch = Dataset(elements: trainLabels).batched(32)

for (X, y) in zip(imageBatch, labelBatch) {
  // Caculate the gradient
  let (_loss, grads) = valueWithGradient(at: model) { model -> Tensor<Float> in
    let logits = model(X)
    return softmaxCrossEntropy(logits: logits, labels: y)
  }

  // Update parameters by optimizer
  optimizer.update(&model.allDifferentiableVariables, along: grads)    
}

let logits = model(testImages)
let acc = mnist.getAccuracy(y: testLabels, logits: logits)

print("Test Accuracy: \(acc)" )

Downloading train-images-idx3-ubyte ...
Downloading train-labels-idx1-ubyte ...
Reading data.
Constructing data tensors.
Test Accuracy: 0.91216666


- Credit: This example is inspired from [A set of notebooks explaining swift for tensorflow optimized to run in Google Collaboratory.](https://github.com/zaidalyafeai/Swift4TF)
- License [Apache-2.0](https://github.com/tensorflow/swift-models/blob/stable/LICENSE)