In [11]:
import torch
import jax
import jax.numpy as jnp
import time
from env_torch import ParticleEnv
from env_torch_prealloc import ParticleEnvPreAlloc
from env_jax import ParticleEnvJAX

In [16]:
class ParticleEnvJAX:
    def __init__(self, key, n_env=10, n_particles=10, dt=0.01, box_size=1.0, k=10.0):
        self.n_env = n_env
        self.n_particles = n_particles
        self.dt = dt
        self.box_size = box_size
        self.k = k

        # state initialized to zeros/ones
        self.pos = jnp.zeros((n_env, n_particles, 2))
        self.vel = jnp.zeros((n_env, n_particles, 2))
        self.charge = jnp.ones((n_env, n_particles, 1))

        # buffers
        self.r = jnp.zeros((n_env, n_particles, n_particles, 2))
        self.dist3 = jnp.zeros((n_env, n_particles, n_particles, 1))
        self.qiqj = jnp.zeros((n_env, n_particles, n_particles, 1))
        self.F = jnp.zeros((n_env, n_particles, n_particles, 2))
        self.F_total = jnp.zeros((n_env, n_particles, 2))
        self.over_pos = jnp.zeros((n_env, n_particles, 2))
        self.under_pos = jnp.zeros((n_env, n_particles, 2))
        self.boundary_force = jnp.zeros((n_env, n_particles, 2))

        self.init_random(key)

    def init_random(self, key, pos_range=1.0, vel_range=0.1, charge_range=1.0):
        key1, key2, key3 = jax.random.split(key, 3)
        self.pos = jax.random.uniform(key1, (self.n_env, self.n_particles, 2), minval=-pos_range, maxval=pos_range)
        self.vel = jax.random.uniform(key2, (self.n_env, self.n_particles, 2), minval=-vel_range, maxval=vel_range)
        #self.charge = jax.random.uniform(key3, (self.n_env, self.n_particles, 1), minval=0.0, maxval=charge_range)

    def step(self):
        eps = 1e-4

        self.r = self.pos[:, :, None, :] - self.pos[:, None, :, :]
        print("1 r", self.r.shape)

        self.dist3 = jnp.linalg.norm(self.r, axis=-1, keepdims=True)**3 + eps
        print("2 dist3", self.dist3.shape)

        self.qiqj = self.charge * jnp.transpose(self.charge, (0,2,1))
        print("3 qiqj", self.qiqj.shape)

        self.F = self.r * self.qiqj[:, :, :, None] / self.dist3
        print("4 F", self.F.shape)

        self.F_total = jnp.sum(self.F, axis=2)
        print("5 F_total", self.F_total.shape)

        self.over_pos = jnp.clip(self.pos - self.box_size, a_min=0.0)
        print("6 over_pos", self.over_pos.shape)

        self.under_pos = jnp.clip(-self.box_size - self.pos, a_min=0.0)
        print("7 under_pos", self.under_pos.shape)

        self.boundary_force = -self.k * (self.over_pos + self.under_pos)
        print("8 boundary_force", self.boundary_force.shape)

        self.vel = self.vel + (self.F_total + self.boundary_force) * self.dt
        print("9 vel", self.vel.shape)

        self.pos = self.pos + self.vel * self.dt
        print("10 pos", self.pos.shape)

In [17]:
key = jax.random.PRNGKey(0)
n_steps = 100
n_env = 2
n_particles = 3
dt = 0.01

In [18]:
env_jax = ParticleEnvJAX(key, n_env=n_env, n_particles=n_particles, dt=dt)


In [19]:
env_jax.step()



1 r (2, 3, 3, 2)
2 dist3 (2, 3, 3, 1)
3 qiqj (2, 3, 3)
4 F (2, 3, 3, 2)
5 F_total (2, 3, 2)
6 over_pos (2, 3, 2)
7 under_pos (2, 3, 2)
8 boundary_force (2, 3, 2)
9 vel (2, 3, 2)
10 pos (2, 3, 2)


In [None]:
jax.device_get(env_jax.pos)  # sync computation


In [4]:
env_pre.charge*env_pre.charge.transpose(1,2)

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], device='cuda:0')

In [5]:
env_pre.step(dt=dt)

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], device='cuda:0')


In [56]:
env = ParticleEnv(n_env,n_particles)

In [57]:
env.step(dt=dt)

r torch.Size([2, 3, 3, 2])
qiqj torch.Size([2, 3, 3])
dist3 torch.Size([2, 3, 3, 1])


In [None]:
import torch

class ParticleEnv:
    def __init__(self, n_env=10, n_particles=10, device='cuda'):
        self.n_env = n_env
        self.n_particles = n_particles
        self.device = device

        self.pos = torch.zeros((n_env, n_particles, 2), device=device)
        self.vel = torch.zeros((n_env, n_particles, 2), device=device)
        self.charge = torch.ones((n_env, n_particles, 1), device=device)

        self.init_random()

    def init_random(self, pos_range=1.0, vel_range=0.1, charge_range=1.0):
        self.pos.uniform_(-pos_range, pos_range)
        self.vel.uniform_(-vel_range, vel_range)
        #self.charge.uniform_(-charge_range, charge_range)

    def step(self, dt=0.01, box_size=1.0, k=10.0):
        eps = 1e-4
        r = self.pos[:, :, None, :] - self.pos[:, None, :, :]
        dist3 = r.norm(dim=-1, keepdim=True).pow(3) + eps
        qiqj = self.charge * self.charge.transpose(1,2)
        
        F = r * qiqj[:, :, :, None] 
        F = F/ dist3
        F_total = F.sum(dim=2)

        # boundary force as true force
        over_pos = torch.clamp(self.pos - box_size, min=0.0)
        under_pos = torch.clamp(-box_size - self.pos, min=0.0)
        F_boundary = -k * (over_pos + under_pos)

        self.vel += (F_total + F_boundary) * dt
        self.pos += self.vel * dt

In [None]:
for _ in range(n_steps):
    env.step(dt=dt)
torch.cuda.synchronize()
print("PyTorch standard:", time.time()-start, "s")
del env; torch.cuda.empty_cache()