In [1]:
@file:Repository("https://kotlin.bintray.com/kotlin-datascience")
@file:DependsOn("org.jetbrains.kotlin-deeplearning:api:[0.1.0]")
import org.jetbrains.kotlinx.dl.api.core.*
import org.jetbrains.kotlinx.dl.api.core.layer.*
import org.jetbrains.kotlinx.dl.api.core.layer.twodim.*
import org.jetbrains.kotlinx.dl.api.core.activation.*
import org.jetbrains.kotlinx.dl.api.core.initializer.*
import org.jetbrains.kotlinx.dl.datasets.Dataset
import org.jetbrains.kotlinx.dl.datasets.handlers.*
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.SGD

In [2]:
val EPOCHS = 2
val TRAINING_BATCH_SIZE = 2000
val TEST_BATCH_SIZE = 1000
val NUM_CHANNELS = 1L
val IMAGE_SIZE = 28L
val SEED = 12L

In [9]:
val model = Sequential.of(
    Input(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
    Conv2D(
        filters = 32,
        kernelSize = longArrayOf(5, 5),
        strides = longArrayOf(1, 1, 1, 1),
        activation = Activations.Relu,
        kernelInitializer = HeNormal(SEED),
        biasInitializer = HeNormal(SEED),
        padding = ConvPadding.SAME
    ),
    MaxPool2D(
        poolSize = intArrayOf(1, 2, 2, 1),
        strides = intArrayOf(1, 2, 2, 1)
    ),
    Conv2D(
        filters = 64,
        kernelSize = longArrayOf(5, 5),
        strides = longArrayOf(1, 1, 1, 1),
        activation = Activations.Relu,
        kernelInitializer = HeNormal(SEED),
        biasInitializer = HeNormal(SEED),
        padding = ConvPadding.SAME
    ),
    MaxPool2D(
        poolSize = intArrayOf(1, 2, 2, 1),
        strides = intArrayOf(1, 2, 2, 1)
    ),
    Flatten(),
    Dense(
        outputSize = 512,
        activation = Activations.Relu,
        kernelInitializer = HeNormal(SEED),
        biasInitializer = Constant(0.1f)
    ),
    Dense(
        outputSize = NUMBER_OF_CLASSES,
        activation = Activations.Linear,
        kernelInitializer = HeNormal(SEED),
        biasInitializer = Constant(0.1f)
    )
)

In [10]:
val (train, test) = Dataset.createTrainAndTestDatasets(
    TRAIN_IMAGES_ARCHIVE,
    TRAIN_LABELS_ARCHIVE,
    TEST_IMAGES_ARCHIVE,
    TEST_LABELS_ARCHIVE,
    NUMBER_OF_CLASSES,
    ::extractImages,
    ::extractLabels
)

model.use {
    it.compile(
        optimizer = SGD(learningRate = 0.1f),
        loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
        metric = Metrics.ACCURACY
    )

    it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE, verbose = true)

    val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]

    println("Accuracy: $accuracy")
}

Extracting 60000 images of 28x28 from datasets/mnist/train-images-idx3-ubyte.gz
Extracting 60000 labels from datasets/mnist/train-labels-idx1-ubyte.gz
Extracting 10000 images of 28x28 from datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting 10000 labels from datasets/mnist/t10k-labels-idx1-ubyte.gz
19:39:33.159 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Conv2D(filters=32, kernelSize=[5, 5], strides=[1, 1, 1, 1], dilations=[1, 1, 1, 1], activation=Relu, kernelInitializer=HeNormal(seed=12) VarianceScaling(scale=2.0, mode=FAN_IN, distribution=TRUNCATED_NORMAL, seed=12), biasInitializer=HeNormal(seed=12) VarianceScaling(scale=2.0, mode=FAN_IN, distribution=TRUNCATED_NORMAL, seed=12), kernelShape=[5, 5, 1, 32], padding=SAME); outputShape: [None, 28, 28, 32]
19:39:33.160 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - MaxPool2D(poolSize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=VALID); outputShape: [None, 14, 14, 32]
19:39:33.165 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - C

19:42:01.488 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3011773 metricValue: 0.1085 }
19:42:04.506 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3021197 metricValue: 0.106 }
19:42:07.903 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3016684 metricValue: 0.119 }
19:42:10.936 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3013375 metricValue: 0.118 }
19:42:14.024 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.300793 metricValue: 0.112 }
19:42:17.101 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3021295 metricValue: 0.106 }
19:42:20.734 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.3007288 metricValue: 0.1175 }
19:42:23.588 [main] DEBUG o.j.kotlinx.dl.api.core.Sequential - Batch stat: { lossValue: 2.302591 metricValue: 0.104 }
19:42:26.373 [main] DEBUG o.j.kotlinx.dl.api.cor