# Just-In-Time (JIT) Notebook

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

## Introduction

This notebook demonstrates the just-in-time (jit) api. The jit api produces a optimized version of **DiffKt** code. This is useful for when you repeatedly call a function. On the first call to a jitted function, an optimized version is created. On subsequent calls, the optimized version is called, which should result in a speed up of the program.

## Jit Tips and Tricks

There are lots of subtle things you need to get right to take full advantage of the jit:

Make sure there is a good `equals()` and `hashCode()` function for the jitted function's input type. The jit cache needs that.

For the purposes of the jit, wrapping more of the input is better. For example, if you have some inputs that are not active variables of differentiation inside the body of the jitted function, it is still valuable to wrap them for the purposes of the jit so that you will get a cache hit when the values change. That means you may want to use a different (explicit) wrapInput lambda when taking the derivative.

Don't use mutable variables from an enclosing scope. If they are var variables (i.e. they don't change) that is OK, but if the value might change from call to call of the jitted function, they should be explicit inputs to the function.

Don't have side-effects in the jitted function; it should be a pure function. That means no print statements, random number generation, or taking the time of day.

### Housekeeping

In [1]:
@file:DependsOn("../kotlin/api/build/libs/api.jar")
@file:DependsOn("net.bytebuddy:byte-buddy:LATEST")

In [2]:
%useLatestDescriptors
%use lets-plot

### Imports

In [3]:
import org.diffkt.*
import java.util.Vector
import jetbrains.letsPlot.intern.Plot
import org.diffkt.tracing.*

## The Model

We will reuse the complex mass-spring model for the mass-spring notebook

The complex mass-spring model has four masses, two non-movable masses and two moveable masses, and three springs. There are two non-moveable masses connected by springs to a movable mass. Below the movable mass is a spring and a second movable mass. The movable masses are initially place to the left, so motion results in both the x and y axis. The folowing is a picture of the complex mass-spring model.

![Complex Mass-Spring Model](resources/complex_mass_spring.png)




### Constants

In [4]:
// Constants

object Constants {
   
// Physical Properties
    
    val defaultMass = FloatScalar(1f)
    val defaultRestingLength = FloatScalar(1f)
    val defaultDampen = FloatScalar(.9995f) 
    val ground = FloatScalar(0f)
    val gravityConstant = FloatScalar(9.8f)
    val hooksConstant = FloatScalar(600f)
    val velocityScaling = 10
    val accelerationScaling = 100
    val dt = FloatScalar(.00025f)
    
// Plotting 
    
    val xGraphTitle =             "X vs Time"
    val xPositionAxisLable =      "X Position"
    val xVelocityAxisLable =      "X Velocity"       
    val xAccelerationAxisLable =  "X Acceleration"
        
    val yGraphTitle =             "Y  vs Time"
    val yPositionAxisLable =      "Y Position"
    val yVelocityAxisLable =      "Y Velocity"       
    val yAccelerationAxisLable =  "Y Acceleration"
    val timeAxisLable =           "Time"
    
}

### The Vector2 Coordinates

The `Vector2` class is used to hold the values of the x and y position coordinates. It is also use for velocity, acceleration, and the derivatives of the position. This class demonstrate using operator overloading with complex data types.

The `Vector2` class inherits from the Differentiable<T> interface. The `Vector2` class wraps both the x and y variables.
    
Tips for jitting a class, classes need to be uniquely identified with the `hashCode()` and the `equals()` functions.

In [5]:
// class Vector2

open class Vector2(val x: DScalar, val y: DScalar) : Differentiable<Vector2> {
    
    constructor(x: Float, y: Float): this(FloatScalar(x), FloatScalar(y))
    
    // tip for using jit, make sure hashCode() is unique
    
    override fun hashCode(): Int = combineHash("Vector2")
    
    // tip for using jit, make sure equals() means the hashCode() of two object are the same
    
    override fun equals(other: Any?): Boolean = other.hashCode() == hashCode()
    
    override fun wrap(wrapper: Wrapper) : Vector2 {
        return Vector2(wrapper.wrap(x), wrapper.wrap(y))
    }
    
    operator fun unaryMinus(): Vector2 {
        return Vector2(-x, -y)
    }
    
    operator fun plus(b: Vector2): Vector2 {
        return Vector2(x + b.x, y + b.y)
    }
    
    operator fun minus(b: Vector2): Vector2 {
        return Vector2(x - b.x, y - b.y)
    }
    
