# **DiffKT** Model

 **Copyright (c) Meta Platforms, Inc. and affiliates.**
 
 This source code is licensed under the MIT license found in the
 LICENSE file in the root directory of this source tree.

## Introduction

This notebook discusses the model api, which is in diffkt/kotlin/api/src/main/kotlin/org/diffkt/model.
The model api is used to build deep neural networks using the automatic differentiation in **DiffKt**.
This notebook will use a simple linear regression example to show how to use the model api. This notebook is based on the example __[Linear Regression](https://github.com/facebookresearch/diffkt/tree/main/kotlin/examples/src/main/kotlin/examples/linreg)__.

There are additional examples using the model api:

__[Iris](https://github.com/facebookresearch/diffkt/tree/main/kotlin/examples/src/main/kotlin/examples/iris)__, an image processing example using dense layers in a neural network.

__[MNIST](https://github.com/facebookresearch/diffkt/tree/main/kotlin/examples/src/main/kotlin/examples/mnist)__, an image processing example using a convolution neural network.

__[RESNET](https://github.com/facebookresearch/diffkt/tree/main/kotlin/examples/src/main/kotlin/examples/resnet)__, an image processing example using a deep convolution neural network.


### Housekeeping

The following jars need to be included in the notebook.

In [1]:
@file:DependsOn("../kotlin/api/build/libs/api.jar")
@file:DependsOn("../kotlin/data/build/libs/data.jar")

Import the following classes for the example.

In [2]:
import org.diffkt.*
import org.diffkt.data.Data
import org.diffkt.model.*
import org.diffkt.tracing.jit
import kotlin.math.min
import kotlin.random.Random

## Linear Regression

This is a simple linear regression model of $y = ax + b$, where

$x$ - feature or input 

$y$ - label or  output

$a$ - the weight

$b$ - the bias

The goal of linear regression is to recover the weight and the bias given the model, the input data, and the output data.

## The Model API

A number of steps are required to use the Model API.

1) Create some training data to use to build the model,

2) Create an interator over the data for training the model,

3) Create a linear regression model that inherits from the Model API,

4) Create a loss function,

5) Create an optimizer,

6) Create a learning class,

7) Train the model.

### Setup

The training data set size is 100 data points.

In [3]:
// Setup

val trainingDataSize = 100
val random = Random(1234567)

### Training Data

The function `makeTrainingData()` creates a vector, or a 1D tensor, of 100 random inputs, or features. Tensor based arithmatic is used to create the labels vector, where $labels = features * trueWeight + trueBias$. The features, labels, trueWeight, and trueBias are return in an object of class `TrainingData`. trueWeight and trueBias are stored with the data so we can see how accurate a model is produced from training.

In [4]:
// Training Data

class TrainingData(val features : FloatTensor, 
                   val labels : FloatTensor, 
                   val trueWeight : FloatScalar, 
                   val trueBias : FloatScalar) {
       
    companion object {
              
        fun makeTrainingData(trainingDataSize: Int, random : Random ) : TrainingData {    

            val trueWeight = FloatScalar(random.nextFloat())
            val trueBias = FloatScalar(random.nextFloat())
            
            val features = FloatTensor(Shape(trainingDataSize)) { random.nextFloat() }
            val labels = (features * trueWeight + trueBias) as FloatTensor
        
            return TrainingData(features, labels, trueWeight, trueBias)
        }
    }
} 

### Create the Training Data

In [5]:
val trainingData = TrainingData.makeTrainingData(trainingDataSize, random)

### Data Iterator

