<a href="https://colab.research.google.com/github/itsvismay/Trying-Differentiable-Physics-in-Jax/blob/main/4_Jax_Matching_A_Spring's_Motion_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%matplotlib inline
import jax.numpy as jp
from jax import grad, jit, vmap
from jax.lax import scan
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
from functools import partial

# Plotting tools
def plot(state):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    # NOTE: there is no .set_data() for 3 dim data...
    nodes = state#-1 value is inferred
    ax.plot(nodes[:,0], nodes[:,1], 'go-')
    ax.set_aspect('equal')
    ax.grid()
    plt.show()
    
def animate(states):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    springs, = ax.plot([], [], 'go-')
    def update(i):
        # NOTE: there is no .set_data() for 3 dim data...
        nodes = states[i,:]#-1 value is inferred
        springs.set_data(nodes[:,0], nodes[:,1])
        return springs,
    
    anim = animation.FuncAnimation(fig, update, states.shape[0], interval=100)
    plt.close()
    return anim

def animate_both(states1, states2):
    fig = plt.figure()
    ax = fig.add_subplot(autoscale_on=False, xlim=(-1, 10), ylim=(-1, 10))
    springs1, = ax.plot([], [], 'go-')
    springs2, = ax.plot([], [], 'bo-')
    def update(i):
        # NOTE: there is no .set_data() for 3 dim data...
        nodes1 = states1[i,:]#-1 value is inferred
        springs1.set_data(nodes1[:,0], nodes1[:,1])
        nodes2 = states2[i,:]#-1 value is inferred
        springs2.set_data(nodes2[:,0], nodes2[:,1])
        return springs1,springs2
    
    anim = animation.FuncAnimation(fig, update, states.shape[0], interval=100)
    plt.close()
    return anim


# rest nodes x,y,z//x,y,z
X0 = jp.array([[0.,5.,0.],
                [0.,3.,0.],
                [5,5,0]])
# edges,
E = jp.array([[0,1],
             [1,2]]) #n1,n2,
mesh_obj = {"verts": X0, "edges": E}

# rest nodes x,y,z//x,y,z
X0_2 = jp.array([[0.,5.,0.],
                [0.,0.,0.],
                [5,5,0]])
mesh_obj2 = {"verts": X0_2, "edges": E}

class WorldPhysicsState:
    g : jp.ndarray = jp.array([[0.0,-9.8,0.0]])
    dt: float = 0.1
        
class SpringPhysicsState:
    L0: jp.ndarray
    X0: jp.ndarray
    E: jp.ndarray
    invM: jp.ndarray
    pin: jp.ndarray
    #physics variables
    x: jp.ndarray
    k: jp.ndarray
    v: jp.ndarray
    states: list

def initWorld():
    return WorldPhysicsState()
    
def initSpring(mesh_obj):
    springs = SpringPhysicsState()
    # Physics constants
    springs.X0 = mesh_obj["verts"]
    springs.E = mesh_obj["edges"]
    springs.L0 = jp.linalg.norm(springs.X0[springs.E[:,1]] - springs.X0[springs.E[:,0]], axis=1)
    springs.M = 0.1*jp.ones(springs.X0.shape[0]) # mass per vertex

    # Pinned verts indices: TODO: better way to do this (not by hand)
    springs.pin = jp.array([0, 2])

    # Physics variables
    springs.x = springs.X0.copy()
    springs.k = jp.ones(springs.E.shape[0]) # stiffness per spring (element), TODO: turn into variable
    springs.v = 0*springs.x
    
    return springs

world = initWorld()
spring_target = initSpring(mesh_obj)

def mesh_energy(X, K, X0, Ele, M):
  def spring_energy(X,K, X0):
    return 0.5*K*(jp.linalg.norm(X[0,:] - X[1,:])-jp.linalg.norm(X0[0,:] - X0[1,:]))*(jp.linalg.norm(X[0,:] - X[1,:])-jp.linalg.norm(X0[0,:] - X0[1,:]))

  def ele_energy(X, K, X0):
    return spring_energy(X,K,X0)

  def vert_energy(X, M):
      return M*X[1]*9.81
  #Run mesh energy
  s = vmap(ele_energy)(X[Ele,:], K, X0[Ele,:])
  v = vmap(vert_energy)(X, M)
  return jp.sum(s) + jp.sum(v)


grad_fcn = jit(grad(mesh_energy))

def simulate(optVar, spring, grad_fcn, dt, sim_steps):
    def step(carry, accum, K, X0, E, M, pin, dt):
      x,v = carry
      forces = -1*grad_fcn(x, K, X0, E, M)
      forces = forces.at[pin].set(0)
      v_new = v + dt*(forces/M) - 0.1*v
      x_new = x + v_new * dt
      carry = (x_new, v_new)
      accum = x_new
      return (carry, accum)

    pstep = partial(step, 
                    K=optVar[0], 
                    X0=optVar[1], 
                    E= spring.E, 
                    M = spring.M, 
                    pin=spring.pin, 
                    dt = dt)

    carry = (optVar[1], spring.v)
    outputs = []
    carry, outputs = scan(pstep, 
                          carry, 
                          xs=None, 
                          length=sim_steps)
    
    return outputs

