# Neural Networks with User-Defined Types

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.


In this notebook we will learn how to build a neural network using DiffKt with user-defined types. Neural networks are not exactly simple, but they are composed of simple mathematical techniques working in orchestration. However the calculus behind neural networks can be tedious, as derivatives for each layer need to be calculated for gradient descent purposes. Because weights and biases are applied in nested functions from each layer, it's mathematically like pulling apart an onion layer-by-layer. Thankfully DiffKt can take care of this task of calculating gradients for weight and bias layers, and leave out the messiness of solving derivatives by hand. Along the way, we will use custom types and demonstrate DiffKt's capabilities with its `Wrapper` interface. 

To get started, first bring in DiffKt to use in this notebook as well as these imports. Then we will talk about the structure of a neural network.

In [1]:
@file:DependsOn("../kotlin/api/build/libs/api.jar")

In [2]:
import org.diffkt.*
import java.net.URL
import org.diffkt.random.DiffktRandom

## The Anatomy of a Neural Network 

Let's present a problem adapted from Chapter 7 in the book [*Essential Math for Data Science (O'Reilly)*](https://learning.oreilly.com/library/view/essential-math-for/9781098102920/). We want to train a neural network to predict a light/dark font for a given background color. For example, a <span style="background-color:DarkBlue; color:white"><text color='white'>Dark Blue</text></span> background would warrant a light font and a <span style="background-color:pink; color:black"><text color='white'>Pink</text></span> background would warrant a dark font. We could solve this with a logistic regression or even a [known heuristic](https://stackoverflow.com/questions/1855884/determine-font-color-based-on-background-color), but this will be a nice toy example to discover the workings of a neural network and applying DiffKt.  

Let's get to building the neural network. This visual below does not show the activation functions, a critical component to make a neural network work. We will get to that. Let's look at the nodes first. 

<img src="./resources/sGQdjdjUMw.png" style="width: 600px;"/>

The first layer is simply an input of the three variables (R, G, and B values for a given color). In the hidden layer (which resides in the middle), notice that we have three **nodes**, or functions of weights and biases, between the inputs and outputs. There is a weight $ w_i $ between each input node and hidden node, and another set of weights between each hidden node and output node. Each hidden and output node gets an additional bias $ b_i $ added.

The output node repeats the same operation, taking the resulting weighted and summed outputs from the hidden layer and making them inputs into the output layer, where another set of weights and biases will be applied. I put "repeat weighting and summing" instead of the mathematical expressions because the expressions propagating from the hidden layer is too long to display in the graphic. But here is the expression for the final node. 

$ \text{output} = w_{10}(x_1w_1 + x_2w_2 + x_3w_3 + b_1) $ 

$ + w_{11}(x_1w_4 + x_2w_5 + x_3w_6 + b_2) $ 

$ + w_{12}(x_1w_7 + x_2w_8 + x_3w_9 + b_3) + b_4 $ 

We need to solve for each of these weight and bias values, and this is what we call **training** the neural network. But before we get to that later in this notebook, there is one more critical component we need to add. The **activation function** is a nonlinear function that transforms the weighted and summed values in a node, helping separate the weighted data so it can be classified. For this neural network, we will use the _ReLU_ function for the hidden layer and the sigmoid function for the output layer. 

<img src="./resources/PvLebFIsiT.png" style="width: 600px;"/>

The **ReLU "rectified linear unit" function** will take a given numeric input and turn it to $ 0 $ if negative. ReLU is commonly used for hidden layers in neural networks because of its speed. It also mitigates the [vanishing gradient problem](https://en.wikipedia.org/wiki/Vanishing_gradient_problem) where partial derivative gradients get so small they prematurely approach $ 0 $ and bring training to a halt. DiffKt comes packaged with a `relu()` function that is compatible with its tensors and scalar types. It can be built from scratch simply using a maximum function between 0 and the input value. 

$$
ReLU(x) = max(0, x)
$$

<img src="./resources/tKkerIrVkt.png" style="width: 600px;"/>


The output layer consolidates all the inputs from the hidden layer and turns them into interpretable results in the output layer. In this case where our output is binary (light/dark font) we only have one output node. We use the **sigmoid function** to compress values between 0 and 1 using a logistic curve. This can be interpreted as a probability between 0 and 1, where closer to 0 indicates a dark font recommendation and closer to 1 recommends a light font. We can use $ 0.5 $ as our threshhold so anything less than $ 0.5 $ is considered a light font recommendation, and anything equal or higher is dark font. DiffKt also comes with a `sigmoid()` function, which is mathematically defined below: 

$$
sigmoid(x) = \frac{1}{1 + e^{-x}}
$$

<img src="./resources/DllsJpEMCJ.png" style="width: 600px;"/>


## Declaring the Types

We will need to declare our classes in advance that will be targeted for differentiation. The first class `Color` will hold a red, green, and blue value as properties. It will also contain an `asFloatArray()` function for convenience to package the three values as a float array. We make it a `data class` so it prints the object type and its properties. 

In [3]:
data class Color(val r: Float, val g: Float, val b: Float)  {
    fun asFloatArray() = floatArrayOf(r,g,b)
}

The `FontShade` is a simple boolean indicator indicating light (0) or dark (1) font. We express this an an enumerable for explicitness and store the `value` as a float for mathematical operations. We also can convert a floating value of `1f` or `0f` using the `valueOf()` companion function. 

In [4]:
enum class FontShade(val value: Float) {
    LIGHT(0f),
    DARK(1f);

    companion object {
        fun valueOf(value: Float) = values().first { it.value == value }
    }
}

A `ColorFontAndShade` class pairs a `Color` and a `FontShade` together. It also has a constructor build both of those items at once. 

In [5]:

class ColorAndFontShade(val color: Color, val fontShade: FontShade) {

    constructor(r: Float, g: Float, b: Float, fontShadeValue: Float):
            this(Color(r,g,b), FontShade.valueOf(fontShadeValue))

    override fun toString() = "(${color.r},${color.g},${color.b})-${fontShade}"
}

Now let's look at the neural network components, each declared as a user-defined type. We are going to use DiffKt to calculate differentials on each layer, which contains nodes with weights and biases. The `NodeDiff` will contain the `weightDiffs` with regard to each weight, and a `biasDiff` to target the bias. We will also implement a `times()` function to multiply the diffs by a scalar. 

In [6]:
data class NodeDiff(val weightDiffs: List<DScalar>, val biasDiff: DScalar)  {
    operator fun times(scalar: Float) = NodeDiff(weightDiffs.map { it * scalar}, biasDiff * scalar)
}

data class LayerDiff(val nodeDiffs: List<NodeDiff>) {
    operator fun times(scalar: Float) = LayerDiff(nodeDiffs.map { it * scalar })
}


We will then implement a `Node` type which contains the weights and biases for a given node on a layer. This will implement the `Differentiable` interface and make it wrappable to the DiffKt API. 

In [7]:
data class Node(val weights: List<DScalar>,
                val bias: DScalar = DiffktRandom.fromTimeOfDay().nextUniform()
): Differentiable<Node> {

    fun forward(values: List<DScalar>) =
        values.zip(weights) { v, w -> v * w }
            .reduce { x, y -> x + y } + bias

    constructor(connectionCount: Int) : this(
        (0 until connectionCount).map { (DiffktRandom.fromTimeOfDay().nextUniform() * 2f) - 1f }
    )

    override fun wrap(wrapper: Wrapper) =
        Node(wrapper.wrap(weights), wrapper.wrap(bias))

    private fun bound(value: DScalar, lower: Float, upper: Float) : DScalar {
        val l = FloatScalar(lower)
        val u = FloatScalar(upper)
        val trimLower = ifThenElse(value lt l, l, value)
        val trimUpper = ifThenElse(trimLower gt u, u, trimLower)
        return trimUpper
    }

    operator fun minus(nodeDiff: NodeDiff) =
        Node(weights.zip(nodeDiff.weightDiffs) { w, d ->
            bound(w - d, -1f, 1f)
        }, bound(bias - nodeDiff.biasDiff, 0f, 1f))
}


A `forward()` function will take incoming values and apply the weights and bias through a dot product operation. The `constructor` will randomly initialize a node with the provided number of connections, resulting in corresponding weights and a bias. The `wrap()` instructs DiffKt how to pass the weights and bias to a new instance of `Node` and enable differentiation. The `bound()` is a utility function to keep a value inside a range, and that is used in `minus()` when a `NodeDiff` is subtracted from a `Node` during gradient descent where the weight and bias values are adjusted based on the gradients. 

The `Layer` contains an ordered collection of nodes and an activation function, which is defined as a functional argument taking a `DScalar` and converting into another `DScalar`. 

In [8]:
data class Layer(val nodes: List<Node>,
                 val activation: (DScalar) -> DScalar): Wrappable<Layer> {

    constructor(previousNodeCount: Int,
                nodeCount: Int,
                activation: (DScalar) -> DScalar): this(
        (0 until nodeCount).map { Node(previousNodeCount) },
        activation
    )

    fun forward(values: List<DScalar>) = nodes.map { node ->
        activation(node.forward(values))
    }

    operator fun minus(layerDiff: LayerDiff) =
        Layer(nodes.zip(layerDiff.nodeDiffs) { n,d -> n - d }, activation)

    override fun toString() = nodes.joinToString(",")

    override fun wrap(wrapper: Wrapper) = Layer(wrapper.wrap(nodes), activation)
}


The `Layer` implements `Wrappable` so it can be differentiated on effectively and implements the `wrap()` functionon the nodes. The alternate `constructor` takes as argument the previous layer's number of nodes so it can calculate the number of connections, and thus weights, from the previous layer. Otherwise if it is the input layer the the primary constructor is used for an input layer. The `forward()` function will apply the nodes (weights and biases) and activation function to a given input. The `minus()` will subtract a `LayerDiff` to apply differentials on a layer of nodes.

Finally, a `NeuralNetwork` will be constructed off a list of layers and will implement a `Wrappable` interface as well. Notice that there is a wrappable hierarchy with `Neural Network` ⮕ `Layer` ⮕ `Node`. 

In [9]:
data class NeuralNetwork(val layers: List<Layer>): Wrappable<NeuralNetwork> {
    constructor(vararg layers: Layer): this(layers.toList())

    fun forward(input: FloatArray): List<DScalar> {
        var forwardProp: List<DScalar> = input.map(::FloatScalar)

        for (layer in layers) {
            forwardProp = layer.forward(forwardProp)
        }
        return forwardProp
    }

    operator fun minus(layerDiffs: List<LayerDiff>) =
        NeuralNetwork(
            layers.zip(layerDiffs) { layer, diff -> layer - diff }
        )

    override fun wrap(wrapper: Wrapper) = NeuralNetwork(wrapper.wrap(layers))
}

This will make differentiating an entire neural network possible with respect to each individual weight and bias. When we call the `forward()` function it will propogate the values through each layer including applying the weights, biases, summations, and activation functions. We return the result as a list of `DScalar` which will contain the prediction. This `forward()` function is what we will differentiate in a loss function. 

The `minus()` subtracts layer differentials from the entire neural network, and this will aid us in gradient descent easily by subtracting a fraction of the gradients across so many iterations. 

## Importing the Data



Let's first explore our data stored [here](https://tinyurl.com/y2qmhfsr). It contains 3 input variable columns (red, green, and blue) and the output light/dark font indicator which is a boolean we want to predict. We have 1345 records in this training data. Here is a sample: 

| RED | GREEN | BLUE | LIGHT_OR_DARK_FONT_IND |
|-----|-------|------|------------------------|
| 0   | 0     | 0    | 0                      |
| 0   | 0     | 128  | 0                      |
| 0   | 139   | 69   | 0                      |
| 0   | 154   | 205  | 0                      |
| 0   | 178   | 238  | 1                      |
| 0   | 197   | 205  | 1                      |
| 0   | 199   | 140  | 1                      |
| 0   | 201   | 87   | 1                      |
| 0   | 205   | 0    | 0                      |
| 0   | 205   | 102  | 1                      |

To bring in this data, let's use the Java `URL` interface to read the CSV from GitHub. We will use some [regular expressions](https://www.oreilly.com/content/an-introduction-to-regular-expressions/) to split the lines, and [Sequence](https://kotlinlang.org/docs/sequences.html) operations to clean up the lines like an assembly line. We will map the whole CSV into a list of `ColorAndfontShade` objects. Note we will also rescale the color values down by a factor of 255, to compress the red, green, and blue values to be between 0 and 1. 

In [10]:
val colorsAndFontShades = URL("https://tinyurl.com/y2qmhfsr")
    .readText().split(Regex("\\r?\\n"))
    .asSequence()
    .drop(1)
    .filter { it.isNotBlank() }
    .map { s ->
        s.split(",").map { it.toFloat()  }
    }.map {
        ColorAndFontShade(it[0] / 255, it[1] / 255, it[2] / 255, it[3])
    }.toList()

## Performing Gradient Descent

Let's declare an instance of our neural network to predict a light or dark font. We will use expect 3 inputs that will be passed to a middle layer with 12 nodes, and uses a ReLU function. The output layer will take those 12 incoming values as inputs and then weight/bias them again and sum them into a single value, hence why there is only 1 node. That final layer will take that single value and pass it through the `sigmoid()` function.

In [11]:
var nn = NeuralNetwork(
    Layer(3,12,::relu),
    Layer(12,1,::sigmoid)
)

Obviously our neural network is going to perform poorly as the weights and biases are randomly initialized and not optimized yet. The first step in optimizing them is to declare a `loss()` function that measures how far off our neural network outputs are from our training outputs. Let's use a simple squared loss function, where $ C $ is the sum of squares between the predicted outputs $ A_2 $ and the actual outputs from the data $ Y $. 

$ C = (Y - A_2)^2 $ 

Let's create this `loss()` function as specified below. We will perform batch gradient descent here and iterate the entire batch, calculating the sum of squared differences between the neural network prediction and actual output.

In [12]:
var count = 0

fun loss(nn: NeuralNetwork): DScalar {
    var loss: DScalar = FloatScalar.ZERO

    for (trainExample in colorsAndFontShades) {
        val (trainInput, trainOutput) = trainExample.let { it.color.asFloatArray() to it.fontShade.value }
        val diff = nn.forward(trainInput)[0] - trainOutput
        loss += diff.pow(2)
    }

    println("$count. ${sqrt(loss.basePrimal() / colorsAndFontShades.size)}")
    count++
    return loss
}

Finally let's move forward with our neural network. We will implement a learning rate of $ .001 $ and 100 itertions. On each iteration we will traverse the entire dataset as implemented in our `loss()` function, and pass the whole `NeuralNetwork` as an input object into `reverseDerivative()`. The `extractDerivative` lambda will map each layer from the neural network to a `LayerDiff`, which then contains nodes mapped to `NodeDiff` instances containing the `extractTangent` calculations for the weights and tensors. We can then take this resulting list of `LayerDiff` objects, multiply them by our learning rate, and subtract them from our neural network. 

We can then calculate the proportion of accurate predictions by running all the data through the neural network, and comparing to the actual values. Reading the output below, you should get high accuracy around $ .96 $, give or take. 

In [13]:
// The learning rate
val lr = .001F

// The number of iterations to perform gradient descent
val iterations = 100

    // Perform gradient descent
    for (i in 0..iterations) {

        val tangents = reverseDerivative(
            x = nn,
            f = ::loss,
            extractDerivative= { input, output, extractTangent ->
                input.layers.map { layer ->
                    LayerDiff(
                        layer.nodes.map { node ->
                            NodeDiff(
                                node.weights.map { extractTangent(it, output) as DScalar },
                                extractTangent(node.bias, output) as DScalar
                            )
                        }
                    )
                }
            }
        )

        nn -= tangents.map { it * lr }
    }

    // calculate accuracy
    val accuracy = colorsAndFontShades.map {
        (nn.forward(it.color.asFloatArray())[0].ge(.5f).eq(it.fontShade.value) as FloatScalar).value
    }.sum() / colorsAndFontShades.count()

    println(nn)
    println("ACCURACY: $accuracy")

0. 0.54920614
1. 0.5178369
2. 0.50264174
3. 0.48897183
4. 0.47608778
5. 0.46379277
6. 0.45166457
7. 0.44003513
8. 0.42941603
9. 0.42018256
10. 0.4114745
11. 0.4035245
12. 0.39601928
13. 0.38903233
14. 0.38250408
15. 0.37636676
16. 0.37055892
17. 0.365396
18. 0.36049214
19. 0.35582516
20. 0.3513578
21. 0.3470943
22. 0.3429735
23. 0.33900622
24. 0.33527622
25. 0.3318883
26. 0.32862973
27. 0.32542607
28. 0.32231262
29. 0.31930184
30. 0.31638476
31. 0.3137104
32. 0.3112201
33. 0.30891708
34. 0.30702764
35. 0.30517948
36. 0.3033503
37. 0.3015414
38. 0.29976398
39. 0.29803848
40. 0.29634753
41. 0.29468715
42. 0.29304844
43. 0.29143605
44. 0.28982726
45. 0.28823754
46. 0.28667375
47. 0.28511363
48. 0.2836921
49. 0.28238967
50. 0.28111345
51. 0.2799777
52. 0.27886778
53. 0.27777228
54. 0.27669948
55. 0.27567115
56. 0.27465108
57. 0.27364227
58. 0.27264416
59. 0.27165627
60. 0.2706783
61. 0.26971057
62. 0.26876372
63. 0.2679001
64. 0.26705843
65. 0.2663019
66. 0.26557568
67. 0.2648647
68. 0.264