<a href="https://colab.research.google.com/github/brettkoonce/swift-models/blob/master/tpu_swift_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Swift for Tensorflow in a shell

In [1]:
%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("master"))' Datasets ImageClassificationModels
print("\u{001B}[2J")




# XLA device

In [6]:
let device = Device.defaultXLA
device

▿ Device(kind: .GPU, ordinal: 0, backend: .XLA)
  - kind : TensorFlow.Device.Kind.GPU
  - ordinal : 0
  - backend : TensorFlow.Device.Backend.XLA


# MNIST-XLA-TPU

Next, we will tackle MNIST on a TPU:

In [7]:
import Datasets
import TensorFlow

struct CNN: Layer {
  var conv1a = Conv2D<Float>(filterShape: (3, 3, 1, 32), padding: .same, activation: relu)
  var conv1b = Conv2D<Float>(filterShape: (3, 3, 32, 32), padding: .same, activation: relu)
  var pool1 = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))

  var flatten = Flatten<Float>()
  var inputLayer = Dense<Float>(inputSize: 14 * 14 * 32, outputSize: 512, activation: relu)
  var hiddenLayer = Dense<Float>(inputSize: 512, outputSize: 512, activation: relu)
  var outputLayer = Dense<Float>(inputSize: 512, outputSize: 10)
  
  @differentiable
  public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
    let convolutionLayer = input.sequenced(through: conv1a, conv1b, pool1)
    return convolutionLayer.sequenced(through: flatten, inputLayer, hiddenLayer, outputLayer)
  }  
}

let batchSize = 128
let epochCount = 12
var model = CNN()
var optimizer = SGD(for: model, learningRate: 0.1)
let dataset = MNIST(batchSize: batchSize)

let device = Device.defaultXLA
model.move(to: device)
optimizer = SGD(copying: optimizer, to: device)

print("Starting training...")

for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
    Context.local.learningPhase = .training
    var trainingLossSum: Float = 0
    var trainingBatchCount = 0
    for batch in epochBatches {
        let (images, labels) = (batch.data, batch.label)
        let deviceImages = Tensor(copying: images, to: device)
        let deviceLabels = Tensor(copying: labels, to: device)
        let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(deviceImages)
            return softmaxCrossEntropy(logits: logits, labels: deviceLabels)
        }
        trainingLossSum += loss.scalarized()
        trainingBatchCount += 1
        optimizer.update(&model, along: gradients)
	    LazyTensorBarrier()
    }

    Context.local.learningPhase = .inference
    var testLossSum: Float = 0
    var testBatchCount = 0
    var correctGuessCount = 0
    var totalGuessCount = 0
    for batch in dataset.validation {
        let (images, labels) = (batch.data, batch.label)
        let deviceImages = Tensor(copying: images, to: device)
        let deviceLabels = Tensor(copying: labels, to: device)
        let logits = model(deviceImages)
        testLossSum += softmaxCrossEntropy(logits: logits, labels: deviceLabels).scalarized()
        testBatchCount += 1

        let correctPredictions = logits.argmax(squeezingAxis: 1) .== deviceLabels
        correctGuessCount = correctGuessCount
            + Int(
                Tensor<Int32>(correctPredictions).sum().scalarized())
        totalGuessCount = totalGuessCount + batch.data.shape[0]
	    LazyTensorBarrier()
    }

    let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
    print(
        """
        [Epoch \(epoch + 1)] \
        Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
        Loss: \(testLossSum / Float(testBatchCount))
        """
    )
}

Loading resource: train-images-idx3-ubyte
Loading resource: train-labels-idx1-ubyte
Loading resource: t10k-images-idx3-ubyte
Loading resource: t10k-labels-idx1-ubyte
Starting training...
[Epoch 1] Accuracy: 9630/10000 (0.963) Loss: 0.113842346
[Epoch 2] Accuracy: 9714/10000 (0.9714) Loss: 0.089315146
[Epoch 3] Accuracy: 9817/10000 (0.9817) Loss: 0.05626795
[Epoch 4] Accuracy: 9845/10000 (0.9845) Loss: 0.04341101
[Epoch 5] Accuracy: 9851/10000 (0.9851) Loss: 0.0448411
[Epoch 6] Accuracy: 9849/10000 (0.9849) Loss: 0.0460658
[Epoch 7] Accuracy: 9881/10000 (0.9881) Loss: 0.038156476
[Epoch 8] Accuracy: 9868/10000 (0.9868) Loss: 0.03822924
[Epoch 9] Accuracy: 9884/10000 (0.9884) Loss: 0.039507892
[Epoch 10] Accuracy: 9876/10000 (0.9876) Loss: 0.039613236
[Epoch 11] Accuracy: 9882/10000 (0.9882) Loss: 0.038583163
[Epoch 12] Accuracy: 9888/10000 (0.9888) Loss: 0.043279253
