In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from jarzynski import init_square, forward, imap, forward_n

In [None]:
init_square_j = jax.jit(init_square, static_argnums=1)
forward_j = jax.jit(forward)

In [None]:
state = init_square_j(jax.random.PRNGKey(0), 2, 0.0)
forward_j(1.0, state)

In [None]:
def plot_state(state, proj, draw_walls=True):
#     plt.axis('off')
    plt.xticks([])
    plt.yticks([])
    plt.axis('square')
    a = 0.4
    plt.xlim(-a, a)
    plt.ylim(-a, a)

    r = state['balls']['r']
    p = proj(state['balls']['x'])
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    
    x = p[:, 0] + r * jnp.cos(phi[:, None])
    y = p[:, 1] + r * jnp.sin(phi[:, None])
    plt.plot(x, y, lw=1, color='black')
    
    if draw_walls:
        x = proj(state['walls']['x'])
        j = proj(state['walls']['j'])
        k = proj(state['walls']['k'])
        path = [x, x + j, x + j + k, x + k, x]
        plt.plot([x[:, 0] for x in path], [x[:, 1] for x in path], 'black')


def xy(pos):
    return pos[..., [0, 1]]

def xz(pos):
    return pos[..., [0, 2]]

def zy(pos):
    return pos[..., [2, 1]]

In [None]:
N = 5
fig, axs = plt.subplots(1, N, figsize=(5.5, 1.27))

def view(ax, states):
    plt.sca(ax)
    plt.cla()
    plot_state(state, xy)
    
state = init_square_j(jax.random.PRNGKey(0), 160, 4e-2)
vel = 15
state['walls']['v'] = jnp.array([
    [vel, 0.0, 0.0],
    [-vel, 0.0, 0.0],
    [0.0, vel, 0.0],
    [0.0, -vel, 0.0],
])
dt = 0.5 / vel / (N - 1)

for i in range(N):
    view(axs[i], state)
    A = jnp.sum(jnp.pi * state['balls']['r']**2)
    x = state['walls']['x']
    B = (x[1, 0] - x[0, 0]) * (x[3, 1] - x[2, 1])
    phi = A / B
    plt.title(fr"$\phi = {phi:.2f}$")

    if i < N - 1:
        n, state, _ = forward_j(dt, state)
        print(n)

plt.tight_layout(w_pad=0.0, h_pad=0.3)
plt.savefig('jamming.pgf')

In [None]:
jnp.pi / (2 * 3**0.5)

In [None]:
state = init_square_j(jax.random.PRNGKey(0), 155, 4e-2)
vel = 10.0
state['walls']['v'] = jnp.array([
    [vel, 0.0, 0.0],
    [-vel, 0.0, 0.0],
    [0.0, vel, 0.0],
    [0.0, -vel, 0.0],
])

#t, state = forward_n(100_000, state)

In [None]:
t, state = forward_n(10_000, state)

A = jnp.sum(jnp.pi * state['balls']['r']**2)
x = state['walls']['x']
B = (x[1, 0] - x[0, 0]) * (x[3, 1] - x[2, 1])
phi = A / B

phi