In [6]:
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.learning.config.Nesterovs
import org.nd4j.linalg.lossfunctions.LossFunctions

fun generateData(numSamples: Int): Pair<INDArray, INDArray> {
    val features = Nd4j.rand(numSamples, 2).subi(0.5).muli(10.0) // random numbers between -5 and 5
    val labels = features.sum(1).gt(0.0).castTo(features.dataType()) // 1 if sum > 0, 0 otherwise

    return Pair(features, labels)
}

fun main() {
    val numInputs = 2
    val numOutputs = 2
    val numHiddenNodes = 20

    val conf: MultiLayerConfiguration = NeuralNetConfiguration.Builder()
        .seed(123)
        .weightInit(WeightInit.XAVIER)
        .updater(Nesterovs(0.01, 0.9))
        .list()
        .layer(0, DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
            .activation(Activation.RELU)
            .build())
        .layer(1, OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
            .activation(Activation.SOFTMAX)
            .nIn(numHiddenNodes).nOut(numOutputs).build())
        .build()

    val model = MultiLayerNetwork(conf)
    model.init()

    // Generate training data
    val (trainFeatures, trainLabels) = generateData(1000)

    // Train the model
    for (epoch in 0 until 100) {
        model.fit(trainFeatures, trainLabels)
    }

    // Generate some test data
    val (testFeatures, testLabels) = generateData(100)

    // Evaluate the model on the test data
    val output = model.output(testFeatures)
    val nCorrect = output.eq(testLabels).sumNumber().toInt()
    println("Correctly classified $nCorrect out of 100 test examples.")
}