In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegFileWriter
from tqdm.auto import tqdm
from functools import lru_cache

from jarzynski import init_piston, forward

In [None]:
init_piston_j = jax.jit(init_piston, static_argnums=1)
forward_j = jax.jit(forward)

In [None]:
def experiment(key, xs, n, vel, r):
    state = init_piston_j(key, n, r)
    state['walls']['v'] = jnp.array([
        [0.0, 0.0, vel/2],
        [0.0, 0.0, -vel/2],
    ])
    t = 0
    ws = [0]
    for x in xs[1:]:
        dt = x / vel - t
        n, state, work = forward_j(dt, state)
        t += dt

        ws += [ws[-1] + work]

    return jnp.array(ws)

keys = jax.random.split(jax.random.PRNGKey(1), num=500)
xs = jnp.linspace(0.0, 1.5, 100)

@lru_cache
def foobar(n, vel, r):
    return jnp.stack([experiment(key, xs, n, vel, r) for key in tqdm(keys)], axis=1)

In [None]:
Va = 2 * jnp.pi
Vb = (2 - xs) * jnp.pi

def Z(n, V, r):
    vol = 4/3 * jnp.pi * r**3
    return V**n * (1 - n * (n-1) / 2 * vol / V)

plt.style.use('classic')

fig, axs = plt.subplots(1, 3, dpi=100, figsize=(8, 3))

for ax, (n, vel, r) in zip(axs, [(5, 0.01, 0.1), (5, 2.0, 0.1), (50, 0.01, 0.1)]):
    plt.sca(ax)
    works = foobar(n, vel, r)

    for m in [1, 10, works.shape[1]]:
        plt.plot(
            Z(n, Vb, r) / Z(n, Va, r), 
            jnp.exp(-works[:, :m]).mean(1), 
            '-', 
            label=fr'$N_{{\mathrm{{avg}}}} = {m}$'
        )

    plt.legend(loc=0, fontsize=12, labelspacing=0, handlelength=1.5, handletextpad=0.2, frameon=False)
    plt.plot([0,1], [0,1], '--k')
    plt.xscale('log')
    plt.yscale('log')
    plt.title(fr'$n={n}$, $v_p={vel}$, $r={r}$', fontsize=13)

plt.sca(axs[0])
plt.xlim(1e-4, 1)
plt.ylim(1e-4, 1)
plt.yticks([])

plt.sca(axs[1])
plt.xlim(1e-4, 1)
plt.ylim(1e-4, 1)
plt.yticks([])

plt.sca(axs[2])
plt.xlim(3e-7, 1)
plt.ylim(3e-7, 1)
plt.yticks([])

plt.sca(axs[1])
plt.xlabel(r'$e^{-\Delta F / T} = Z_B/Z_A$')
plt.sca(axs[0])
plt.ylabel(r'$\langle e^{-W / T} \rangle$')

plt.tight_layout(pad=0.1)

plt.savefig('je.pgf')

In [None]:
def plot_piston(state, proj, draw_walls=True):
    plt.axis('off')
    plt.axis('square')
    plt.xlim(-1.1, 1.1)
    plt.ylim(-1.1, 1.1)

    r = state['balls']['r']
    p = proj(state['balls']['x'])
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    
    x = p[:, 0] + r * jnp.cos(phi[:, None])
    y = p[:, 1] + r * jnp.sin(phi[:, None])
    plt.plot(x, y)
    
    if draw_walls:
        x = proj(state['walls']['x'])
        j = proj(state['walls']['j'])
        k = proj(state['walls']['k'])
        path = [x, x + j, x + j + k, x + k, x]
        plt.plot([x[:, 0] for x in path], [x[:, 1] for x in path], 'white')


def xy(pos):
    return pos[..., [0, 1]]

def xz(pos):
    return pos[..., [0, 2]]

def zy(pos):
    return pos[..., [2, 1]]

In [None]:
plt.style.use('dark_background')

fig, [axs1, axs2] = plt.subplots(2, 3, figsize=(9, 6))

def view(axs, states):
    ax1, ax2, ax3 = axs
    
    plt.sca(ax1)
    plt.cla()
    plot_piston(state, xz)

    plt.sca(ax2)
    plt.cla()
    plot_piston(state, xy, False)
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    plt.plot(jnp.cos(phi), jnp.sin(phi), 'white')

    plt.sca(ax3)
    plt.cla()
    plot_piston(state, zy)
    
state = init_piston_j(jax.random.PRNGKey(0), 50, 1e-1)
view(axs1, state)

vel = 15
state['walls']['v'] = jnp.array([
    [0.0, 0.0, vel/2],
    [0.0, 0.0, -vel/2],
])
_, state, _ = forward_j(1.5 / vel, state)
view(axs2, state)

plt.tight_layout()
plt.savefig('compression.pdf')

In [None]:
plt.style.use('dark_background')

fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(9, 3))

for ax in [ax1, ax2, ax3]:
    plt.sca(ax)
    plt.axis('off')
plt.tight_layout()

state = init_piston_j(jax.random.PRNGKey(0), 50, 1e-1)

fps = 20
moviewriter = FFMpegFileWriter(fps=fps)

with moviewriter.saving(fig, 'myfile.mp4', dpi=200):
    for i in tqdm(range(300)):
        dt = 5e-2
        
        if 5.0 < i * dt < 10.0:
            v = -1.5/5.0
        else:
            v = 0.0

        state['walls']['v'] = jnp.array([
            [0.0, 0.0, 0.0],
            [0.0, 0.0, v],
        ])
        
        _, state, _ = forward_j(dt, state)
        
        plt.sca(ax1)
        plt.cla()
        plot_piston(state, xz)

        plt.sca(ax2)
        plt.cla()
        plot_piston(state, xy, False)
        phi = jnp.linspace(0, 2 * jnp.pi, 100)
        plt.plot(jnp.cos(phi), jnp.sin(phi), 'white')

        plt.sca(ax3)
        plt.cla()
        plot_piston(state, zy)
        
        moviewriter.grab_frame()