    operator fun times(c: Float) : Vector2 {
        return Vector2(c * x, c * y)
    }
    
    operator fun times(c: DScalar) : Vector2 {
        return Vector2(c * x, c * y)
    }
    
    operator fun div(c: Float) : Vector2 {
        return Vector2(x / c, y / c)
    }
    
    operator fun div(c: DScalar) : Vector2 {
        return Vector2(x / c, y / c)
    }
    
    fun norm(): DScalar {
        return (x * x + y * y).pow(0.5f)
    }
}

### Spring

A `Spring` is connected to two masses. Every `Spring` object has an `id`. A `Spring` object connects to a left `Mass` object and a right `Mass` object, regardless of the actual orientation. A `Spring` object has a resting length set to the `defaultRestingLength` constant when the `MassStringSystem` is constructed.

The `Spring` class is not a differentiable class in **DiffKt**

In [6]:
// Spring

data class Spring(val id : Int,
                  val leftMassId : Int,
                  val rightMassId : Int,
                  val restingLength : FloatScalar)

### Mass

Every `Mass` object has an `id`. A Mass object can be `movable` or non-`movable`. A `Mass` object has a `mass`. A `Mass` object has a list of `Spring` objects that connect it to other `Mass` objects. The `Mass` object is evaluated at a particular time. At the time evaluated, it has a position, velocity, and acceleration for both the x and y axis. For a non-movable `Mass` object, the position is set at the initalization of the `Mass` object, and the velocity and acceleration are zero. 

The class `Mass` inherits from the `Differentiable<T>` interface. The position of a `Mass` object is a differentiable variable for **DiffKt**. The velocity and acceleration are not differentiable variables. A `Mass` object can be initialized with either `DScalars` or `Floats`. Since the class `Mass` inherits from the `Differentiable<T>` interface, it overides the `wrap` function, and calls `wrap` on the differentiable variables `xPosition` and `yPosition`. The other variables are ignored by **DiffKt** because they are not wrapped.

Tips for jitting a class, classes need to be uniquely identified with the `hashCode()` and the `equals()` functions. Also note, all the `Float` variable are implemented as `DScalar` and wrapped in the `wrap()` function.

In [7]:
// Mass

class Mass( val id : Int,
            val time : DScalar = FloatScalar(0f),
            val step : DScalar = FloatScalar(0f),
            val movable : Boolean = false,
            val mass : DScalar = FloatScalar(1f),
            val springList : List<Spring>, 
            val position : Vector2 = Vector2(0f, 0f),                   
            val velocity : Vector2 = Vector2(0f, 0f),
            val acceleration : Vector2 = Vector2(0f, 0f)) : Differentiable<Mass> {
  
    // Alternative Constructor
    
    constructor(id : Int,
                time : Float = 0f,
                step : Float = 0f,
                movable : Boolean = false,
                mass : Float = 1f,
                springList : List<Spring>,
                xPosition : Float = 0f,
                xVelocity : Float = 0f,
                xAcceleration : Float = 0f,
                yPosition : Float = 0f,
                yVelocity : Float = 0f,
                yAcceleration : Float = 0f) : this ( id,
                                                FloatScalar(time),
                                                FloatScalar(step),
                                                movable,
                                                FloatScalar(mass),
                                                springList,
                                                Vector2(xPosition, yPosition),
                                                Vector2(xVelocity, yVelocity),
                                                Vector2(xAcceleration, yAcceleration))
    
    // tip for using jit, make sure hashCode() is unique
    override fun hashCode(): Int = combineHash("Mass")

    // tip for using jit, make sure equals() means the hashCode() of two object are the same
    override fun equals(other: Any?): Boolean = other.hashCode() == hashCode()
                                                
    // tip for using jit, wrap as many of the variables to speed up jit
    override fun wrap(wrapper : Wrapper) : Mass {
        return Mass(id,
                    wrapper.wrap(time),
                    wrapper.wrap(step),
                    movable,
                    wrapper.wrap(mass),
                    springList,
                    wrapper.wrap(position), // wrapped
                    wrapper.wrap(velocity),
                    wrapper.wrap(acceleration))
    }
       
}

### Energy Calculations

There are two forces acting on a mass, the force of gravity and the force of the springs attached to the mass. The force of gravity only acts on the y axis of a mass. The spring force is modeled as a Hookean spring force. The spring force can act in both the x and y axes. The forces are only applied if the mass is a movable mass. If the masses are non-movable, there are no forces or energy calculated. 

