# Mass-Spring Notebook

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

## Introduction

The Mass-Spring Notebook shows how to use automatic differentiation, using **DiffKt**, with user-defined types applied to a  mass-spring differential equation model. Mass-Spring systems are one of the models used to represent flexible objects in computer graphics, such as cloth. The Mass-Spring Notebook is an example of a complex system simulation using user-defined types. In this notebook it will be demonstrated how to have a list of simulation object, where each object is a composite of multiple objects that can be user-defined types.

### Housekeeping

In [1]:
@file:DependsOn("../kotlin/api/build/libs/api.jar")

In [2]:
%useLatestDescriptors
%use lets-plot

### Imports

In [3]:
import org.diffkt.*
import java.util.Vector
import jetbrains.letsPlot.intern.Plot

## The Model

There are two mass-spring models presented in the notebook. One is a simple mass-spring model, the other is a more complex mass-spring model. 

The simple mass-spring model has one fixed mass, a spring, and a movable mass that hangs on the bottom of the spring. The simple mass-spring model is centered at x = 0, and the movable mass only moves on the y axis. The following is a picture of the simple mass-spring model.

![Simple Mass-Spring Model](resources/simple_mass_spring.png)

The complex mass-spring model has four masses, two non-movable masses and two moveable masses, and three springs. There are two non-moveable masses connected by springs to a movable mass. Below the movable mass is a spring and a second movable mass. The movable masses are initially place to the left, so motion results in both the x and y axis. The folowing is a picture of the complex mass-spring model.

![Complex Mass-Spring Model](resources/complex_mass_spring.png)




### Constants

In [4]:
// Constants

object Constants {
   
// Physical Properties
    
    val defaultMass = 1f
    val defaultRestingLength = 1f
    val defaultDampen = .9995f 
    val ground = 0f
    val gravityConstant = 9.8f
    val hooksConstant = 600f
    val velocityScaling = 10
    val accelerationScaling = 100
    val dt = .00025f
    
// Plotting 
    
    val xGraphTitle =             "X vs Time"
    val xPositionAxisLable =      "X Position"
    val xVelocityAxisLable =      "X Velocity"       
    val xAccelerationAxisLable =  "X Acceleration"
        
    val yGraphTitle =             "Y  vs Time"
    val yPositionAxisLable =      "Y Position"
    val yVelocityAxisLable =      "Y Velocity"       
    val yAccelerationAxisLable =  "Y Acceleration"
    val timeAxisLable =           "Time"
    
}

### The Vector2 Coordinates

The `Vector2` class is used to hold the values of the x and y position coordinates. It is also use for velocity, acceleration, and the derivatives of the position. This class demonstrate using operator overloading with complex data types.

The `Vector2` class inherits from the Differentiable<T> interface. The `Vector2` class wraps both the x and y variables.

In [5]:
// class Vector2

open class Vector2(val x: DScalar, val y: DScalar) : Differentiable<Vector2> {
    
    constructor(x: Float, y:Float): this(FloatScalar(x), FloatScalar(y))
    
    override fun wrap(wrapper: Wrapper) : Vector2 {
        return Vector2(wrapper.wrap(x), wrapper.wrap(y))
    }
    
    operator fun unaryMinus(): Vector2 {
        return Vector2(-x, -y)
    }
    
    operator fun plus(b: Vector2): Vector2 {
        return Vector2(x + b.x, y + b.y)
    }
    
    operator fun minus(b: Vector2): Vector2 {
        return Vector2(x - b.x, y - b.y)
    }
    
    operator fun times(c: Float) : Vector2 {
        return Vector2(c * x, c * y)
    }
    
    operator fun div(c: Float) : Vector2 {
        return Vector2(x / c, y / c)
    }
    
    fun norm(): DScalar {
        return (x * x + y * y).pow(0.5f)
    }
}

### Spring

