<a href="https://colab.research.google.com/github/itsvismay/Trying-Differentiable-Physics-in-Jax/blob/main/3_Jax_Optmizing_Over_Physics_Based_Loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%matplotlib inline
import jax.numpy as jp
from jax import grad, jit, vmap
from jax.lax import scan
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
from functools import partial

# Plotting tools
def plot(state):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    # NOTE: there is no .set_data() for 3 dim data...
    nodes = state#-1 value is inferred
    ax.plot(nodes[:,0], nodes[:,1], 'go-')
    ax.set_aspect('equal')
    ax.grid()
    plt.show()
    
def animate(states):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    springs, = ax.plot([], [], 'go-')
    def update(i):
        # NOTE: there is no .set_data() for 3 dim data...
        nodes = states[i,:]#-1 value is inferred
        springs.set_data(nodes[:,0], nodes[:,1])
        return springs,
    
    anim = animation.FuncAnimation(fig, update, states.shape[0], interval=100)
    plt.close()
    return anim


# rest nodes x,y,z//x,y,z
X0 = jp.array([[0.,5.,0.],
                [0.,3.,0.],
                [5,5,0]])
# edges,
E = jp.array([[0,1],
             [1,2]]) #n1,n2,
mesh_obj = {"verts": X0, "edges": E}

class WorldPhysicsState:
    g : jp.ndarray = jp.array([[0.0,-9.8,0.0]])
    dt: float = 0.1
        
class SpringPhysicsState:
    L0: jp.ndarray
    X0: jp.ndarray
    E: jp.ndarray
    invM: jp.ndarray
    pin: jp.ndarray
    #physics variables
    x: jp.ndarray
    k: jp.ndarray
    v: jp.ndarray
    states: list

def initWorld():
    return WorldPhysicsState()
    
def initSpring(mesh_obj):
    springs = SpringPhysicsState()
    # Physics constants
    springs.X0 = mesh_obj["verts"]
    springs.E = mesh_obj["edges"]
    springs.L0 = jp.linalg.norm(springs.X0[springs.E[:,1]] - springs.X0[springs.E[:,0]], axis=1)
    springs.M = 0.1*jp.ones(springs.X0.shape[0]) # mass per vertex

    # Pinned verts indices: TODO: better way to do this (not by hand)
    springs.pin = jp.array([0, 2])

    # Physics variables
    springs.x = springs.X0.copy()
    springs.k = jp.ones(springs.E.shape[0]) # stiffness per spring (element), TODO: turn into variable
    springs.v = 0*springs.x
    
    return springs

world = initWorld()
spring_target = initSpring(mesh_obj)

def mesh_energy(X, K, L0, Ele, M):
  def spring_energy(X,K, L0):
    return 0.5*K*(jp.linalg.norm(X[0,:] - X[1,:])-L0)*(jp.linalg.norm(X[0,:] - X[1,:])-L0)

  def ele_energy(X, K, L0):
    return spring_energy(X,K,L0)

  def vert_energy(X, M):
      return M*X[1]*9.81
  #Run mesh energy
  s = vmap(ele_energy)(X[Ele,:], K, L0)
  v = vmap(vert_energy)(X, M)
  return jp.sum(s) + jp.sum(v)

e = mesh_energy(spring_target.x,
                spring_target.k, 
                spring_target.L0,
                spring_target.E,  
                spring_target.M)

grad_fcn = jit(grad(mesh_energy))

def simulate(kL0, spring, grad_fcn, dt, sim_steps):
    def step(carry, accum, K, L0, E, M, pin, dt):
      x,v = carry
      forces = -1*grad_fcn(x, K, L0, E, M)
      forces = forces.at[pin].set(0)
      v_new = v + dt*(forces/M) - 0.1*v
      x_new = x + v_new * dt
      carry = (x_new, v_new)
      accum = x_new
      return (carry, accum)

    pstep = partial(step, 
                    K=kL0[0], 
                    L0=kL0[1], 
                    E= spring_target.E, 
                    M = spring_target.M, 
                    pin=spring_target.pin, 
                    dt = dt)

    carry = (spring.x, spring.v)
    outputs = []
    carry, outputs = scan(pstep, 
                          carry, 
                          xs=None, 
                          length=sim_steps)
    
    return outputs

In [None]:
@jit
def physics_loss(kL0):
    states = simulate(kL0, spring_target, grad_fcn, world.dt, 200)
    #loss = the distance of the bottom vertex from y=0 at the final state
    loss = jp.linalg.norm(states[-1,1,1])
    return loss

grad_physics_loss = grad(physics_loss)

optVar_k,optVar_L0 = spring_target.k, spring_target.L0
optVars = (optVar_k,optVar_L0)
#Crappy gradient descent with fixed step size
for _ in range(100):
    dk,dL0 = grad_physics_loss(optVars)
    optVar_k -= dk*0.01
    optVar_L0 -= dL0*0.01
    optVars = (optVar_k,  optVar_L0)

In [None]:
states = simulate(optVars, spring_target, grad_fcn, world.dt, 200)
anim = animate(states)
HTML(anim.to_html5_video())