The energy from these forces is the integral of the forces. The energies for a mass are additive.

#### Gravitational Energy
The equation for the gravity energy is

$energy = (y - y_0) * g * mass$

where $y_0$ is the ground, which is zero, and $g$ is the force of gravity, 9.8N.

#### Spring Energy
The equation for the spring energy is more complicated, a spring energy is calculated for each spring connected to a mass.

$energy = 0.5 * h * \Delta l^2$

where $h$ is Hooks constant, and $\Delta l$ = the length between the two masses minus the resting length. Note, the length of the spring can not be less than the resting length of the spring and $\Delta l$ can not be less than zero.

#### Total Energy
The total energy is the gravitational energy of a mass plus the spring energy of each spring connected to the mass.

In [8]:
// Energy Calculations

object EnergyCalculations {
    
    // Gravity Energy 
    
    fun gravitationalEnergy(id : Int, ml : List<Mass>) : DScalar {
        
        val energy = if (ml[id].movable) {
                (ml[id].position.y - Constants.ground) * Constants.gravityConstant * ml[id].mass
            } else {
                FloatScalar(0f) as DScalar
            }
        
 
        return energy
    }

    // Spring Energy
    
    fun springEnergy( id : Int, ml : List<Mass>) : DScalar {
        
        val energy = if (ml[id].movable) {
            ml[id].springList.fold(FloatScalar(0f) as DScalar) { accumulatedEnergy, spring ->
                
                    val leftMassPosition = ml[spring.leftMassId]
                    val rightMassPosition = ml[spring.rightMassId]
            
                    val deltaPosition = leftMassPosition.position - rightMassPosition.position
                    val length = deltaPosition.norm()
                    
                    // Note the spring can not become shorter than it resting length
                    
                    val deltaSpring = length - spring.restingLength

                    val deltaLength = deltaSpring
        
                    val potentialEnergy =  FloatScalar(0.5f) * Constants.hooksConstant * deltaLength.pow(2f)
                    
                    accumulatedEnergy + potentialEnergy
                } 
            } else { 
                FloatScalar(0f) as DScalar
            }
               
        return energy
    }

    // Total Energy
    
    fun totalEnergy(id : Int, ml : List<Mass>) : DScalar {
                
        val gEnergy = gravitationalEnergy(id, ml)
        val sEnergy = springEnergy(id, ml)  
        val total = gEnergy + sEnergy

        return total
    }
}

### The Mass-Spring System

The `MassSpringSystem` class contains a list of all `Mass` objects in the system, with their associated `Spring` objects. The `MassSpringSystem` class implements the `Differentiable<T>` interface. Since the `MassSpringClass` inherits from the `Differentiable<T>` interface, it overrides the `wrap` function and calls `wrap` on a `List` of `Mass` objects. `List` implements the `Differentiable<T>` interface in **DiffKt** and is an internal data structure in **DiffKt**. `Mass` was written to implement the `Differentiable<T>` interface.
    
The `totalSystemEnergy` is the sum of all the energies associated with every `Mass`.

Tip for jitting a function, objects need to be uniquely identified with the `hashCode()` and `equals()` functions.

In [9]:
// MassSpringSystem class

class MassSpringSystem(val massList : List<Mass>) : Differentiable<MassSpringSystem> {
    
    override fun hashCode(): Int = combineHash("MassSpringSystem")
    
    override fun equals(other: Any?): Boolean = other.hashCode() == hashCode()
    
    override fun wrap(wrapper : Wrapper) : MassSpringSystem {
        return MassSpringSystem(wrapper.wrap(massList))
    }
         
    fun totalSystemEnergy(massList : List<Mass>) : DScalar {
           
        // Sums the energy of each Mass
        val totalSystemEnergy =  massList.fold(FloatScalar(0f) as DScalar) 
            {accumulatedEnergy, mass ->
                accumulatedEnergy + EnergyCalculations.totalEnergy(mass.id, massList)}
        
        return totalSystemEnergy            
    }
}

### Derivatives

The x and y derivatives of the position of a `Mass` are kept in the instances `MassDerivatives` class. The `TotalDerivates` class hold a `List` of `MassDerivatives` to collect all the derivatives from all the `Mass` instances. 

In [10]:
// Derivatives of the x & y position of a single Mass
class MassDerivatives(x : DScalar, y : DScalar) : Vector2(x, y)
    
