In [1]:
import flax.linen as nn
import jax, jax.numpy as jnp

In [2]:
training_epochs = 4000
training_time_steps = 10
training_datasize = 20480

In [11]:
# (batch, time, features)
x = jnp.ones((10, 50, 32)).astype('float32')
lstm = nn.RNN(nn.LSTMCell(64))
variables = lstm.init(jax.random.key(0), x)
y = lstm.apply(variables, x)
y

Array([[[-0.02776539, -0.09947269, -0.35887945, ...,  0.1575224 ,
          0.02195803,  0.01176363],
        [-0.06707337, -0.10323231, -0.49960035, ...,  0.21516931,
          0.02352902,  0.00946351],
        [-0.10023084, -0.08195096, -0.565464  , ...,  0.23122399,
          0.02382251,  0.00786299],
        ...,
        [-0.14599933, -0.05934289, -0.6633949 , ...,  0.24032372,
          0.02364738,  0.01068131],
        [-0.14599778, -0.05934377, -0.66339415, ...,  0.2403252 ,
          0.02364752,  0.01068177],
        [-0.14599648, -0.05934456, -0.66339356, ...,  0.2403265 ,
          0.02364765,  0.01068216]],

       [[-0.02776539, -0.09947269, -0.35887945, ...,  0.1575224 ,
          0.02195803,  0.01176363],
        [-0.06707337, -0.10323231, -0.49960035, ...,  0.21516931,
          0.02352902,  0.00946351],
        [-0.10023084, -0.08195096, -0.565464  , ...,  0.23122399,
          0.02382251,  0.00786299],
        ...,
        [-0.14599933, -0.05934289, -0.6633949 , ...,  

In [149]:
class Model(nn.Module):
    lstm_features: int
    action_embedding_dimension: int
    
    @nn.compact
    def __call__(self, x, train):
        # x shape should be (batch, time, features)
        # features is x, xdot, v, vdot, t
        x = nn.RNN(nn.LSTMCell(features=self.lstm_features))(x)
        
        # Dropout during training
        x = nn.Dropout(0.3, deterministic=not train)(x)
        
        # Is this actually what we want or are we misunderstanding the problem
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.action_embedding_dimension)(x)
        return x

In [150]:
model = Model(lstm_features=10, action_embedding_dimension=3)

In [122]:
x = jnp.ones((100, 100, 2)).astype('float32')
# We pass a sample input data into hte init function as
variables = model.init(jax.random.key(0), jnp.ones_like(x), train=False)
y = model.apply(variables, x, False)

In [124]:
y.shape

(100, 3)

In [112]:
y.shape

(100, 3)

In [160]:
# Same as JAX version but using model.apply().
@jax.jit
def mse2(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x, train=True, rngs={'dropout': jax.random.key(1)})
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [101]:
learning_rate = 0.3

In [161]:
import optax
import numpy as np

tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(variables)
loss_grad_fn = jax.value_and_grad(mse2)

In [133]:
x_samples = jnp.load('/Users/joshuacoles/Developer/checkouts/fyp/slimplectic-jax/nn/xData_lowNoise.npy').astype('float32')
np.array(x_samples).shape

(20480, 41, 2)

In [134]:
y_samples = jnp.load('/Users/joshuacoles/Developer/checkouts/fyp/slimplectic-jax/nn/yData_lowNoise.npy').astype('float32')
y_samples

Array([[ 15.913245  ,  -4.2763476 , -10.811048  ],
       [ -0.8697091 ,   2.8883822 ,  -8.701813  ],
       [ -0.91460717,   3.2345715 ,  11.649809  ],
       ...,
       [-18.918459  , -10.682423  ,  17.807985  ],
       [-16.834541  , -14.153472  ,   6.4526916 ],
       [-17.413532  ,  17.312805  ,   2.3421843 ]], dtype=float32)

In [162]:
init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
params = model.init(init_rngs, jnp.ones_like(x_samples).astype('float32'), train=False)

In [163]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

ScopeParamShapeError: Initializer expected to generate shape (410, 3) but got shape (10, 3) instead for parameter "kernel" in "/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [120]:
import flax

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

initialized parameter shapes:
 {'params': {'Dense_0': {'bias': (3,), 'kernel': (1000, 3)}, 'LSTMCell_0': {'hf': {'bias': (10,), 'kernel': (10, 10)}, 'hg': {'bias': (10,), 'kernel': (10, 10)}, 'hi': {'bias': (10,), 'kernel': (10, 10)}, 'ho': {'bias': (10,), 'kernel': (10, 10)}, 'if': {'kernel': (5, 10)}, 'ig': {'kernel': (5, 10)}, 'ii': {'kernel': (5, 10)}, 'io': {'kernel': (5, 10)}}}}