A `Spring` is connected to two masses. Every `Spring` object has an `id`. A `Spring` object connects to a left `Mass` object and a right `Mass` object, regardless of the actual orientation. A `Spring` object has a resting length set to the `defaultRestingLength` constant when the `MassSpringSystem` is constructed.

The `Spring` class is not a differentiable class in **DiffKt**

In [6]:
// Spring

data class Spring(val id : Int,
                  val leftMassId : Int,
                  val rightMassId : Int,
                  val restingLength : Float)

### Mass

Every `Mass` object has an `id`. A Mass object can be `movable` or non-`movable`. A `Mass` object has a `mass`. A `Mass` object has a list of `Spring` objects that connect it to other `Mass` objects. The `Mass` object is evaluated at a particular time. At the time evaluated, it has a position, velocity, and acceleration for both the x and y axis. For a non-movable `Mass` object, the position is set at the initalization of the `Mass` object, and the velocity and acceleration are zero. 

The class `Mass` inherits from the `Differentiable<T>` interface. The position of a `Mass` object is a differentiable variable for **DiffKt**. The velocity and acceleration are not differentiable variables. A `Mass` object can be initialized with either `DScalars` or `Floats`. Since the class `Mass` inherits from the `Differentiable<T>` interface, it overides the `wrap` function, and calls `wrap` on the differentiable variables `xPosition` and `yPosition`. The other variables are ignored by **DiffKt** because they are not wrapped.

In [7]:
// Mass

class Mass( val id : Int,
            val time : Float = 0f,
            val movable : Boolean,
            val mass : Float,
            val springList : List<Spring>, 
            val position : Vector2 = Vector2(0f, 0f),                   
            val velocity : Vector2 = Vector2(0f, 0f),
            val acceleration : Vector2 = Vector2(0f, 0f)) : Differentiable<Mass> {
  
    // Alternative Constructor
    
    constructor(id : Int,
                time : Float = 0f,
                movable : Boolean,
                mass : Float,
                springList : List<Spring>,
                xPosition : Float = 0f,
                xVelocity : Float = 0f,
                xAcceleration : Float = 0f,
                yPosition : Float = 0f,
                yVelocity : Float = 0f,
                yAcceleration : Float = 0f) : this ( id,
                                                time,
                                                movable,
                                                mass,
                                                springList,
                                                Vector2(xPosition, yPosition),
                                                Vector2(xVelocity, yVelocity),
                                                Vector2(xAcceleration, yAcceleration))
                                                
    
    override fun wrap(wrapper : Wrapper) : Mass {
        return Mass(id,
                    time,
                    movable,
                    mass,
                    springList,
                    wrapper.wrap(position), // wrapped
                    velocity,
                    acceleration)
    }
       
}

### Energy Calculations

There are two forces acting on a mass, the force of gravity and the force of the springs attached to the mass. The force of gravity only acts on the y axis of a mass. The spring force is modeled as a Hookean spring force. The spring force can act in both the x and y axes. The forces are only applied if the mass is a movable mass. If the masses are non-movable, there are no forces or energy calculated. 

The energy from these forces is the integral of the forces. The energies for a mass are additive.

#### Gravitational Energy
The equation for the gravity energy is

$energy = (y - y_0) * g * mass$

where $y_0$ is the ground, which is zero, and $g$ is the force of gravity, 9.8N.

#### Spring Energy
The equation for the spring energy is more complicated, a spring energy is calculated for each spring connected to a mass.

$energy = 0.5 * h * \Delta l^2$

where $h$ is Hooks constant, and $\Delta l$ = the length between the two masses minus the resting length. Note, the length of the spring can not be less than the resting length of the spring and $\Delta l$ can not be less than zero.

#### Total Energy
The total energy is the gravitational energy of a mass plus the spring energy of each spring connected to the mass.

In [8]:
// Energy Calculations

object EnergyCalculations {
    
    // Gravity Energy 
    
    fun gravitationalEnergy(id : Int, ml : List<Mass>) : DScalar {
        
        var energy = FloatScalar(0f) as DScalar
        if (ml[id].movable == true) {
            energy = (ml[id].position.y - Constants.ground) * Constants.gravityConstant * ml[id].mass
        }
        
 
        return energy
    }

