In [6]:
import flax.linen as nn
import flax
import jax, jax.numpy as jnp
import optax
import numpy as np

In [117]:
x_samples = jnp.load('/Users/joshuacoles/Developer/checkouts/fyp/slimplectic-jax/nn/xData_lowNoise.npy').astype(
    'float32')
print(f"x shape {x_samples.shape}")

y_samples = jnp.load('/Users/joshuacoles/Developer/checkouts/fyp/slimplectic-jax/nn/yData_lowNoise.npy').astype(
    'float32')
print(f"y shape {y_samples.shape}")

x shape (20480, 41, 2)
y shape (20480, 3)


In [177]:
class Model(nn.Module):
    lstm_features: int
    action_embedding_dimension: int

    @nn.compact
    def __call__(self, x, train):
        print("Input", x.shape)
        # x shape should be (batch, time, features)
        # features is x, xdot, v, vdot, t
        carry, x = nn.RNN(nn.LSTMCell(features=self.lstm_features), return_carry=True)(x)
        print("Carry", jax.tree_util.tree_map(jnp.shape, carry))
        print("After LSTM", x.shape)

        # Flatten
        x = x.reshape(-1)
        print("After Reshape", x.shape)

        # Dropout during training
        # x = nn.Dropout(0.3, deterministic=not train)(x)
        # print("After Dropout", x.shape)

        # # Is this actually what we want or are we misunderstanding the problem
        # x = x.reshape(x.shape[0])

        x = nn.Dense(self.action_embedding_dimension)(x)
        print("After Dense", x.shape)
        return x


model = Model(lstm_features=5, action_embedding_dimension=3)

In [178]:
loss_rngs = {'dropout': jax.random.key(1)}


# Same as JAX version but using model.apply().
@jax.jit
def loss_function(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
        print("X", x.shape)
        print("Y", y.shape)
        pred = model.apply(params, x, train=True, rngs=loss_rngs)
        print("PRED", pred.shape)

        a = jnp.inner(y - pred, y - pred) / 2.0
        print(pred.shape)
        print(a.shape)

        return a

    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [183]:
learning_rate = 0.3
init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
params = model.init(init_rngs, jnp.ones_like(x_samples[0]), train=True)

tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss_function)

Input (41, 2)
Carry ((5,), (5,))
After LSTM (41, 5)
After Reshape (205,)
After Dense (3,)


In [180]:
batch_size = 64
batched_x = x_samples.reshape((x_samples.shape[0] // batch_size, batch_size, x_samples.shape[1], x_samples.shape[2]))
batched_y = y_samples.reshape((y_samples.shape[0] // batch_size, batch_size, y_samples.shape[1]))
(batched_x.shape, batched_y.shape)

((320, 64, 41, 2), (320, 64, 3))

In [181]:
for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

X (41, 2)
Y (3,)
Input (41, 2)
Carry ((5,), (5,))
After LSTM (41, 5)
After Reshape (205,)
After Dense (3,)
PRED (3,)
(3,)
()
Loss step 0:  nan
Loss step 10:  nan
Loss step 20:  nan
Loss step 30:  nan
Loss step 40:  nan
Loss step 50:  nan
Loss step 60:  nan
Loss step 70:  nan
Loss step 80:  nan
Loss step 90:  nan
Loss step 100:  nan


In [None]:
print(f"initialized parameter shapes: {jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))}")

In [185]:
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)

In [187]:
grads

{'params': {'Dense_0': {'bias': Array([nan, nan, nan], dtype=float32),
   'kernel': Array([[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
 

In [225]:
# Define the squared loss for a single pair (x,y)
def squared_error(p, x, y):
    print("X", x.shape)
    print("Y", y.shape)
    pred = model.apply(p, x, train=False, rngs=loss_rngs)
    delta = jnp.inner(y - pred, y - pred) / 2.0
    print("PRED", pred.shape)
    print("DELTA", delta.shape)

    return delta

In [226]:
jax.grad(lambda params: jnp.mean(jax.vmap(squared_error, in_axes=(None, 0, 0))(
    params,
    x_samples,
    y_samples
)))(params)

X (41, 2)
Y (3,)
Input (41, 2)
Carry ((5,), (5,))
After LSTM (41, 5)
After Reshape (205,)
After Dense (3,)
PRED (3,)
DELTA ()


{'params': {'Dense_0': {'bias': Array([nan, nan, nan], dtype=float32),
   'kernel': Array([[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
 

In [199]:
jax.vmap(squared_error, in_axes=(None, 0, 0))(
    params,
    x_samples,
    y_samples
)

X (41, 2)
Y (3,)
Input (41, 2)
Carry ((5,), (5,))
After LSTM (41, 5)
After Reshape (205,)
After Dense (3,)
PRED (3,)
(3,)
()


(20480,)

In [216]:
jnp.sum(jnp.array()).size

TypeError: Value '<generator object <genexpr> at 0x2e5932dd0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [218]:
sum([i.size for i in jax.tree.leaves(params)])

778

In [239]:
from slimpletic import DiscretisedSystem, GGLBundle, SolverManual

q0 = jnp.array([1.0])
pi0 = jnp.array([1.0])

def lagrangian_family(q, v, _, embedding):
    # Fixed power series expansion to make tf happy
    v = q[0] ** 2 * embedding[0] + v[0] ** 2 * embedding[1] + q[0] * v[0] * embedding[2]
    return v


system = DiscretisedSystem(
    ggl_bundle=GGLBundle(r=0),
    dt=0.1,
    lagrangian=lagrangian_family,
    k_potential=None,
    pass_additional_data=True,
)

solver = SolverManual(system)

loss_rngs = {'dropout': jax.random.key(1)}


# Same as JAX version but using model.apply().
# @jax.jit
def loss_function(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(trajectory, true_embedding):
        predicted_embedding = model.apply(params, trajectory, train=True, rngs=loss_rngs)
        
        predicted_q, predicted_pi = solver.integrate(
            q0=q0,
            pi0=pi0,
            t0=0,
            iterations=40,
            additional_data=predicted_embedding,
            result_orientation='coordinate'
        )

        true_q, true_pi = solver.integrate(
            q0=q0,
            pi0=pi0,
            t0=0,
            iterations=40,
            additional_data=true_embedding,
            result_orientation='coordinate'
        )
        
        print("TQ", true_q.shape)
        print("PQ", predicted_q.shape)
        diff = (true_q - predicted_q).reshape(-1)
        
        return jnp.dot(diff, diff)

    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [243]:
loss_function(params, x_samples, y_samples)

Array(8.07752077e+197, dtype=float64)

In [242]:
jax.value_and_grad(loss_function)(params, x_samples, y_samples)

2024-03-12 12:51:13.710477: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4m10.814092s

********************************
[Compiling module jit_loss_function] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


(Array(8.07752077e+197, dtype=float64),
 {'params': {'Dense_0': {'bias': Array([nan, nan, nan], dtype=float32),
    'kernel': Array([[nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
           [nan, nan, nan],
          