// Derivatives of all the Masses in the system
data class TotalDerivatives(val d : List<MassDerivatives>)

### Calculating the Derivatives

The `calculateDerivatives()` function calls the `primalAndReverseDerivatives()` function, the automatic differentiation function for user defined types. The user define types are the `x` and `y` position, variables in a `Vector2` class, which is also differentiable. The `Mass` instances are user defined types and differentiable. The `List` that holds the `Mass` instances, and the `MassSpringSystem` instance are differentiable. The function to be differentiated is `system::totalSystemEnergy`, which is the energy function for the system. A custom function `makeDerivative` extracts the derivatives for each `x` and `y` position of a `Mass` instance.


In [11]:
// Calculate the derivatives of the energy of the system

object Derivatives {
    
    // Extract the derivatives of each variable in the system
    
    fun makeDerivatives(input : List<Mass>, output : DScalar, 
                        extractDerivative: (DTensor, DTensor) -> DTensor) : TotalDerivatives {
        
        val listMassDerivatives = input.map {
            val dmassDx = extractDerivative(it.position.x, output) as DScalar
            val dmassDy = extractDerivative(it.position.y, output) as DScalar
            MassDerivatives(dmassDx, dmassDy)
        }
        
        return TotalDerivatives(listMassDerivatives)
    }
    
    // Calculate the derivatives of the total energy of the system
    
    fun calculateDerivatives(system : MassSpringSystem) : TotalDerivatives {
        
        return primalAndReverseDerivative(
                    x = system.massList,
                    f = system::totalSystemEnergy,
                    extractDerivative = ::makeDerivatives).second  
    }
    
}

### Numerical Integration of the ODEs

This is the section that demonstrates using jit. The function that is jitted is the `Derivatives::calculateDerivatives()` function. A wrapper to measure the time of the function is added. Notice the time is longer on the first call, where an optimized version of the function is created, than on subsequent calls. The subsequent calls are magnitudes faster in time.

The solution for a Mass-Spring System is solved using the forward Euler method for solving ordinary differental equations (ODE). It is probably one of the least desired algorithms for solving ordinary differential equations, it can be unstable and slow, but it is very simple and easy to show as an example.

The forward Euler algorithm to solve an ode $f()$ is 

$y_{n+1} = y_n + h*f(y_n)$

where
$n$ is the time step, and
$h$ is the time step size.

The `systemAcceleration` is the derivative of the energy of each movable mass.<br>

The `acceleration` is the `-x` and `-y` derivative of a mass energy divided by the weight of the mass.<br>
The `velocity` is updated by adding the current `velocity` to `acceleration` times`dt`. The resulting velocity is then multiplied by a dampening constant.<br>
The `position` is updated by adding the current `position` to `velocity` times `dt`.

The `integrate()` function integrates the equations across time, and saves the system state at each time step. 

### JIT

The function that is jitted is `Derivative::calculateDerivatives()`. This is a wrapper around `primalAndReverseDerivative()`. The jitting is set up in the `integrate()` function. This function has to be called for each integration step. There is some timing wrapped around the function call so you can see two things: the jitted function is faster than the non-jitted function, and the first call to the jitted function is slower than the rest of the function call. This may be difficult to see because the first time either the jitted or non-jitted function is called is slower than the subsequent calls because of JVM optimizations as the code is executed. The timing for the first twenty steps and the last twenty steps is printed out.

In [12]:
// Forward Euler ODE Solver

class ForwardEulerSolver() {
    
    // Solver the ODEs for one step
    
    fun integrateOneStep(system: MassSpringSystem, systemAcceleration : TotalDerivatives) : MassSpringSystem {
       
        fun newState(p : Mass, acc : MassDerivatives) : Mass {
             
            val acceleration = if (p.movable) {-acc / p.mass } else Vector2(0f, 0f)
            val velocity =     if (p.movable) {(p.velocity + (p.acceleration * Constants.dt)) * Constants.defaultDampen } else Vector2(0f, 0f)
            val position =     if (p.movable) {p.position + (p.velocity * Constants.dt)} else p.position
            
            val newStep  = (p.step + FloatScalar(1f))
            
            return Mass (id = p.id,
                         step = newStep,
                         time = (newStep * Constants.dt.value),
                         movable = p.movable,
                         mass = p.mass,
                         springList = p.springList,
                         position = position,
                         velocity = velocity,
                         acceleration = acceleration) 
        }  
        
        val massList = system.massList.mapIndexed{idx, mass -> newState(mass, systemAcceleration.d[idx])}
        
        // returns the new state of the system
        
        return MassSpringSystem(massList)

    }
    