    // Spring Energy
    
    fun springEnergy( id : Int, ml : List<Mass>) : DScalar {

        var energy = FloatScalar(0f) as DScalar
        if (ml[id].movable == true) {
            energy = ml[id].springList.fold(FloatScalar(0f) as DScalar) {
                accumulatedEnergy, spring ->
                
                    val leftMassPosition = ml[spring.leftMassId]
                    val rightMassPosition = ml[spring.rightMassId]
            
                    val deltaPosition = leftMassPosition.position - rightMassPosition.position
                    val length = deltaPosition.norm()
                    
                    // Note the spring can not become shorter than it resting length
                    
                    val deltaSpring = length - spring.restingLength
                    val deltaLength = if (deltaSpring < 0f) {FloatScalar(0f)} else {deltaSpring}
                 
        
                    val potentialEnergy =  0.5f * Constants.hooksConstant * deltaLength.pow(2f)            
                    accumulatedEnergy + potentialEnergy
                }
        }       
        return energy
    }

    // Total Energy
    
    fun totalEnergy(id : Int, ml : List<Mass>) : DScalar {
                
        val gEnergy = gravitationalEnergy(id, ml)
        val sEnergy = springEnergy(id, ml)  
        val total = gEnergy + sEnergy

        return total
    }
}

### The Mass-Spring System

The `MassSpringSystem` class contains a list of all `Mass` objects in the system, with their associated `Spring` objects. The `MassSpringSystem` class implements the `Differentiable<T>` interface. Since the `MassSpringClass` inherits from the `Differentiable<T>` interface, it overrides the `wrap` function and calls `wrap` on a `List` of `Mass` objects. `List` implements the `Differentiable<T>` interface in **DiffKt** and is an internal data structure in **DiffKt**. `Mass` was written to implement the `Differentiable<T>` interface.
    
The `totalSystemEnergy` is the sum of all the energies associated with every `Mass`.

In [9]:
// MassSpringSystem class

class MassSpringSystem(val massList : List<Mass>) : Differentiable<MassSpringSystem> {
    
    override fun wrap(wrapper : Wrapper) : MassSpringSystem {
        return MassSpringSystem(wrapper.wrap(massList))
    }
         
    fun totalSystemEnergy(massList : List<Mass>) : DScalar {
           
        // Sums the energy of each Mass
        val totalSystemEnergy =  massList.fold(FloatScalar(0f) as DScalar) 
            {accumulatedEnergy, mass ->
                accumulatedEnergy + EnergyCalculations.totalEnergy(mass.id, massList)}
        
        return totalSystemEnergy            
    }
}

### Derivatives

The x and y derivatives of the position of a `Mass` are kept in the instances `MassDerivatives` class. The `TotalDerivates` class hold a `List` of `MassDerivatives` to collect all the derivatives from all the `Mass` instances. 

In [10]:
// Derivatives of the x & y position of a single Mass
class MassDerivatives(x : DScalar, y : DScalar) : Vector2(x, y)
    
// Derivatives of all the Masses in the system
data class TotalDerivatives(val d : List<MassDerivatives>)

### Calculating the Derivatives

The `calculateDerivatives()` function calls the `primalAndReverseDerivatives()` function, the automatic differentiation function for user-defined types. The user-define types are the `x` and `y` position, variables in a `Vector2` class, which is also differentiable. The `Mass` instances are user-defined types and differentiable. The `List` that holds the `Mass` instances, and the `MassSpringSystem` instance are differentiable. The function to be differentiated is `system::totalSystemEnergy`, which is the energy function for the system. A custom function `makeDerivative` extracts the derivatives for each `x` and `y` position of a `Mass` instance.


In [11]:
// Calculate the derivatives of the energy of the system

class CalculateDerivatives(val system : MassSpringSystem) {
    
    // Extract the derivatives of each variable in the system
    
