In [29]:
# This GitHub serves as a guide for replication and further development of LNNs, consequently it´s heavily commented and explained (for detailed explanation go to LNN for 1M and 1S).
# This code stems from various sources, along the code it will be commented in detail and the sources cited from which the code stems
# I was heavily influenced by Cranmer´s form of implementation of LNNs (https://github.com/MilesCranmer/lagrangian_nns) however there are
# not many simple LNNs available from which to learn, therefore AI was used as a guide to know the overall steps the code should follow (I wrote the code) and for
# debugging and optimization of the performace of the code.


# Inverse Problem LNN N Masses, N+1 Springs

k_vals = [4.1,2,3,4.1,2]

import jax
import jax.numpy as jnp
from jax import grad, jacobian, hessian, vmap, random
import optax
import flax.linen as nn
from flax.training import train_state
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

m = 1.0


# Generate Data with equations of motion

def simulate_system_N(y0, k_vals, t_max=10.0, num_points=100):

    N = len(y0) // 2
    k_vals = jnp.array(k_vals)

    def sum_of_acc(t,y):

      x = y[:N]
      a = jnp.zeros_like(x)
      v = y[N:]

      sum = 0
      a = a.at[0].set((-k_vals[0] * x[0] + k_vals[1] * (x[1] - x[0]))/m)
      a = a.at[-1].set((-k_vals[-1] * x[-1] + k_vals[-2] * (x[-2] - x[-1]))/m)
      if N > 2:
        for i in range(1,N-1):
          a = a.at[i].set((k_vals[i+1] * (x[i+1] - x[i]) + k_vals[i] * (x[i-1] - x[i]))/m)


      return jnp.concatenate([v, a])



    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, num_points)
    sol = solve_ivp(sum_of_acc, t_span, y0, t_eval=t_eval, method='RK45')

    t = jnp.array(sol.t)
    final = len(y0) // 2
    q = jnp.stack([sol.y[i] for i in range(final)], axis=1)
    q_dot = jnp.stack([sol.y[i] for i in range(final, len(y0))], axis=1)

    dt = t[1] - t[0]
    q_ddot = jnp.gradient(q_dot, dt, axis=0)

    return t, q, q_dot, q_ddot


def data(N =2, n_trajectories=10, num_points=100, k_vals=None):

    q_list = []
    q_dot_list = []
    q_ddot_list = []

    rng = np.random.default_rng(0)
    if k_vals is None:
        k_vals = jnp.ones(N + 1)

    for i in range(n_trajectories):

        y0 = rng.uniform(low=-2.0, high=2.0, size=(2 * N,))

        t_simulated, q, q_dot, q_ddot = simulate_system_N(y0, k_vals, num_points=num_points)

        q_list.append(q)
        q_dot_list.append(q_dot)
        q_ddot_list.append(q_ddot)

    return jnp.concatenate(q_list), jnp.concatenate(q_dot_list), jnp.concatenate(q_ddot_list)

# Definition LNN
class LNN_N_Masses(nn.Module):
    log_k: jnp.array
    m: float = 1.0

    @nn.compact
    def __call__(self, q, q_dot):

        # As it is a generalization one should acces the length of the last dimensio of q to know how many masses it is describing

        x = [q[..., i] for i in range(q.shape[-1])]

        v = [q_dot[..., i] for i in range(q_dot.shape[-1])]

        log_k = self.param('log_k', lambda _: self.log_k)

        k = jnp.exp(log_k)

        T = 0.5 * self.m * jnp.sum(jnp.stack([o**2 for o in v]))

        V = 0.5*k[0]* x[0]**2 + 0.5*k[-1]* x[-1]**2
        for i in range(len(x)-1):
            V += 0.5*k[i+1]*(x[i+1] - x[i])**2
        return T - V


# Definition of Lagrangian calculates lagrangian by LNN and returns acceleration which is a parameter that can be integrated to get q

def lagrangian(LNN_returnable, params, q, q_dot):

  def Lagrangian_from_LNN( q, q_dot):
    #call LNN to return value of Lagrangian
    return LNN_returnable.apply(params, q, q_dot).squeeze()

  # apply definition
  sec = jacobian(grad(Lagrangian_from_LNN, 1), 0)(q, q_dot)
  par = grad(Lagrangian_from_LNN, 0)(q, q_dot) - jnp.matmul(sec, q_dot)

  H = hessian(Lagrangian_from_LNN, 1)(q, q_dot)

  q_ddot_pred = jnp.linalg.pinv(H) @ par
  return q_ddot_pred



# Loss Function
def loss_function(params, model, q, q_dot, q_ddot_data):
    # Standard function used in other LNN codes https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html
    q_ddot_pred = vmap(lambda q, q_dot: lagrangian(model, params, q, q_dot))(q, q_dot)
    # Equivalent to tensorflow´s reduce_mean
    return jnp.mean((q_ddot_pred.squeeze() - q_ddot_data)**2)

