<a href="https://colab.research.google.com/github/itsvismay/Trying-Differentiable-Physics-in-Jax/blob/main/1_Jax_Basic_Mass_Springs_Simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Starting With A Basic Mass-Sring Simulation**
Here is a preliminary attempt at creating a simple mass-spring simulation using Jax for auto-differentiation and matplotlib for display.*italicized text*

## Imports
Import everything I need (possibly some things I don't end up using). The first line 

```
%matplotlib inline
```
keeps the display tidy in this colab notebook.



In [None]:
%matplotlib inline
import jax.numpy as jp
from jax import grad, jit, vmap
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
from functools import partial

## Create the Mass-Spring System
I will simulate a super simple mesh with nodes (masses) and edges (springs). Create a mesh object because I might want to use real .obj files in the future.

In [None]:
# rest nodes x,y,z//x,y,z
X0 = jp.array([[0.,5.,0.],
                [0.,3.,0.],
                [5,5,0]])
# edges,
E = jp.array([[0,1],
             [1,2]]) #n1,n2,
mesh_obj = {"verts": X0, "edges": E}

## Trying to use states - Jax friendly
Jax is very function, so system parameters will be stored in states. For reference check out [stateful computations in jax](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)

In [None]:
# World physics defines gravity and timestep, using the fancy new python annotations
class WorldPhysicsState:
    g : jp.ndarray = jp.array([[0.0,-9.8,0.0]])
    dt: float = 0.1

#Spring physics defines all the spring parameters        
class SpringPhysicsState:
    L0: jp.ndarray #rest length of springs
    X0: jp.ndarray #rest pose verts
    E: jp.ndarray #edges between nodes
    M: jp.ndarray # masses at each node
    pin: jp.ndarray #pinned vertices, indexed into X0

    x: jp.ndarray #current vert positions
    k: jp.ndarray #stiffness coefficients for each spring
    v: jp.ndarray #current vert velocities
    states: list #stores the history of motion (might not use)

def initWorld():
    return WorldPhysicsState()
    
def initSpring(mesh_obj):
    springs = SpringPhysicsState()
    # Physics constants
    springs.X0 = mesh_obj["verts"]
    springs.E = mesh_obj["edges"]
    springs.L0 = jp.linalg.norm(springs.X0[springs.E[:,1]] - springs.X0[springs.E[:,0]], axis=1)
    springs.M = 0.1*jp.ones(springs.X0.shape[0]) # mass per vertex

    # Pinned verts indices: TODO: better way to do this (not by hand)
    springs.pin = jp.array([0, 2])

    # Physics variables
    springs.x = springs.X0.copy()
    springs.k = jp.ones(springs.E.shape[0]) # stiffness per spring (element), TODO: turn into variable
    springs.v = 0*springs.x #zero initial v
    
    return springs

world = initWorld()
spring_target = initSpring(mesh_obj)

## Defining Energies and Gradients
Spring energy is 0.5*k*(L - L0)^2, 
Gravity energy is M*g*h

Using the [vmap](https://jax.readthedocs.io/en/latest/jax.html?highlight=vmap#vectorization-vmap) function to vectorize computation on all elements in parallel. 


In [None]:
def spring_energy(X,K, L0):
    return 0.5*K*(jp.linalg.norm(X[0,:] - X[1,:])-L0)*(jp.linalg.norm(X[0,:] - X[1,:])-L0)

def ele_energy(X, K, L0):
    return spring_energy(X,K,L0)

def vert_energy(X, M):
    return M*X[1]*9.81

def mesh_energy(X, K, L0, Ele, M, ele_en, ver_en):
    s = vmap(ele_en)(X[Ele,:], K, L0)
    v = vmap(ver_en)(X, M)
    return jp.sum(s) + jp.sum(v)

e = mesh_energy(spring_target.x,
                spring_target.k, 
                spring_target.L0,
                spring_target.E,  
                spring_target.M, 
                ele_energy, vert_energy)

grad_fcn = grad(mesh_energy)

## Writing the simulation
Simulate the springs for a fixed number of simulation time steps using verlet integration. Set the damping to a simple -0.1*v and at each timestep save the spring's position to a state array for future animation.  

In [None]:
def simulate(spring, grad_fcn, dt, sim_steps):
    outputs = []
    x,v = (spring.x, spring.v)
    outputs.append(x)
    for _ in range(sim_steps):
        forces = -1*grad_fcn(x, spring.k, spring.L0, spring_target.E, spring_target.M, ele_energy, vert_energy)
        forces = forces.at[spring_target.pin].set(0)
        v_new = v + dt*(forces/spring.M) - 0.1*v
        x_new = x + v_new * dt
        x = x_new
        v = v_new
        outputs.append(x)
    return jp.stack(outputs)

## Simulate
Simulate for 200 time steps and record the wall-clock time. This is the time it takes to simulate my system pre-optimization. In the next doc, I will optimize some of the code.

In [None]:
%time states = simulate(spring_target, grad_fcn, world.dt, 200)


CPU times: user 9.27 s, sys: 1.06 s, total: 10.3 s
Wall time: 8.93 s


## Plot using matplotlib
The plot function will display a single state. The animate function will display a video of the simulation.

In [None]:


# Plotting tools
def plot(state):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    nodes = jp.reshape(state["x"], (-1, 3))#-1 value is inferred
    ax.plot(nodes[:,0], nodes[:,1], 'go-')
    ax.set_aspect('equal')
    ax.grid()
    plt.show()
    
def animate(states):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    springs, = ax.plot([], [], 'go-')
    def update(i):
        # NOTE: there is no .set_data() for 3 dim data...
        nodes = states[i,:]#-1 value is inferred
        springs.set_data(nodes[:,0], nodes[:,1])
        return springs,
    
    anim = animation.FuncAnimation(fig, update, states.shape[0], interval=100)
    plt.close()
    return anim

Lets see the animation

In [None]:
anim = animate(states)
HTML(anim.to_html5_video())

## Next
Speeding up this code