    fun makeDerivatives(input : List<Mass>, output : DScalar, 
                        extractDerivative: (DTensor, DTensor) -> DTensor) : TotalDerivatives {
        
        val listMassDerivatives = input.map {
            val dmassDx = extractDerivative(it.position.x, output) as DScalar
            val dmassDy = extractDerivative(it.position.y, output) as DScalar
            MassDerivatives(dmassDx, dmassDy)
        }
        
        return TotalDerivatives(listMassDerivatives)
    }
    
    // Calculate the derivatives of the total energy of the system
    
    fun calculateDerivatives() : TotalDerivatives {
        
        return primalAndReverseDerivative(
                    x = system.massList,
                    f = system::totalSystemEnergy,
                    extractDerivative = ::makeDerivatives).second  
    }
    
}

### Numerical Integration of the Diferential Equation

The solution for a Mass-Spring System is solved using the forward Euler method for solving differental equations.

The forward Euler algorithm to solve a differential equation $f()$ is 

$y_{n+1} = y_n + h*f(y_n)$

where
$n$ is the time step, and
$h$ is the time step size.

The `systemAcceleration` is the derivative of the energy of each movable mass.<br>

The `acceleration` is the `-x` and `-y` derivative of a mass energy divided by the weight of the mass.<br>
The `velocity` is updated by adding the current `velocity` to `acceleration` times`dt`. The resulting velocity is then multiplied by a dampening constant.<br>
The `position` is updated by adding the current `position` to `velocity` times `dt`.

The `integrate()` function integrates the equations across time, and saves the system state at each time step. 

In [12]:
// Forward Euler Solver

class ForwardEulerSolver() {
    
    // Solver the differential equaiton for one step
    
    fun integrateOneStep(system: MassSpringSystem, step : Int) : MassSpringSystem {
    
        val grad = CalculateDerivatives(system)
    
        val systemAcceleration = grad.calculateDerivatives()
        
        fun newState(p : Mass, acc : MassDerivatives) : Mass {
                        
            val acceleration = if (p.movable) {-acc / p.mass } else Vector2(0f, 0f)
            val velocity =     if (p.movable) {(p.velocity + (p.acceleration * Constants.dt)) * Constants.defaultDampen } else Vector2(0f, 0f)
            val position =     if (p.movable) {p.position + (p.velocity * Constants.dt)} else p.position
            
            return Mass (id = p.id,
                         time = step * Constants.dt,
                         movable = p.movable,
                         mass = p.mass,
                         springList = p.springList,
                         position = position,
                         velocity = velocity,
                         acceleration = acceleration)
        }  
        
        val massList = system.massList.mapIndexed{idx, value -> newState(value, systemAcceleration.d[idx])}
        
        // returns the new state of the system
        
        return MassSpringSystem(massList)

    }
    
    // Solve the differential equation for a fixed length of time
    
    fun integrate(system : MassSpringSystem, time : Float) : List<MassSpringSystem> {
        
        val steps = (time / Constants.dt).toInt()
        val numSteps = 1..steps
        
        // Save the system state of each step in observations
        val observations = numSteps.runningFold(system) {newSystem, step -> integrateOneStep(newSystem,step) }
        
        return observations
    }
}

### Plotting

Plots the variables `position`, `velocity`, and `acceleration`, for the `x` and `y` coordinates of a `Mass`.

In [13]:
// Plot each variable for each mass

class PlotSimulation(val observations : List<MassSpringSystem>) {
    
    val dataLength = observations.size

   
    // Plot position, velocity, and acceleration for the x variable
    