The `SimpleDataIterator` class creates an iterator over class `Data`. Class `Data` is located in __[diffkt/kotlin/data/src/main/kotlin/org/diffkt/data/Data.kt](https://github.com/facebookresearch/diffkt/blob/main/kotlin/data/src/main/kotlin/org/diffkt/data/Data.kt)__. It can hold the labels and features for a training set and provides an iterator over the data.

In [6]:
class SimpleDataIterator(val features: FloatTensor,
                         val labels: FloatTensor,
                         val batchSize: Int = 1): Iterable<Data> {
    
    init {
        require(features.shape.first == labels.shape.first)
    }

    private val n = features.shape.first

    fun withBatchSize(batchSize: Int) = SimpleDataIterator(features, labels, batchSize)

    override fun iterator(): Iterator<Data> = object : Iterator<Data> {
        var loc = 0
        override fun hasNext(): Boolean = loc < n
        override fun next(): Data {
            require(hasNext())

            val start = loc
            val end = min(loc + batchSize, n)
            val f = features.slice(start, end)
            val l = labels.slice(start, end)
            loc = end
            return Data(f, l)
        }
    }
}

### Create the Data Iterator

In [7]:
val dataIterator = SimpleDataIterator(trainingData.features, trainingData.labels, trainingDataSize)

### Linear Regression Model

In [8]:
class LinearRegression(val l: AffineTransform): Model<LinearRegression>() {
    
    constructor(m: DScalar, b: DScalar) : this(AffineTransform(TrainableTensor(m), TrainableTensor(b)))
    constructor(random: Random) : this(FloatScalar(random.nextFloat()), FloatScalar(random.nextFloat()))

    override val layers: List<Layer<*>> = listOf(l)

    override fun withLayers(newLayers: List<Layer<*>>): LinearRegression {
        require(newLayers.size == 1)
        val newLayer = newLayers[0] as AffineTransform
        return LinearRegression(newLayer)
    }

    override fun hashCode(): Int = combineHash("LinearRegression", l)
    override fun equals(other: Any?): Boolean = other is LinearRegression &&
            other.l == l
}

### Create the Linear Regression Model

In [9]:
val linReg = LinearRegression(random)

In [10]:
fun lossFun(predictions: DTensor, labels: DTensor): DScalar {
    val diff = predictions - labels
    return (diff * diff).sum()
}

In [11]:
val optimizer = FixedLearningRateOptimizer<LinearRegression>(0.5F / trainingDataSize)

In [12]:
class Learner<T : Model<T>>(val batchedData: Iterable<Data>,
                            val lossFunc: (predictions: DTensor, labels: DTensor) -> DScalar,
                            val optimizer: Optimizer<T> = AdamOptimizer(),
                            val useJit: Boolean = false) 
{
    var totalTime = 0L

    /**
     * Trains the given model on the data set, for [epochs] epochs processing the data of the [dataIterator] in
     * batches of size [batchSize], but with a maximum number of total batches processed [maxIters].  Returns
     * the trained model.
     */
    fun train(model: T,
              epochs: Int,
              printProgress: Boolean = false,
              maxIters: Int? = null,
              printProgressFrequently: Boolean = false,
              device: Device = Device.CPU): T 
    {
        
        var totalIters = 0

        // The model training step function, which could possibly be optimized.
        fun modelTrainStep(model2: T, batch: Data): Pair<DScalar, T> 
        {
            val (loss, tangent) = primalAndReverseDerivative(
                x = model2,
                f = { model3: T ->
                    val output = model3.predict(batch.features)
                    val loss = lossFunc(output, batch.labels)
                    loss
                },
                extractDerivative = { model3: T,
                                      loss: DScalar,
                                      extractor: (input: DTensor, output: DTensor) -> DTensor ->
                    model3.extractTangent(loss, extractor)
                }
            )

            val trainedModel: T = optimizer.train(model2, tangent)
            return Pair(loss, trainedModel)
        }
        
        fun trainingFunction(p: Pair<T, Data>): Pair<DScalar, T> = modelTrainStep(p.first, p.second)
        
        val jittedTrainingFunction = if (useJit) jit(::trainingFunction) else ::trainingFunction

        val optimizedModel = (0 until epochs).fold(model) { model1: T, e: Int ->
            var lossTotal: DScalar = FloatScalar.ZERO
            val trainedModel = batchedData.fold(model1) { model2: T, batch: Data ->
                val batchOnDevice = batch.to(device)
                val t1 = System.nanoTime()
                val (loss, trainedModel) = jittedTrainingFunction(Pair(model2, batchOnDevice))
                val t2 = System.nanoTime()
                totalTime += t2 - t1
                if (printProgress) lossTotal += loss
                totalIters++
                if (printProgressFrequently) println("Iter $totalIters Batch Loss: $loss")
                if (maxIters != null && totalIters >= maxIters) return trainedModel
                trainedModel
            }
            if (printProgress && ((e % 10) == 0)) println("Epoch $e Cumulative Loss: $lossTotal")
            trainedModel
        }
        
        return optimizedModel
    }

    private fun e(n: Long) = n / 1e9f

    fun dumpTimes() {
        println("running time:  ${e(totalTime)} sec")
    }
}


In [13]:
val learner = Learner(batchedData = dataIterator,
                      lossFunc = ::lossFun,
                      optimizer = optimizer,
                      useJit = true)

val trainedModel = learner.train(linReg, 300, printProgress = true)
learner.dumpTimes()



Epoch 0 Cumulative Loss: 68.87204
Epoch 10 Cumulative Loss: 0.06945361
Epoch 20 Cumulative Loss: 0.02103138
Epoch 30 Cumulative Loss: 0.0063685565
Epoch 40 Cumulative Loss: 0.0019284704
Epoch 50 Cumulative Loss: 5.839721E-4
Epoch 60 Cumulative Loss: 1.768291E-4
Epoch 70 Cumulative Loss: 5.3545296E-5
Epoch 80 Cumulative Loss: 1.6213318E-5
Epoch 90 Cumulative Loss: 4.909573E-6
Epoch 100 Cumulative Loss: 1.4866098E-6
Epoch 110 Cumulative Loss: 4.5023984E-7
Epoch 120 Cumulative Loss: 1.3625511E-7
Epoch 130 Cumulative Loss: 4.1285187E-8
Epoch 140 Cumulative Loss: 1.2533533E-8
Epoch 150 Cumulative Loss: 3.7813876E-9
Epoch 160 Cumulative Loss: 1.1489654E-9
Epoch 170 Cumulative Loss: 3.4492587E-10
Epoch 180 Cumulative Loss: 1.0542678E-10
Epoch 190 Cumulative Loss: 3.381828E-11
Epoch 200 Cumulative Loss: 1.0352608E-11
Epoch 210 Cumulative Loss: 3.2045477E-12
Epoch 220 Cumulative Loss: 2.5330849E-12
Epoch 230 Cumulative Loss: 2.5330849E-12
Epoch 240 Cumulative Loss: 2.5330849E-12
Epoch 250 Cumul