Create a New Spring (Green) and simulate both the target(blue) and new springs.

In [None]:
spring_target_states = simulate((spring_target.k, spring_target.X0), spring_target, grad_fcn, world.dt, 100)
# create new spring
spring_current = initSpring(mesh_obj2)
# change up its stiffness
spring_current.k = 0.25*spring_current.k

initial_states = simulate((spring_current.k, spring_current.X0), spring_current, grad_fcn, world.dt, 100)
HTML(animate_both(initial_states, spring_target_states).to_html5_video())

In [None]:
@jit
def simulation_match_loss(opt_var, target_states):
    states = simulate(opt_var, spring_current, grad_fcn, world.dt, 100)
    #loss = the difference at each timestep between target states and current states
    loss = jp.linalg.norm(target_states - states)
    return loss

grad_simulation_match_loss = grad(simulation_match_loss, argnums=0)

In [None]:
ov_k, ov_X0 = spring_current.k,spring_current.X0
optVars = (ov_k, ov_X0)
#Crappy gradient descent with fixed step size
for _ in range(200):
    dk,dX0 = grad_simulation_match_loss(optVars, spring_target_states)
    dX0 = dX0.at[spring_current.pin].set(0)
    ov_k -= dk*0.01
    ov_X0 -= dX0*0.01
    optVars = ov_k, ov_X0
    print("loss: "+str(simulation_match_loss(optVars, spring_target_states)))


loss: 28.995747
loss: 27.392427
loss: 25.937677
loss: 24.579266
loss: 23.287758
loss: 22.044804
loss: 20.838396
loss: 19.6605
loss: 18.505585
loss: 17.369783
loss: 16.250349
loss: 15.145308
loss: 14.053273
loss: 12.97332
loss: 11.904969
loss: 10.848172
loss: 9.803411
loss: 8.771848
loss: 7.7556205
loss: 6.758377
loss: 5.7862835
loss: 4.849889
loss: 3.9677947
loss: 3.173218
loss: 2.522156
loss: 2.0842235
loss: 1.8744988
loss: 1.8032478
loss: 1.7782508
loss: 1.7645779
loss: 1.7538676
loss: 1.7441736
loss: 1.7349072
loss: 1.725809
loss: 1.716742
loss: 1.707622
loss: 1.698396
loss: 1.6890308
loss: 1.6794987
loss: 1.6697834
loss: 1.6598693
loss: 1.6497401
loss: 1.6393837
loss: 1.6287875
loss: 1.6179373
loss: 1.6068197
loss: 1.5954221
loss: 1.5837288
loss: 1.5717242
loss: 1.5593954
loss: 1.5467257
loss: 1.5336982
loss: 1.520299
loss: 1.5065099
loss: 1.492312
loss: 1.4776905
loss: 1.4626256
loss: 1.4470999
loss: 1.4310955
loss: 1.4145913
loss: 1.3975722
loss: 1.380017
loss: 1.3619075
loss: 1.

In [None]:
states = simulate(optVars, spring_current, grad_fcn, world.dt, 100)
anim2 = animate_both(states, spring_target_states)
HTML(anim2.to_html5_video())

In [None]:
states = simulate(optVars, spring_current, grad_fcn, world.dt, 10)
print(optVars)
print(states)

(DeviceArray([1.0898231, 0.9927029], dtype=float32), DeviceArray([[0.0000000e+00, 5.0000000e+00, 0.0000000e+00],
             [8.6539518e-04, 3.0649388e+00, 0.0000000e+00],
             [5.0000000e+00, 5.0000000e+00, 0.0000000e+00]],            dtype=float32))
[[[0.0000000e+00 5.0000000e+00 0.0000000e+00]
  [8.6539518e-04 2.9668388e+00 0.0000000e+00]
  [5.0000000e+00 5.0000000e+00 0.0000000e+00]]

 [[0.0000000e+00 5.0000000e+00 0.0000000e+00]
  [4.1885194e-03 2.7924933e+00 0.0000000e+00]
  [5.0000000e+00 5.0000000e+00 0.0000000e+00]]

 [[0.0000000e+00 5.0000000e+00 0.0000000e+00]
  [1.6313605e-02 2.5712357e+00 0.0000000e+00]
  [5.0000000e+00 5.0000000e+00 0.0000000e+00]]

 [[0.0000000e+00 5.0000000e+00 0.0000000e+00]
  [4.3233473e-02 2.3357904e+00 0.0000000e+00]
  [5.0000000e+00 5.0000000e+00 0.0000000e+00]]

 [[0.0000000e+00 5.0000000e+00 0.0000000e+00]
  [8.9501612e-02 2.1178215e+00 0.0000000e+00]
  [5.0000000e+00 5.0000000e+00 0.0000000e+00]]

 [[0.0000000e+00 5.0000000e+00 0.000000