    fun plotX(id : Int) : Plot {
             
        val dataColor = List(dataLength){Constants.xPositionAxisLable} + 
                        List(dataLength){Constants.xVelocityAxisLable + " / ${Constants.velocityScaling}" } +
                        List(dataLength){Constants.xAccelerationAxisLable + " / ${Constants.accelerationScaling}" }
                        
        val time = observations.map{it.massList[id].time}                
        val xData =  time + time + time
                   
        val yData = observations.map{(it.massList[id].position.x as FloatScalar).value} +
                    observations.map{(it.massList[id].velocity.x as FloatScalar).value / Constants.velocityScaling} +
                    observations.map{(it.massList[id].acceleration.x as FloatScalar).value / Constants.accelerationScaling}
                    
        val title = "Mass ${id} - " + Constants.xGraphTitle
        val xLable = Constants.timeAxisLable
        val yLable = "Value"
        
        val plotData = mapOf<String, Any>(
            "time" to xData,
            "values" to yData,
            "color" to dataColor)
       
        val p = plot(plotData, title, xLable, yLable)
        return p
        
    }
    
    // Plot position, velocity, and acceleration for the y variable
    
    fun plotY(id : Int) : Plot {
        
        val dataColor = List(dataLength){Constants.yPositionAxisLable} + 
                        List(dataLength){Constants.yVelocityAxisLable + " / ${Constants.velocityScaling}" } +
                        List(dataLength){Constants.yAccelerationAxisLable + " / ${Constants.accelerationScaling}"}
    
        val time = observations.map{it.massList[id].time}                
        val xData =  time + time + time                      
                    
        val yData = observations.map{(it.massList[id].position.y as FloatScalar).value} +
                    observations.map{(it.massList[id].velocity.y as FloatScalar).value / Constants.velocityScaling } +
                    observations.map{(it.massList[id].acceleration.y as FloatScalar).value / Constants.accelerationScaling}
                    
        val title = "Mass ${id} - " + Constants.yGraphTitle
        val xLable = Constants.timeAxisLable
        val yLable = "Value"
        
        val plotData = mapOf<String, Any>(
            "time" to xData,
            "values" to yData,
            "color" to dataColor)
        
        val p = plot(plotData, title, xLable, yLable)
        return p
        
    }
    
    
    // Plot function
    
    fun plot(plotData : Map<String,Any>, 
             title : String, 
             xLable : String, 
             yLable : String) : Plot {
        
        val p = letsPlot(plotData) {x="time"; y="values"; color="color";} + 
            ggsize(1000,600) + 
            geomLine(sampling=samplingSystematic(10000)) + 
            ggtitle(title) +
            xlab(xLable) + 
            ylab(yLable)

        return p    
        
    }
    
    fun plotAll() {
        
        fun plotXandY(id : Int) {
            GGBunch().addPlot(plotX(id), 0, 0)
                     .addPlot(plotY(id), 0, 600)
                     .show()
        }
        
        repeat(observations[0].massList.size) {index -> plotXandY(index)}
        
    }
}


### Simple Mass-Spring System

![Simple Mass-Spring System](resources/simple_mass_spring.png)

A simple mass-spring system with a non-movable top mass, a second movable mass below, and a spring in-between. There is only movement on the y axis and no movement on the x axis.

In [14]:
// Simple Mass-Spring System

object SystemBuild1 {
    
    fun twoMassOneSpring() : MassSpringSystem {

        val spring0 = Spring(id = 0,
                             leftMassId = 0,
                             rightMassId = 1,
                             restingLength = Constants.defaultRestingLength)

        val springList = listOf(spring0)

        val mass0 = Mass(id = 0,    
                         movable = false,
                         mass = Constants.defaultMass,
                         springList = springList,
                         yPosition = 20f)
                         

        val mass1 = Mass(id = 1,
                        movable = true,
                        mass = Constants.defaultMass,
                        springList = springList,
                        yPosition = 15f)
     
        return MassSpringSystem(listOf(mass0, mass1))     
    }
}

### Simulation of Simple Mass-Spring System

The simulation of the simple mass-spring system is for five seconds. First, a simple mass-spring system is created. Next, the solver. Lastly, the solver is integrated for five seconds, and the results are plotted. `dt` is set in `Constants` for a 0.00025 second step size.

In [15]:
// Simple Mass-Spring System simulation

