# Differentiable Programming with User Defined Types


**Copyright (c) Meta Platforms, Inc. and affiliates.**
 
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

## Introduction

This notebook will discuss more advance features of the **DiffKt** api. For an introduction to **DiffKt**, review the notebook "Introduction To Basic API Operations for DiffKt".

The focus of this notebook will be on user defined data types and the use of the functions __[primalAndForwardDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__, which can preform automatic differentiation over functions of user defined types or complex data structures.



## Housekeeping

This notebook uses `api.jar` from the **DiffKt** project.<br>
`@file:DependsOn("...")` tells the Kotlin Jupyter notebook the path to a jar that it needs.

In [1]:
@file:DependsOn("../kotlin/api/build/libs/api.jar")

## Imports

In [2]:
import org.diffkt.*

## Example Problem

We will use the example from the "Introduction To Basic API Operations for DiffKt" that was used the show the difference in the transpose of the Jacobian between forward and reverse derivatives.

For a vector valued function where you have N inputs and M outputs, the Jacobian that is returned by the __[forwardDerivative](http://www.diffkt.org/api/api/org.diffkt/forward-derivative.html)__ and __[primalAndForwardDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ is an M x N matrix. For the reverse derivative algorithms, __[reverseDerivative](http://www.diffkt.org/api/api/org.diffkt/reverse-derivative.html)__ and __[primalAndReverseDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__, the transpose of the Jacobian is returned, which is a N x M Matrix.

For this example there are 3 inputs and 2 outputs.

Let $\mathbf x $ be a 1D tensor of inputs with three elements, $x_0, x_1,$ and $x_2$.

Let $\mathbf y $ be a 1D tensor of function outputs with two functions, $y_0(\mathbf x)$ and $y_1(\mathbf x)$.

Let $y_0(\mathbf x) = 2 x_0 + 2 x_1^2 + 3 x_1 x_2^3$

Let $y_1(\mathbf x) = 3 x_0^3 x_1 + 2 x_1^2 + 3 x_2^3$

Then $\mathbf f(\mathbf x) = (y_0(\mathbf x), y_1(\mathbf x)) = 2 x_0 + 2 x_1^2 + 3 x_1 x_2^3, 3 x_0^3 x_1 + 2 x_1^2 + 3 x_2^3$

The derivative is $ \mathbf \nabla \mathbf f(\mathbf x) = \left [ \begin {array} {*{2}{c@{{},{},{}}c}} 2 & 4 x_1 + 3x_2^3 & 9x_1x_2^2 \\ 9x_0^2x_1 & 3x_0^3 + 4x_1 & 9x_2^2 \end {array} \right ] $

The transpose of the derivative is $ \mathbf \nabla \mathbf f(\mathbf x)^T = \left [ \begin {array} {*{3}{c@{{},{}}c}} 2 & 9x_0^2x_1 \\ 4 x_1 + 3x_2^3 & 3x_0^3 + 4x_1 \\ 9x_1x_2^2  & 9x_2^2 \end {array} \right ] $

if $\mathbf x = \left [ 1, 2, 3 \right ]$,

then $\mathbf f(\mathbf x) = \left [ 172, 95 \right ]$

and

$\mathbf \nabla \mathbf f(\mathbf x) = \left [ \begin {array} {*{2}{c@{{},{},{}}c}} 2 & 89 & 162 \\ 18 & 11 & 81 \end {array} \right ]$

$\mathbf \nabla \mathbf f(\mathbf x)^T = \left [ \begin {array} {*{3}{c@{{},{}}c}} 2 &  18 \\ 89 & 11 \\ 162 & 81 \end {array} \right ]$

In [3]:
// example problem

fun f(x: DTensor) : DTensor {
    
    val y0  = 2f * x[0] + 2f * x[1].pow(2f) + 3f * x[1] * x[2].pow(3f) 
    val y1  = 3f * x[0].pow(3f) * x[1] + 2f * x[1].pow(2f) + 3f * x[2].pow(3f)
  
    return tensorOf(y0 as DScalar, y1 as DScalar) 
}

val x = tensorOf(1f, 2f, 3f)

// calculate the Jacobian by the forward derivative algorithm
val (forwardFx, forwardJacobian) = primalAndForwardDerivative(x, ::f)

// calculate the Jacobian by the reverse derivative algorithm
val (reverseFx, reverseJacobian) = primalAndReverseDerivative(x, ::f)

println("x = ${x}")
println("")
println("forward f(x) = ${forwardFx}")
println("forward Jacobian(f(x)) = ${forwardJacobian}")
println("")
println("reverse f(x) = ${reverseFx}")
println("reverse Jacobian(f(x)) = ${reverseJacobian}")

x = [1.0, 2.0, 3.0]

forward f(x) = [172.0, 95.0]
forward Jacobian(f(x)) = [[2.0, 89.0, 162.0], [18.0, 11.0, 81.0]]

reverse f(x) = [172.0, 95.0]
reverse Jacobian(f(x)) = [[2.0, 18.0], [89.0, 11.0], [162.0, 81.0]]


Both the forward derivative algorithm and reverse derivative algorithm calculate the same value for `f(x)`.
The reverse derivative algorithm calculates the transpose of the Jacobian.

## User Define Types

When you create a __[DTensor](http://www.diffkt.org/api/api/org.diffkt/-d-tensor/index.html)__ or __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ variable, internally it has an implementation of a function call `wrap()`, which is invoked during differentiation operations. The internal representation is used for both the calculation of the user defined function and the calculation of its derivative. Alternatively, one can create their own user defined type. A user defined type could be a class with __[DTensor](http://www.diffkt.org/api/api/org.diffkt/-d-tensor/index.html)__ or __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ variables, or a list of __[DTensor](http://www.diffkt.org/api/api/org.diffkt/-d-tensor/index.html)__ or __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ variables, or even more complex types. When defining a user created type, one has to implement the `wrap()` function as part of the type. There are a couple ways to implement the `wrap()` function and have it called, which are discussed below.

The advantage of the user defined type is that one has named-member access of a class instead of placing all the variables in an array or tensor and having to use indexing to access the variables.

The purpose of __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ is to calculate the derivatives of user defined types. The functions take a user defined input type, a user defined output type, and a user defined derivative type. In addition, the user defines a function for the calculations, and possibly a function to extract the derivatives from the calculations and place the results into the user defined derivative type. Also, the lambdas `wrapInput` and `wrapOutput` might need to be defined to get the `wrap()` function called internally in the code. Notice the similarity in names to __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__, as an "s" has been added to the end of the function names __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivatives.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ .

__[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ have essentially the same function signature . The function signatures are:

`fun <Input : Any, Output : Any, Derivative : Any>`__[primalAndForwardDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__`(
    x: Input,
    f: (Input) -> Output,
    wrapInput: ((Input, Wrapper) -> Input)? = null,
    wrapOutput: ((Output, Wrapper) -> Output)? = null,
    extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative,
): Pair<Output, Derivative>`

and

`fun <Input : Any, Output : Any, Derivative : Any>`__[primalAndReverseDerivative](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__`(
    x: Input,
    f: (Input) -> Output,
    wrapInput: ((Input, Wrapper) -> Input)? = null,
    wrapOutput: ((Output, Wrapper) -> Output)? = null,
    extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative,
): Pair<Output, Derivative>`

The type for `Input`, `Output`, and `Derivative` are user defined. The user defined types could be a class with __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ or __[DTensor](http://www.diffkt.org/api/api/org.diffkt/-d-tensor/index.html)__ variables, a list with __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ or __[DTensor](http://www.diffkt.org/api/api/org.diffkt/-d-tensor/index.html)__ elements, or something more complex. 

The function `f: (Input) -> Output` has to know how to access the variables in the `Input` type and produce a return of `Output` type.

The `Derivative` type has to define all the possible derivates that can be produced from taking the derivative of `f()` with respect to the `Input` type.

The `Input` or `Output` types can inherit the `Differentiable<T>` interface, which knows how to call the `wrap()` function.. If the `Input` and `Output` types do not inherit from the `Differentiable<T>` interface, then a lambda expression needs to written for the `wrapInput` and/or the `wrapOutput` functions to call `wrap()` for the `Input` or `Output` type.

## Defining a Class as a User Defined Types

The example problem was represented with the input being a 1D vector with three variables and the output being a 1D vector with two variables. In this example we will define three user defined classes for `Input`, `Output`, and `Derivative`.

### Input Class

We will define an `Input` class that is constructed with three __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ variables: `x0`, `x1`, and `x2`. Now each variable is a named member of the class and does not have to be accessed by indexing an array. The class `Input` inherits from the `Differentiable<T>` interface. Since the class inherits from the `Differentiable<T>` interface it will override the function `wrap()`, which is defined in the `Differentiable<T>` interface. When `wrap()` is called, a new `Input` instance is created with `x0`, `x1`, and `x2` wrapped.

In [4]:
// Input class

class Input(val x0: DScalar, val x1: DScalar, val x2: DScalar) : Differentiable<Input> {
            override fun wrap(wrapper: Wrapper): Input {
                return Input(wrapper.wrap(x0), wrapper.wrap(x1), wrapper.wrap(x2))
            }
}

### Output Class

We will define an `Output` class that is constructed with two __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ variables: `y0` and `y1`. Now each variable is a named member of the class and does not have to be accessed by indexing an array. The class `Output` inherits from the `Differentiable<T>`interface. Since the class inherits from the `Differentiable<T>` interface it will override the function `wrap()`, which is defined in the `Differentiable<T>` interface. Each variable in the class is wrapped by calling `wrapper.wrap()` as it is passed to a constructor of the `Output` class.

In [5]:
// Output class

class Output(val y0: DScalar, val y1: DScalar) : Differentiable<Output> {
            override fun wrap(wrapper: Wrapper): Output {
                return Output(wrapper.wrap(y0), wrapper.wrap(y1))
            }
} 

### Derivative Class

A user defined class is specified to hold the derivatives. Since we have 2 output functions and 3 variables, there are 6 derivatives. 

`dy0Dx0` = $ \frac{dy0(x0, x1, x2)}{dx0} $

`dy0Dx1` = $ \frac{dy0(x0, x1, x2)}{dx1} $

`dy0Dx2` = $ \frac{dy0(x0, x1, x2)}{dx2} $

`dy1Dx0` = $ \frac{dy1(x0, x1, x2)}{dx0} $

`dy1Dx1` = $ \frac{dy1(x0, x1, x2)}{dx1} $

`dy1Dx2` = $ \frac{dy1(x0, x1, x2)}{dx2} $

In [6]:
// Derivative class

class Derivative(val dy0Dx0: DScalar, val dy0Dx1: DScalar, val dy0Dx2: DScalar, 
                 val dy1Dx0: DScalar, val dy1Dx1: DScalar, val dy1Dx2: DScalar) 
 


### Extracting the Derivatives

We use the function `makeDerivative()` to extract the derivatives from the calculations. The function `extractDerivative()` is implemented in the code and knows how to extract a derivative from the calculations. We pass `extractDerivative()` an individual input variable and an individual output variable.

In [7]:
// extracting the derivatives

fun makeDerivative(input: Input, output: Output, extractDerivative: (DTensor, DTensor) -> DTensor): Derivative {
          
    val x0 = input.x0
    val x1 = input.x1
    val x2 = input.x2
    val y0 = output.y0
    val y1 = output.y1
            
    val dy0Dx0 = extractDerivative(x0, y0) as DScalar
    val dy0Dx1 = extractDerivative(x1, y0) as DScalar
    val dy0Dx2 = extractDerivative(x2, y0) as DScalar
    val dy1Dx0 = extractDerivative(x0, y1) as DScalar
    val dy1Dx1 = extractDerivative(x1, y1) as DScalar
    val dy1Dx2 = extractDerivative(x2, y1) as DScalar
             
    return Derivative(dy0Dx0, dy0Dx1, dy0Dx2, dy1Dx0, dy1Dx1, dy1Dx2)
}

### User Defined Function

The function `f()` is the user defined function for the calculation presented in the example. The input and output tensors have been replace by the `Input` and `Output` class.

In [8]:
// user defined function

fun f(input: Input): Output {
        
    val x0 = input.x0
    val x1 = input.x1
    val x2 = input.x2
        
    val y0  = 2f * x0 + 2f * x1.pow(2f) + 3f * x1 * x2.pow(3f) 
    val y1  = 3f * x0.pow(3f) * x1 + 2f * x1.pow(2f) + 3f * x2.pow(3f)
        
    return Output(y0,y1)
}


### Initialization of Input

In [9]:
// Initialization of x

val x = Input(FloatScalar(1f), FloatScalar(2f), FloatScalar(3f))

### Calling `primalAndForwardDerivative()`

In [10]:
// calculate derivatives with the forward algorithm

val (p,d) = primalAndForwardDerivative(
    x = x,
    f = ::f,
    extractDerivative = ::makeDerivative)

### Output

Compare the output of __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivatives.html)__ with the example output. It is exactly the same but used user defined types instead of tensors.

In [11]:
// print the output

println("y0 = ${p.y0}")
println("y1 = ${p.y1}")
println("dy0Dx0 = ${d.dy0Dx0}")
println("dy0Dx1 = ${d.dy0Dx1}")
println("dy0Dx2 = ${d.dy0Dx2}")
println("dy1Dx0 = ${d.dy1Dx0}")
println("dy1Dx1 = ${d.dy1Dx1}")
println("dy1Dx2 = ${d.dy1Dx2}")


y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


### Calling `primalAndReverseDerivative()`

In [12]:
// calculate the derivatives with the reverse algorithm

val (pr,dr) = primalAndReverseDerivative(
    x = x,
    f = ::f,
    extractDerivative = ::makeDerivative)

### Output

The output of __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ is the same as the output of __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__. 

In [13]:
// print the output

println("y0 = ${pr.y0}")
println("y1 = ${pr.y1}")
println("dy0Dx0 = ${dr.dy0Dx0}")
println("dy0Dx1 = ${dr.dy0Dx1}")
println("dy0Dx2 = ${dr.dy0Dx2}")
println("dy1Dx0 = ${dr.dy1Dx0}")
println("dy1Dx1 = ${dr.dy1Dx1}")
println("dy1Dx2 = ${dr.dy1Dx2}")


y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


## User Defined Types That Don't Inherit From `Differentiable<T>`

In this section we will look at the same example but the user defined types will not inherit from the`Differentiable<T>` interface. In this case, we have to pass two additional paramaters, `wrapInput` and `wrapOutput`, to __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ that will be as lambdas that call the `wrap()` functions on the objects.

### Input Class

The `Input` class does not inherit from the `Differentiable<T>` interface, so `wrap()` is defined as a function of the class and is not overriding a function in the `Differentiable<T>` interface.

In [14]:
// Input class

class Input(val x0: DScalar, val x1: DScalar, val x2: DScalar) {
            fun wrap(wrapper: Wrapper): Input {
                return Input(wrapper.wrap(x0), wrapper.wrap(x1), wrapper.wrap(x2))
            }
}

### Output Class

The same with the `Output` class.

In [15]:
// Output Class

class Output(val y0: DScalar, val y1: DScalar) {
            fun wrap(wrapper: Wrapper): Output {
                return Output(wrapper.wrap(y0), wrapper.wrap(y1))
            }
} 

### Derivative Class

No changes for the `Derivative` class.

In [16]:
// Derivative class

class Derivative(val dy0Dx0: DScalar, val dy0Dx1: DScalar, val dy0Dx2: DScalar, 
                 val dy1Dx0: DScalar, val dy1Dx1: DScalar, val dy1Dx2: DScalar) 
 

### Extracting the Derivatives

The `makeDerivative()` function is the same.

In [17]:
// extracting the derivatives

fun makeDerivative(input: Input, output: Output, extractDerivative: (DTensor, DTensor) -> DTensor): Derivative {
          
    val x0 = input.x0
    val x1 = input.x1
    val x2 = input.x2
    val y0 = output.y0
    val y1 = output.y1
            
    val dy0Dx0 = extractDerivative(x0, y0) as DScalar
    val dy0Dx1 = extractDerivative(x1, y0) as DScalar
    val dy0Dx2 = extractDerivative(x2, y0) as DScalar
    val dy1Dx0 = extractDerivative(x0, y1) as DScalar
    val dy1Dx1 = extractDerivative(x1, y1) as DScalar
    val dy1Dx2 = extractDerivative(x2, y1) as DScalar
       
    return Derivative(dy0Dx0, dy0Dx1, dy0Dx2, dy1Dx0, dy1Dx1, dy1Dx2)
}

### User Defined Function

The function `f()` for the calculation is the same.

In [18]:
// user defined function

fun f(input: Input): Output {
        
    val x0 = input.x0
    val x1 = input.x1
    val x2 = input.x2
        
    val y0  = 2f * x0 + 2f * x1.pow(2f) + 3f * x1 * x2.pow(3f) 
    val y1  = 3f * x0.pow(3f) * x1 + 2f * x1.pow(2f) + 3f * x2.pow(3f)
 
    return Output(y0,y1)
}


### Initialization of the Input

In [19]:
// Initialization of x

val x = Input(FloatScalar(1f), FloatScalar(2f), FloatScalar(3f))

### Calling `primalAndForwardDerivative()`

Since the `Input` and `Output` classes do not inherit from the `Differentiable<T>` interface, we need a way for `wrap()` to be called. The parameters `wrapInput` and `wrapOutput` are assigned lambdas that call the `wrap()` function in the `Input` and `Output` classes.

In [20]:
// Calculate the derivatives with the forward algorithm

val (p,d) = primalAndForwardDerivative(
    x = x,
    f = ::f,
    wrapInput = { i, w -> i.wrap(w) },
    wrapOutput = { o, w -> o.wrap(w) },
    extractDerivative = ::makeDerivative)

### Output

The output of __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivatives.html)__ is the same as the previous example.

In [21]:
// print the output

println("y0 = ${p.y0}")
println("y1 = ${p.y1}")
println("dy0Dx0 = ${d.dy0Dx0}")
println("dy0Dx1 = ${d.dy0Dx1}")
println("dy0Dx2 = ${d.dy0Dx2}")
println("dy1Dx0 = ${d.dy1Dx0}")
println("dy1Dx1 = ${d.dy1Dx1}")
println("dy1Dx2 = ${d.dy1Dx2}")

y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


### Calling `primalAndReverseDerivative()`

In [22]:
// Calculate the derivatives with the reverse algorithm

val (pr,dr) = primalAndReverseDerivative(
    x = x,
    f = ::f,
    wrapInput = { i, w -> i.wrap(w) },
    wrapOutput = { o, w -> o.wrap(w) },
    extractDerivative = ::makeDerivative)

### Output

The output of __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ is the same as the previous example.

In [23]:
// print the output

println("y0 = ${pr.y0}")
println("y1 = ${pr.y1}")
println("dy0Dx0 = ${dr.dy0Dx0}")
println("dy0Dx1 = ${dr.dy0Dx1}")
println("dy0Dx2 = ${dr.dy0Dx2}")
println("dy1Dx0 = ${dr.dy1Dx0}")
println("dy1Dx1 = ${dr.dy1Dx1}")
println("dy1Dx2 = ${dr.dy1Dx2}")

y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


## A List of Scalars

In this example a list of __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__ is used as the input and output. Internally, a list is already represented as a wrapped data type and has the machinery to call the `wrap()` function. In this example there are no `Input`, `Output`, or `Derivative` types. 

### Extracting the Derivatives

In this example the input variables are variables in a list:

$x_0$ = `input[0]`

$x_1$ = `input[1]`

$x_2$ = `input[2]`

The output variables are variables in a list:

$y_0$ = `output[0]`

$y_1$ = `output[1]`

The extracted derivatives are variables in a list:

`dy0Dx0` = `d[0]`

`dy0Dx1` = `d[1]`

`dy0Dx2` = `d[2]`

`dy1Dx0` = `d[3]`

`dy1Dx1` = `d[4]`

`dy1Dx2` = `d[5]`

In [24]:
// extracting the derivatives

fun makeDerivative(input: List<DScalar>, output: List<DScalar>, extractDerivative: (DTensor, DTensor) -> DTensor): List<DScalar> {
       
    val x0 = input[0]
    val x1 = input[1]
    val x2 = input[2]
    val y0 = output[0]
    val y1 = output[1]
        
    val dy0Dx0 = extractDerivative(x0, y0) as DScalar
    val dy0Dx1 = extractDerivative(x1, y0) as DScalar
    val dy0Dx2 = extractDerivative(x2, y0) as DScalar
    val dy1Dx0 = extractDerivative(x0, y1) as DScalar
    val dy1Dx1 = extractDerivative(x1, y1) as DScalar
    val dy1Dx2 = extractDerivative(x2, y1) as DScalar     
    return listOf(dy0Dx0, dy0Dx1, dy0Dx2, dy1Dx0, dy1Dx1, dy1Dx2)
}

### User Defined Function

The user defined function `f()` is the same except the variables are now in a list.

In [25]:
// user defined function

fun f(input: List<DScalar>): List<DScalar> {
        
    val x0 = input[0]
    val x1 = input[1]
    val x2 = input[2]
        
    val y0  = 2f * x0 + 2f * x1.pow(2f) + 3f * x1 * x2.pow(3f) 
    val y1  = 3f * x0.pow(3f) * x1 + 2f * x1.pow(2f) + 3f * x2.pow(3f)
    return listOf(y0,y1)
}


### Initialization of X

The initial values for the input variables are placed in a list.

In [26]:
// initialization of x

val x = listOf(FloatScalar(1f), FloatScalar(2f), FloatScalar(3f))


### Calling `primalAndForwardDerivative()`

In [27]:
// Calculate the derivatives with the forward algorithm

val (p,d) = primalAndForwardDerivative(
    x = x,
    f = ::f,
    extractDerivative = ::makeDerivative)

### Output

The output is the same except the primal p and derivates d are list.

In [28]:
// print the output

println("y0 = ${p[0]}")
println("y1 = ${p[1]}")
println("dgDx0 = ${d[0]}")
println("dgDx1 = ${d[1]}")
println("dgDx2 = ${d[2]}")
println("dhDx0 = ${d[3]}")
println("dhDx1 = ${d[4]}")
println("dhDx2 = ${d[5]}")

y0 = 172.0
y1 = 95.0
dgDx0 = 2.0
dgDx1 = 89.0
dgDx2 = 162.0
dhDx0 = 18.0
dhDx1 = 11.0
dhDx2 = 81.0


### Calling `primalAndReverseDerivative()`

In [29]:
// Calculate the derivatives with the reverse algorithm

val (pr,dr) = primalAndReverseDerivative(
    x = x,
    f = ::f,
    extractDerivative = ::makeDerivative)

### Output

The same output.

In [30]:
// print the output

println("y0 = ${pr[0]}")
println("y1 = ${pr[1]}")
println("dgDx0 = ${dr[0]}")
println("dgDx1 = ${dr[1]}")
println("dgDx2 = ${dr[2]}")
println("dhDx0 = ${dr[3]}")
println("dhDx1 = ${dr[4]}")
println("dhDx2 = ${dr[5]}")

y0 = 172.0
y1 = 95.0
dgDx0 = 2.0
dgDx1 = 89.0
dgDx2 = 162.0
dhDx0 = 18.0
dhDx1 = 11.0
dhDx2 = 81.0


## Complex Data Structures

This example uses a variety of data structures and mixes the way the wrapper is invoked. 
1. It has a class, `MyVariable`, that holds an individual `FloatScalar` variable that inherits from the `Differentiable<T>` interface. `MyVariable` has an alternative constructor that can be initialized from a `Float` instead of a `FloatScalar`.
2. It has a class, `Input`, that holds a `List` of `MyVariable`. The class `Input` does not inherit from `Differentiable<T>`. `List` is a registered class internally in the code and implements the `wrap()` function internally.
3. It has a class, `Output`, that uses a `Pair` to hold the output variables $y_0$ and $y_1$. `Output` does not inherit from `Differentiable<T>`. `Pair` is a registered class internally in the code and implements the `wrap()` function internally.
4. This example shows that you can call either  __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ and __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__ with a mixture of data types that use multiple methods for implementing the `wrap()` function.

### MyVariable

`MyVariable` is a class to hold an individual input variable as a __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__. The class does not inherit from the `Differentiable<T>` interface.

In [31]:
// MyVariable class

class MyVariable(val v: DScalar) : Differentiable<MyVariable> {
    
    constructor( f: Float) : this(FloatScalar(f))
    
    override fun wrap(wrapper: Wrapper): MyVariable {
        return MyVariable(wrapper.wrap(v))
    }
}

### Input

The `Input` class takes a `List` of `MyVariable` as the input. The `Input` class does not implement the `Differentiable<T>` interface. Notice the hierarchy of the data structure. You have a class -> list -> class -> __[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__. In this data structure there is a mixture of the implementations of the `wrap()` function: inheriting from `Differentiable<T>`, not inheriting from `Differentiable<T>`, and using registration.

In [32]:
// Input class

class Input(val i : List<MyVariable>) {
    
    fun wrap(wrapper: Wrapper): Input {
        return Input(wrapper.wrap(i))
    }
    
    fun getX(index : Int) : DScalar {
        return i[index].v
    }
}

### Output

The `Output` class uses a `Pair<`__[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__`,`__[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__`>` to hold the output variables $y_0$ and $y_1$. `Pair<`__[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__`,`__[DScalar](http://www.diffkt.org/api/api/org.diffkt/-d-scalar/index.html)__`>` does not have to be wrapped because it is registered in the internal code. The `Output` class does not inherit from the `Differentiable<T>` interface.

In [33]:
// Output class

class Output(val o : Pair<DScalar, DScalar>) {
    
    fun wrap(wrapper: Wrapper): Output {
        return Output(Pair(wrapper.wrap(o.first), wrapper.wrap(o.second)))
    }    
    
}

### Derivative

The `Derivative` class is the same as the previous example.

In [34]:
// Derivative class

class Derivative(val dy0Dx0: DScalar, val dy0Dx1: DScalar, val dy0Dx2: DScalar, 
                 val dy1Dx0: DScalar, val dy1Dx1: DScalar, val dy1Dx2: DScalar) 

### Extracting the Derivatives

The `makeDerivative()` function is adapted for the new `Input` and `Output` data structures. A helper function, `getX()`, was implemented in the `Input` class to index the input variables.

In [35]:
// extracting the derivatives

fun makeDerivative(input: Input, output: Output, 
                    extractDerivative: (DTensor, DTensor) -> DTensor): Derivative {
        
    val x0 = input.getX(0) 
    val x1 = input.getX(1) 
    val x2 = input.getX(2) 
    val y0 = output.o.first 
    val y1 = output.o.second
    
    val dy0Dx0 = extractDerivative(x0, y0) as DScalar
    val dy0Dx1 = extractDerivative(x1, y0) as DScalar
    val dy0Dx2 = extractDerivative(x2, y0) as DScalar
    val dy1Dx0 = extractDerivative(x0, y1) as DScalar
    val dy1Dx1 = extractDerivative(x1, y1) as DScalar
    val dy1Dx2 = extractDerivative(x2, y1) as DScalar
    
    return Derivative(dy0Dx0, dy0Dx1, dy0Dx2, dy1Dx0, dy1Dx1, dy1Dx2)                      
}


### User Defined Function

The user defined function `f()` is the same as the previous examples except it is adapted for the new data structures.

In [36]:
// user defined function

fun f(input: Input): Output {
        
    val x0 = input.getX(0)
    val x1 = input.getX(1)
    val x2 = input.getX(2)
        
    val y0  = 2f * x0 + 2f * x1.pow(2f) + 3f * x1 * x2.pow(3f) 
    val y1  = 3f * x0.pow(3f) * x1 + 2f * x2.pow(2f) + 3f * x2.pow(3f)
     
    return Output(Pair(y0, y1))
}

### Initialization of X

In [37]:
// Initialization of x

val l = listOf(MyVariable(1f),MyVariable(2f), MyVariable(3f))
val x = Input(l)

### Calling `primalAndForwardDerivative()`

Since parts of the `Input` and `Output` data structures do not inherit from the `Differentiable<T>` interface, we need a way for `wrap()` to be called. The parameter `wrapInput` is assigned a lambda that calls the `wrap()` function for `Input`. The parameter `wrapOutput` is assigned a lambda that calls the `wrap()` function for `Output`.  __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivatives.html)__ works with a mixture of wrapping methods for the data types in a complex data structure.

In [38]:
// Calculate the derivatives with the forward algorithm

val (p,d) = primalAndForwardDerivative(
    x = x,
    f = ::f,
    wrapInput = {i, w -> i.wrap(w)},
    wrapOutput = {o, w -> o.wrap(w)},
    extractDerivative = ::makeDerivative)

### Output

In [39]:
// print the output

println("y0 = ${p.o.first}")
println("y1 = ${p.o.second}")
println("dy0Dx0 = ${d.dy0Dx0}")
println("dy0Dx1 = ${d.dy0Dx1}")
println("dy0Dx2 = ${d.dy0Dx2}")
println("dy1Dx0 = ${d.dy1Dx0}")
println("dy1Dx1 = ${d.dy1Dx1}")
println("dy1Dx2 = ${d.dy1Dx2}")

y0 = 172.0
y1 = 105.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 3.0
dy1Dx2 = 93.0


### Calling `primalAndReverseDerivative()`

This works in the same manner as the forward algorithm.

In [40]:
// Calculate the derivatives with the reverse algorithm

val (pr,dr) = primalAndReverseDerivative(
    x = x,
    f = ::f,
    wrapInput = {i, w -> i.wrap(w)},
    wrapOutput = {o, w -> o.wrap(w)},
    extractDerivative = ::makeDerivative)

### Output

In [41]:
// print the output

println("y0 = ${pr.o.first}")
println("y1 = ${pr.o.second}")
println("dy0Dx0 = ${dr.dy0Dx0}")
println("dy0Dx1 = ${dr.dy0Dx1}")
println("dy0Dx2 = ${dr.dy0Dx2}")
println("dy1Dx0 = ${dr.dy1Dx0}")
println("dy1Dx1 = ${dr.dy1Dx1}")
println("dy1Dx2 = ${dr.dy1Dx2}")

y0 = 172.0
y1 = 105.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 3.0
dy1Dx2 = 93.0


## Using Registration

Registration allows one to register a class with __[Wrapper()](http://www.diffkt.org/api/api/org.diffkt/-wrapper/index.html)__ and then use the class without inheriting from Differentiable<T> or use `wrapInput` or `wrapOutput` in every call to __[primalAndForwardDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-forward-derivative.html)__ or __[primalAndReverseDerivative()](http://www.diffkt.org/api/api/org.diffkt/primal-and-reverse-derivative.html)__. Registration is useful for user defined types that are not the `Input` or `Output` classes.

### MyMap

`MyMap` is an container to hold a `Map`, such that the values are wrapped but the keys are not. `MyMap` does not inherit from `Differentiable<T>`. `MyMap` is used in as a data structure in both the `Input` and `Output` classes. Notice that since it does not inherit from `Differentiable<T>` nor is it an input or output class, which can use `wrapInput` or `wrapOutput`, it needs a way to get its `wrap()` function called. Registration is the method that will wrap the class.

In [42]:
// MyMap class

class MyMap(val myMap: Map<String, DScalar>) {
 
    operator fun get(key: String) : DScalar {
        
        return myMap[key] ?: FloatScalar(0f)
    }
    
    fun wrap(wrapper: Wrapper) : MyMap {
        return MyMap(myMap.mapValues {wrapper.wrap(it.value)})
    }
      
}

### Register MyMap

Register `MyMap` with the `Wrapper` class. Once the class is registered with the Wrapper, the class can be used anywhere in the code. The lambda used for the wrapping is the same as the lambdas used with `wrapInput` and `wrapOutput`.

In [43]:
Wrapper.register(MyMap::class, {m, w -> m.wrap(w)})

### Input

The `Input` class inherits from the `Differentiable<T>` interface. The `wrap()` function is overridden because it needs to return the type of `Input`. `Input` is initialized with a variable,`inputMap`, of type `MyMap`.

In [44]:
// Input class

class Input(val inputMap: MyMap) : Differentiable<Input> {
    
    override fun wrap(wrapper: Wrapper) : Input {
        return Input(wrapper.wrap(inputMap))
    }
      
}

### Output

The `Output` class does not inherit from the `Differentiable<T>` interface. We will use the `wrapInput` method.

In [45]:
// Output class

class Output(val outputMap: MyMap) {
    
    fun wrap(wrapper: Wrapper) : Output {
        return Output(wrapper.wrap(outputMap))
    }
   
}

### Derivative

The `Derivative` class is the same as previous examples.

In [46]:
// Derivative class

class Derivative(val dy0Dx0: DScalar, val dy0Dx1: DScalar, val dy0Dx2: DScalar, 
                 val dy1Dx0: DScalar, val dy1Dx1: DScalar, val dy1Dx2: DScalar) 

### Extracting the Derivatives

The `makeDerivative()` function is adapated for using a map to hold the input and output variables.

In [47]:
// extracting the derivatives

fun makeDerivative(input: Input, output: Output, 
                    extractDerivative: (DTensor, DTensor) -> DTensor): Derivative {
    
    val x0 = input.inputMap["x0"]
    val x1 = input.inputMap["x1"]
    val x2 = input.inputMap["x2"]
    val y0 = output.outputMap["y0"]
    val y1 = output.outputMap["y1"]
        
    var dy0Dx0 = extractDerivative(x0, y0) as DScalar
    var dy0Dx1 = extractDerivative(x1, y0) as DScalar
    var dy0Dx2 = extractDerivative(x2, y0) as DScalar
    var dy1Dx0 = extractDerivative(x0, y1) as DScalar
    var dy1Dx1 = extractDerivative(x1, y1) as DScalar
    var dy1Dx2 = extractDerivative(x2, y1) as DScalar

    return Derivative(dy0Dx0, dy0Dx1, dy0Dx2, dy1Dx0, dy1Dx1, dy1Dx2)                      
}

### User Defined Function

The user defined function `f()` is adapted for using a map to hold the input and output variables.

In [48]:
// user defined function

fun f(input: Input): Output {
        
    val x0 = input.inputMap["x0"]
    val x1 = input.inputMap["x1"]
    val x2 = input.inputMap["x2"]
        
    val y0  = 2f * x0 + 2f * x1.pow(2f) + 3f * x1 * x2.pow(3f) 
    val y1  = 3f * x0.pow(3f) * x1 + 2f * x1.pow(2f) + 3f * x2.pow(3f)
        
    return Output(MyMap(mapOf("y0" to y0, "y1" to y1)))
}

### Initialization of X

In [49]:
// Initialization of x

var x = Input(MyMap(mapOf("x0" to FloatScalar(1f), "x1" to FloatScalar(2f), "x2" to FloatScalar(3f))))

### Calling `primalAndForwardDerivative()`

Notice that `wrapOutput` is used for the `Output` class.

In [50]:
// Calculate the derivatives with the forward algorithm

val (p,d) = primalAndForwardDerivative(
    x = x,
    f = ::f,
    wrapOutput = { o, w -> o.wrap(w)},
    extractDerivative = ::makeDerivative)

### Output

In [51]:
// print the output

println("y0 = ${p.outputMap["y0"]}")
println("y1 = ${p.outputMap["y1"]}")
println("dy0Dx0 = ${d.dy0Dx0}")
println("dy0Dx1 = ${d.dy0Dx1}")
println("dy0Dx2 = ${d.dy0Dx2}")
println("dy1Dx0 = ${d.dy1Dx0}")
println("dy1Dx1 = ${d.dy1Dx1}")
println("dy1Dx2 = ${d.dy1Dx2}")

y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


### Calling `primalAndReverseDerivative()`

Notice that `wrapOutput` is used for the `Output` class.

In [52]:
// Calculate the derivatives with the reverse algorithm

val (pr,dr) = primalAndReverseDerivative(
    x = x,
    f = ::f,
    wrapOutput = { o, w -> o.wrap(w)},
    extractDerivative = ::makeDerivative)

### Output

In [53]:
// print the output

println("y0 = ${pr.outputMap["y0"]}")
println("y1 = ${pr.outputMap["y1"]}")
println("dy0Dx0 = ${dr.dy0Dx0}")
println("dy0Dx1 = ${dr.dy0Dx1}")
println("dy0Dx2 = ${dr.dy0Dx2}")
println("dy1Dx0 = ${dr.dy1Dx0}")
println("dy1Dx1 = ${dr.dy1Dx1}")
println("dy1Dx2 = ${dr.dy1Dx2}")

y0 = 172.0
y1 = 95.0
dy0Dx0 = 2.0
dy0Dx1 = 89.0
dy0Dx2 = 162.0
dy1Dx0 = 18.0
dy1Dx1 = 11.0
dy1Dx2 = 81.0


# The End

You have seen many examples of user defined types. Hopefully, this tutorial will assist you in creating and using your own user defined types with **DiffKt**.