# Standard procedure for jax https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html#using-trainstate-in-flax-nnx , initialices de model LNN, defines params, model an optimizer
def create_train_state(rng, model, learning_rate=1e-3):
    init_q = jnp.ones((1,2))
    params = model.init(rng, init_q, init_q)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optax.adam(learning_rate))
@jax.jit
def train(state, q, q_dot, q_ddot):
    loss, grads = jax.value_and_grad(loss_function)(state.params, model, q, q_dot, q_ddot)
    state = state.apply_gradients(grads=grads)
    return state, loss




# Training of LNN
q, q_dot, q_ddot = data(N = 4, n_trajectories=20, num_points=100, k_vals=k_vals)
print(q[1].shape)
print(q_dot.shape)
print(q_ddot.shape)

rng = jax.random.PRNGKey(0)
model = LNN_N_Masses(log_k=jnp.ones(len(q[1])+1))
state = create_train_state(rng, model)

losses = []
for epoch in range(5000):
    state, loss = train(state, q, q_dot, q_ddot)
    losses.append(loss)

    if epoch % 100 == 0:
        k = jnp.exp(state.params['params']['log_k'])
        print(f"Epoch {epoch}, Loss: {loss:.5f}, k=", k)

# Obtain acceleration from Lagrangian NN
q_ddot_pred = vmap(lambda q, q_dot: lagrangian(model, state.params, q, q_dot))(q, q_dot)



(4,)
(2000, 4)
(2000, 4)
Epoch 0, Loss: 1.67663, k= [2.7210016 2.715565  2.7210016 2.7210016 2.7210016]
Epoch 100, Loss: 0.93849, k= [3.0066671 2.476532  2.957669  2.992978  2.844326 ]
Epoch 200, Loss: 0.50104, k= [3.302761  2.3029873 3.043863  3.2530656 2.6614475]
Epoch 300, Loss: 0.23469, k= [3.5734808 2.1807523 3.032013  3.4914532 2.4705684]
Epoch 400, Loss: 0.09541, k= [3.7832844 2.0960398 3.003677  3.689067  2.3056965]
Epoch 500, Loss: 0.03703, k= [3.9185581 2.0392075 2.9799652 3.8323026 2.1772125]
Epoch 600, Loss: 0.01761, k= [3.9930604 2.0040512 2.964362  3.9222112 2.0892043]
Epoch 700, Loss: 0.01243, k= [4.0300145 1.9845033 2.9557164 3.9715254 2.0369146]
Epoch 800, Loss: 0.01131, k= [4.0471106 1.9747431 2.9515479 3.9955332 2.009967 ]
Epoch 900, Loss: 0.01111, k= [4.0545254 1.9703366 2.9497645 4.006018  1.9978083]
Epoch 1000, Loss: 0.01108, k= [4.0575056 1.9685268 2.9490848 4.010141  1.9929516]
Epoch 1100, Loss: 0.01108, k= [4.05861   1.9678478 2.9488556 4.011599  1.9912231]
Epo

In [24]:
#lnn 4 masas

import jax
import jax.numpy as jnp
from jax import grad, jacobian, hessian, vmap, random
import optax
import flax.linen as nn
from flax.training import train_state
import matplotlib.pyplot as plt


m = 1.0
true_ks = jnp.array([1,1,1,1,2,2])

springs = [(0, 1), (2, 3), (0, 2), (1, 3), (0, 3), (1, 2)]

x0 = jnp.array([[0., 0.], [1., 0.], [0., 1.], [1., 1.]])
a_side = jnp.array([jnp.linalg.norm(x0[a] - x0[b]) for a, b in springs])


# Generate data from equations, on idea similar to N masses but generalized to 2D (calculation of direction of mass effect on others)
def simulate_system_2D(T=10.0, dt=0.01, key=random.PRNGKey(0)):

    steps = int(T/dt)
    key1, key2 = random.split(key)

    def cal_of_acc(x1, x2, k, a_fijo):
      # Definition contribution mass j in i
      abs = x2 - x1
      norm = jnp.linalg.norm(abs)

      direction = abs/norm
      return k * (norm - a_fijo) * direction


    xs = x0 + 0.05 * random.normal(key1, (4, 2))
    vs = 0.05 * random.normal(key2, (4, 2))

    # calculate rest a = lattice parameter (side of square)
    a_side = jnp.array([jnp.linalg.norm(x0[a] - x0[b]) for a, b in springs])

    pos_total = []
    vel_total = []
    acc_total = []

    for i in range(steps):

        sum_acc = jnp.zeros((4, 2))

        for i, (a, b) in enumerate(springs):
            F = cal_of_acc(xs[a], xs[b], true_ks[i], a_side[i])/m

            sum_acc = sum_acc.at[a].add(F)
            sum_acc = sum_acc.at[b].add(-F)

        acc = sum_acc

        #definition update velocity and trajectory
        vs = vs + acc* dt
        xs = xs + vs *dt

        pos_total.append(xs)
        vel_total.append(vs)
        acc_total.append(acc)

    return jnp.stack(pos_total), jnp.stack(vel_total), jnp.stack(acc_total)