fun simulation1() {
    
    val time = 5f
    
    val system = SystemBuild1.twoMassOneSpring()
                           
    val solver = ForwardEulerSolver()
    val observations = solver.integrate(system, time)
 
    PlotSimulation(observations).plotAll()      
}

simulation1()

### Observations from the Plots

`Mass0` is fixed and does not move, and has zero velocity and acceleration. The position of `Mass0` is its initial position of `x` = 0 and `y` = 2-

`Mass1` shows a dampened sinusoidal motion in the `y` axis. `Mass1` does not move in the x axis. The acceleration shows a hard stop at zero for a period of time because the `Spring` can not be compressed beyond its resting length. When the acceleration is zero, the velocity is constant but because this is a dampened system, the velocity decays.

### Complex Mass-Spring System

![Complex Mass-Spring System](resources/complex_mass_spring.png)

The complex mass-spring system has four masses and three springs. The top two masses are non-movable. The bottom two masses are movable. The system is initialized with the two movable masses pulled to the left to cause motion in the x axis, as well as the y axis.

In [16]:
// Complex Mass-Spring System

object SystemBuild2 {
    
    fun fourMassThreeSpring() : MassSpringSystem {
                
        val spring0 = Spring(id = 0,
                             leftMassId = 0,
                             rightMassId = 2,
                             restingLength = Constants.defaultRestingLength)
    
 
        
        val spring1 = Spring(id = 1,
                             leftMassId = 1,
                             rightMassId = 2,
                             restingLength = Constants.defaultRestingLength)
        
 
        
        val spring2 = Spring(id = 2,
                             leftMassId = 2,
                             rightMassId = 3,
                             restingLength = Constants.defaultRestingLength)
                
        val springList0 = listOf(spring0)
        
        val springList1 = listOf(spring1)
        
        val springList2 = listOf(spring0, spring1, spring2)
        
        val springList3 = listOf(spring2)
        


        val mass0 = Mass(id = 0,    
                         movable = false,
                         mass = Constants.defaultMass,
                         springList = springList0,
                         yPosition = 20f)

        val mass1 = Mass(id = 1,    
                         movable = false,
                         mass = Constants.defaultMass,
                         springList = springList1,
                         xPosition = 5f,    
                         yPosition = 20f )          
        
        val mass2 = Mass(id = 2,    
                         movable = true,
                         mass = Constants.defaultMass,
                         springList = springList2,
                         yPosition = 16f)
        
        val mass3 = Mass(id = 3,    
                         movable = true,
                         mass = Constants.defaultMass,
                         springList = springList3,
                         yPosition = 12f)



        

        return MassSpringSystem(listOf(mass0, mass1, mass2, mass3))
        
    }
}

### Simulation of Complex Mass-Spring System

The simulation of the complex mass-spring system is for five seconds. First, a complex mass-spring system is created. Next, a solver. Lastly, the solver is integrated for five seconds, and the results are plotted. `dt` is set in the constants for a 0.00025 second step size.

In [17]:
// Complex Mass-Spring System simulation

fun simulation2() {

    val time = 5f
    
    val system = SystemBuild2.fourMassThreeSpring()
                           
    val solver = ForwardEulerSolver()
    val observations = solver.integrate(system, time)
 
    PlotSimulation(observations).plotAll()
}

simulation2()

### Observations for the Plots

`Mass0` and `Mass1` are non-movable. Their position is fixed and they have zero velocity and acceleration.

`Mass2` and `Mass3` are movable. They are moving in both the x and y planes. The motion is a dampened chaotic motion. The acceleration shows a hard stop at zero for a period of time because the `Spring` can not be compressed beyond its resting length. When the acceleration is zero, the velocity is constant but because this is a dampened system, the velocity decays.

## Conclusions

The purpose of this tutorial was to show how simulations could be built from complex data structures using the user-defined types in **DiffKt**. In both simulations a list of `Mass` objects with `DScalar` variables were built, in which a differentiable energy calculation could be calculated over the variables and used in a differential equation solver.

## The End