    // Solve the ODEs for a fixed length of time
    
    fun integrate(system : MassSpringSystem, time : Float, jit : Boolean) : List<MassSpringSystem> {
        
        val steps = ((FloatScalar(time) / Constants.dt) as FloatScalar).value.toInt()
        val numSteps = 1..steps
        
        val calcDerivatives = if (jit == true) {jit(Derivatives::calculateDerivatives)} else {Derivatives::calculateDerivatives}
        
        fun timedCalcDerivatives(newSystem : MassSpringSystem, calc : (MassSpringSystem) -> TotalDerivatives) : TotalDerivatives {    
            
            var totalDerivatives : TotalDerivatives 
            val step = (newSystem.massList[0].step as FloatScalar).value
            
            if (step <= 20) {      
                val elapsed = kotlin.system.measureNanoTime { 
                    totalDerivatives = calc(newSystem)
                }
                
                val formattedElapsed = String.format("%8f" , elapsed / 1000000000.0 )
                println("step ${step} - ${formattedElapsed} sec")
            } else if (step > (steps - 20)) {
                
                val elapsed = kotlin.system.measureNanoTime { 
                    totalDerivatives = calc(newSystem)
                }
                
                val formattedElapsed = String.format("%8f" , elapsed / 1000000000.0 )
                println("step ${step} - ${formattedElapsed} sec")
            } else {
               totalDerivatives = calc(newSystem) 
            }
            
            return totalDerivatives
        }
        
        // Save the system state of each step in observations
        val observations = numSteps.runningFold(system) {newSystem, _ -> integrateOneStep(newSystem, timedCalcDerivatives(newSystem, calcDerivatives)) }
        
        return observations
    }
}

### Complex Mass-Spring System

![Complex Mass-Spring System](resources/complex_mass_spring.png)

The complex mass-spring system has four masses and three springs. The top two masses are non-movable. The bottom two masses are movable. The system is initialized with the two movable masses pulled to the left to cause motion in the x axis, as well as the y axis.

In [13]:
// Complex Mass-Spring System

object SystemBuild {
    
    fun fourMassThreeSpring() : MassSpringSystem {
                
        val spring0 = Spring(id = 0,
                             leftMassId = 0,
                             rightMassId = 2,
                             restingLength = Constants.defaultRestingLength)
    
 
        
        val spring1 = Spring(id = 1,
                             leftMassId = 1,
                             rightMassId = 2,
                             restingLength = Constants.defaultRestingLength)
        
 
        
        val spring2 = Spring(id = 2,
                             leftMassId = 2,
                             rightMassId = 3,
                             restingLength = Constants.defaultRestingLength)
                
        val springList0 = listOf(spring0)
        
        val springList1 = listOf(spring1)
        
        val springList2 = listOf(spring0, spring1, spring2)
        
        val springList3 = listOf(spring2)
        

        val mass0 = Mass(id = 0,    
                         movable = false,
                         mass = 1f,
                         springList = springList0,
                         yPosition = 20f)


        val mass1 = Mass(id = 1,    
                         movable = false,
                         mass = 1f,
                         springList = springList1,
                         xPosition = 5f,    
                         yPosition = 20f)          
        
        val mass2 = Mass(id = 2,    
                         movable = true,
                         mass = 1f,
                         springList = springList2,
                         yPosition = 16f)
        
        val mass3 = Mass(id = 3,    
                         movable = true,
                         mass = 1f,
                         springList = springList3,
                         yPosition = 12f)



        

        return MassSpringSystem(listOf(mass0, mass1, mass2, mass3))
 
    }
}

### Simulation of Complex Mass-Spring System

The simulation of the complex mass-spring system is for 100 seconds. First, a complex mass-spring system is created. Next, a solver. Lastly, the solver is integrated for five seconds, and the results are plotted. `dt` is set in the constants for a 0.00025 second step size.

### Non-jitted Simulation

Note, jit is set to false.

In [14]:
// Complex Mass-Spring System simulation

fun simulation1() {

    val time = 100.0f
    val jit = false
    
    val system = SystemBuild.fourMassThreeSpring()
                           
    val solver = ForwardEulerSolver()
    
    val elapsed = kotlin.system.measureTimeMillis { 
        val observations = solver.integrate(system, time, jit)
    }
    
    println("Total Elapsed Time - ${elapsed / 1000.0} sec")
 
}