# LNN definition, same as N Mass and N+1 Springs but sumatory for other total velocities and lenghts minus the normal length of spring

class LNN_2D_Masses(nn.Module):
    log_ks: jnp.array
    a_side: jnp.ndarray
    m: float = 1.0

    @nn.compact
    def __call__(self, q, q_dot):

        # As it is a generalization one should acces the length of the last dimensio of q to know how many masses it is describing

        x = [q[..., i] for i in range(q.shape[-1])]

        v = [q_dot[..., i] for i in range(q_dot.shape[-1])]

        log_ks = self.param('log_ks', lambda _: self.log_ks)
        print(log_ks)

        #ks = [log_ks[i] for i in range(log_ks.shape[-1])]

        ks = jnp.exp(log_ks)
        print(ks)

        T = 0.5 * self.m * jnp.sum(jnp.stack([o**2 for o in v]))
        V = 0
        for i, (a, b) in enumerate(springs):
            V += 0.5*ks[i]* (jnp.linalg.norm(q[a] - q[b]) - a_side[i])** 2
        return T - V

def lagrangian(LNN_returnable, params, q_new, q_dot_new):

  def Lagrangian_from_LNN( q_2D, q_dot_2D):
    #call LNN to return value of Lagrangian
    q = q_2D.reshape((4, 2))
    q_dot = q_dot_2D.reshape((4, 2))
    return LNN_returnable.apply(params, q, q_dot).squeeze()

  # apply definition
  q, q_dot = q_new.reshape(-1), q_dot_new.reshape(-1)
  sec = jacobian(grad(Lagrangian_from_LNN, 1), 0)(q, q_dot)
  par = grad(Lagrangian_from_LNN, 0)(q, q_dot) - jnp.matmul(sec, q_dot)

  H = hessian(Lagrangian_from_LNN, 1)(q, q_dot)

  q_ddot_pred = jnp.linalg.pinv(H) @ par
  return q_ddot_pred.reshape((4, 2))

# Loss Function
def loss_function(params, model, q, q_dot, q_ddot_data):
    # Standard function used in other LNN codes https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html
    q_ddot_pred = vmap(lambda q, q_dot: lagrangian(model, params, q, q_dot))(q, q_dot)
    # Equivalent to tensorflow´s reduce_mean
    return jnp.mean((q_ddot_pred - q_ddot_data)**2)

# Standard procedure for jax https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html#using-trainstate-in-flax-nnx , initialices de model LNN, defines params, model an optimizer
def create_train_state(rng, model, learning_rate=1e-3):
    params = model.init(rng, jnp.ones((4,2)), jnp.ones((4,2)))
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optax.adam(learning_rate))
@jax.jit
def train(state, q, q_dot, q_ddot):
    loss, grads = jax.value_and_grad(loss_function)(state.params, model, q, q_dot, q_ddot)
    state = state.apply_gradients(grads=grads)
    return state, loss




# Training of model (Standard)
q, q_dot, q_ddot = simulate_system_2D()
log_ks_init = jnp.zeros(6)
model = LNN_2D_Masses(log_ks=log_ks_init, a_side=a_side)
state = create_train_state(random.PRNGKey(0), model)




losses = []
for epoch in range(3000):
    state, loss = train(state, q, q_dot, q_ddot)
    losses.append(loss)
    if epoch % 300 == 0:
        ks_learned = jnp.exp(state.params['params']['log_ks'])
        print(f"Epoch {epoch}, Loss={loss:.6f}, ks={ks_learned}")

# Graph
ks_learned = jnp.exp(state.params['params']['log_ks'])


print("Predictions", ks_learned)
print("Real k:", true_ks)



[0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1.]
Traced<ShapedArray(float32[6])>with<JVPTrace> with
  primal = Traced<ShapedArray(float32[6])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float32[6])>with<JaxprTrace> with
    pval = (ShapedArray(float32[6]), None)
    recipe = LambdaBinding()
Traced<ShapedArray(float32[6])>with<JVPTrace> with
  primal = Traced<ShapedArray(float32[6])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float32[6])>with<JaxprTrace> with
    pval = (ShapedArray(float32[6]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x798940c7d5e0>, in_tracers=(Traced<ShapedArray(float32[6]):JaxprTrace>, Traced<ShapedArray(float32[6]):JaxprTrace>), out_tracer_refs=[<weakref at 0x798940bcf650; to 'JaxprTracer' at 0x798940bcd080>], out_avals=[ShapedArray(float32[6])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[6] b:f32[6]. let c:f32[6] = mul a b in (c,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue), 'out_shardings': (UnspecifiedValue,), '