In [1]:
import torch
import jax
import jax.numpy as jnp
import time
from env_jax import ParticleEnvJAX
from make_video import watch

In [2]:
from predictive_latent import JaxNetwork

In [3]:
key = jax.random.PRNGKey(0)
env = ParticleEnvJAX(key, n_env=10000, n_particles=10, dt=0.01,
                     interaction_strength=0.2, wall_force_coeff=1000, damping=1, vel_range=2)

# Flatten state: positions + velocities
def flatten_state(pos, vel):
    return jnp.concatenate([pos.reshape(pos.shape[0], -1), vel.reshape(vel.shape[0], -1)], axis=-1)

state_dim = env.n_particles * 2 * 2  # pos + vel for each particle
net = JaxNetwork(key, [state_dim, 128, 128, state_dim])  # predict next state


In [4]:
lr = 1e-3
n_steps = 100000

for step in range(n_steps):
    # 1. Take one environment step
    env.step()
    
    # 2. Get current & next state
    s_n = flatten_state(env.pos, env.vel)
    env.step()
    s_n_next = flatten_state(env.pos, env.vel)
    
    # 3. Update network params (SGD)
    net.params = net.sgd_step(net.params, s_n, s_n_next, lr=lr)
    
    # 4. Optionally measure MSE loss for monitoring
    if step % 50 == 0:
        loss_val = JaxNetwork._loss(net.params, s_n, s_n_next)
        print(f"Step {step}, MSE: {loss_val:.5f}")

Step 0, MSE: 1.08487
Step 50, MSE: 1.18709
Step 100, MSE: 1.17319
Step 150, MSE: 1.15927
Step 200, MSE: 1.14735
Step 250, MSE: 1.13796
Step 300, MSE: 1.12496
Step 350, MSE: 1.11725
Step 400, MSE: 1.10736
Step 450, MSE: 1.09895
Step 500, MSE: 1.09198
Step 550, MSE: 1.08512
Step 600, MSE: 1.07533
Step 650, MSE: 1.06747
Step 700, MSE: 1.06462
Step 750, MSE: 1.05956
Step 800, MSE: 1.05316
Step 850, MSE: 1.04727
Step 900, MSE: 1.03942
Step 950, MSE: 1.03591
Step 1000, MSE: 1.03286
Step 1050, MSE: 1.02911
Step 1100, MSE: 1.02353
Step 1150, MSE: 1.02073
Step 1200, MSE: 1.01800
Step 1250, MSE: 1.01184
Step 1300, MSE: 1.00814
Step 1350, MSE: 1.00344
Step 1400, MSE: 1.00226
Step 1450, MSE: 0.99677
Step 1500, MSE: 0.99456
Step 1550, MSE: 0.99025
Step 1600, MSE: 0.98916
Step 1650, MSE: 0.98597
Step 1700, MSE: 0.98162
Step 1750, MSE: 0.97933
Step 1800, MSE: 0.97646
Step 1850, MSE: 0.97363
Step 1900, MSE: 0.97134
Step 1950, MSE: 0.96751
Step 2000, MSE: 0.96522
Step 2050, MSE: 0.96423
Step 2100, MSE:

KeyboardInterrupt: 

In [5]:
watch(env, model=net, n_steps=200, dt_pred=5, scale=600, output_file="predicted.mp4", fps=30)

TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (40,).

In [6]:
lr = 1e-3
n_steps = 500

In [4]:
env.step()
s_n = flatten_state(env.pos, env.vel)
env.step()
s_n_next = flatten_state(env.pos, env.vel)

In [8]:
net.params = net.sgd_step(net.params, s_n, s_n_next, lr=lr)


In [19]:
s_n.shape

(10, 40)

In [5]:
key = jax.random.PRNGKey(0)
n_steps = 1000

In [12]:
env_jax = ParticleEnvJAX(key, n_env=100000, n_particles=10, dt=0.001,interaction_strength=0.05, damping=0.99,wall_force_coeff=5,vel_range=1)
start = time.time()
for _ in range(n_steps):
    env_jax.step()
jax.device_get(env_jax.pos)  # sync computation
print("JAX:", time.time() - start, "s")
del env_jax

JAX: 0.6337120532989502 s


In [4]:

env_jax = ParticleEnvJAX(key, n_env=1, n_particles=10, dt=0.01,interaction_strength=0.2, wall_force_coeff=1000, damping=1, vel_range=2)


In [6]:
env_jax.vel

DeviceArray([[[-0.54249716, -1.6424489 ],
              [ 0.5954504 ,  0.06958485],
              [ 0.95345116,  0.58791924],
              [-0.5985122 ,  0.82822037],
              [-1.9103389 ,  1.6309776 ],
              [ 0.02079391,  1.5843863 ],
              [ 1.4163332 ,  0.59623003],
              [ 1.2774882 ,  0.5948634 ],
              [-0.7074404 , -1.7265964 ],
              [ 1.2423677 , -0.7262883 ]]], dtype=float32)

In [17]:
env_jax.pos = env_jax.pos.at[0,0,0].set(0.78)


In [5]:
watch(env_jax, n_steps=1000, scale=1000, output_file="simulation.mp4", fps=30)

TypeError: _measure_KE() missing 1 required positional argument: 'vel'

In [31]:
env_jax.pos

DeviceArray([[[ 0.8177988 , -0.73890257],
              [-0.20882733, -0.86405367]]], dtype=float32)