simulation1()

step 0.0 - 0.047710 sec
step 1.0 - 0.003520 sec
step 2.0 - 0.001328 sec
step 3.0 - 0.001267 sec
step 4.0 - 0.001303 sec
step 5.0 - 0.002055 sec
step 6.0 - 0.001323 sec
step 7.0 - 0.001637 sec
step 8.0 - 0.001144 sec
step 9.0 - 0.001240 sec
step 10.0 - 0.001230 sec
step 11.0 - 0.001466 sec
step 12.0 - 0.001095 sec
step 13.0 - 0.001108 sec
step 14.0 - 0.000987 sec
step 15.0 - 0.001495 sec
step 16.0 - 0.001531 sec
step 17.0 - 0.001516 sec
step 18.0 - 0.001603 sec
step 19.0 - 0.001683 sec
step 20.0 - 0.001124 sec
step 399980.0 - 0.000073 sec
step 399981.0 - 0.000141 sec
step 399982.0 - 0.000129 sec
step 399983.0 - 0.000128 sec
step 399984.0 - 0.000135 sec
step 399985.0 - 0.000107 sec
step 399986.0 - 0.000098 sec
step 399987.0 - 0.000210 sec
step 399988.0 - 0.000109 sec
step 399989.0 - 0.000100 sec
step 399990.0 - 0.000176 sec
step 399991.0 - 0.000105 sec
step 399992.0 - 0.000107 sec
step 399993.0 - 0.000185 sec
step 399994.0 - 0.000108 sec
step 399995.0 - 0.000099 sec
step 399996.0 - 0.000

### Jitted Simulation

Note, jit is set to true.

This simulation is faster than the non-jitted simulation.

In [15]:
// Complex Mass-Spring System simulation

fun simulation2() {

    val time = 100.0f
    val jit = true
    
    val system = SystemBuild.fourMassThreeSpring()
                           
    val solver = ForwardEulerSolver()
    
    val elapsed = kotlin.system.measureTimeMillis { 
        val observations = solver.integrate(system, time, jit)
    }
    
    println("Total Elaspsed Time - ${elapsed / 1000.0} sec")
 
}

simulation2()

step 0.0 - 0.287756 sec
step 1.0 - 0.002216 sec
step 2.0 - 0.000245 sec
step 3.0 - 0.000268 sec
step 4.0 - 0.000172 sec
step 5.0 - 0.000222 sec
step 6.0 - 0.000182 sec
step 7.0 - 0.000209 sec
step 8.0 - 0.000214 sec
step 9.0 - 0.000244 sec
step 10.0 - 0.000194 sec
step 11.0 - 0.000185 sec
step 12.0 - 0.000305 sec
step 13.0 - 0.000231 sec
step 14.0 - 0.000189 sec
step 15.0 - 0.000165 sec
step 16.0 - 0.000334 sec
step 17.0 - 0.000323 sec
step 18.0 - 0.000281 sec
step 19.0 - 0.000277 sec
step 20.0 - 0.000370 sec
step 399980.0 - 0.000081 sec
step 399981.0 - 0.000155 sec
step 399982.0 - 0.000130 sec
step 399983.0 - 0.000116 sec
step 399984.0 - 0.000087 sec
step 399985.0 - 0.000063 sec
step 399986.0 - 0.000060 sec
step 399987.0 - 0.000091 sec
step 399988.0 - 0.000115 sec
step 399989.0 - 0.000103 sec
step 399990.0 - 0.000167 sec
step 399991.0 - 0.000119 sec
step 399992.0 - 0.000107 sec
step 399993.0 - 0.000138 sec
step 399994.0 - 0.000114 sec
step 399995.0 - 0.000096 sec
step 399996.0 - 0.000

## Discussion

The results show that the simulation using the jitted function was faster than the simulation using the non-jitted function.

The purpose of this notebook was to show how the use the jit API. The jit API is easy to use but there are a number of tips to use it sucessfully: 1) the hashCode() and equals() need to be able to uniquely identify the object, 2) Wrap all the variable in an object that inherits from `Differentiable` for the fastest jit code, 3) Make sure the function that is jitted, and all the code it calls, are pure functions with no side effects, 4) and that the jitted functions only have vals and not vars for variable declarations.



## The End