In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegFileWriter
from tqdm.auto import tqdm
from functools import lru_cache
from plot_tools import texnum

from jarzynski import init_piston, forward, imap

In [None]:
init_piston_j = jax.jit(init_piston, static_argnums=1)
forward_j = jax.jit(forward)

In [None]:
state = init_piston_j(jax.random.PRNGKey(0), 20, 0.0)
forward_j(1.0, state)

In [None]:
def experiment(key, xs, n, vel, r):
    state = init_piston_j(key, n, r)
    state['walls']['v'] = jnp.array([
        [0.0, 0.0, vel/2],
        [0.0, 0.0, -vel/2],
    ])
    t = 0
    ws = [0]
    for x in xs[1:]:
        dt = x / vel - t
        n, state, work = forward_j(dt, state)
        t += dt

        ws += [ws[-1] + work]

    return jnp.array(ws)

xs = jnp.linspace(0.0, 1.5, 100)

@lru_cache
def foobar(n, vel, r, rep):
    keys = jax.random.split(jax.random.PRNGKey(1), num=rep)
    return jnp.stack([experiment(key, xs, n, vel, r) for key in tqdm(keys)], axis=1)

In [None]:
Va = 2 * jnp.pi
Vb = (2 - xs) * jnp.pi

def Z(n, V, r):
    vol = 4/3 * jnp.pi * r**3
    return V**n * (1 - n * (n-1) / 2 * vol / V)

plt.style.use('classic')

fig, axs = plt.subplots(1, 3, dpi=100, figsize=(8, 3))

exps = [(5, 0.01, 0.1), (5, 2.0, 0.1), (50, 0.01, 0.1)]
# exps = [(10, 0.01, 1e-6)]

for ax, (n, vel, r) in zip(axs, exps):
    plt.sca(ax)
    works = foobar(n, vel, r)

    for m in [1, 10, works.shape[1]]:
        plt.plot(
            Z(n, Vb, r) / Z(n, Va, r), 
            jnp.exp(-works[:, :m]).mean(1), 
            '-', 
            label=fr'$N_{{\mathrm{{avg}}}} = {m}$'
        )

    plt.legend(loc=0, fontsize=12, labelspacing=0, handlelength=1.5, handletextpad=0.2, frameon=False)
    plt.plot([0,1], [0,1], '--k')
    plt.xscale('log')
    plt.yscale('log')
    plt.title(fr'$n={n}$, $v_p={vel}$', fontsize=13)

plt.sca(axs[0])
plt.xlim(1e-4, 1)
plt.ylim(1e-4, 1)
plt.yticks([])

plt.sca(axs[1])
plt.xlim(1e-4, 1)
plt.ylim(1e-4, 1)
plt.yticks([])

plt.sca(axs[2])
plt.xlim(3e-7, 1)
plt.ylim(3e-7, 1)
plt.yticks([])

plt.sca(axs[1])
plt.xlabel(r'$e^{-\Delta F / T} = Z_B/Z_A$')
plt.sca(axs[0])
plt.ylabel(r'$\langle e^{-W / T} \rangle$')

plt.tight_layout(pad=0.1)

plt.savefig('je.pgf')

In [None]:
def random_piston_positions(key, z, n, r):
    def sample(key):
        key, k = jax.random.split(key)
        x = jax.random.uniform(k, (3,), minval=-1.0, maxval=1.0)
        x = x * jnp.array([1 - r, 1 - r, z - r])
        return key, x

    def body_for(i, args):
        key, pos, ok = args

        def cond(args):
            key, x, ok = args
            test = (jnp.linalg.norm(x[:2]) < 1.0 - r) & jnp.all(jnp.linalg.norm(pos - x, axis=1) > 2 * r)
            return ~test

        def body(args):
            key, x, ok = args
            return sample(key) + (0,)

        key, x = sample(key)
        key, x, ok = jax.lax.while_loop(cond, body, (key, x, ok))
        return key, pos.at[i].set(x), ok

    pos = 2.0 * jnp.ones((n, 3))
    key, pos, ok = jax.lax.fori_loop(0, n, body_for, (key, pos, 1))

    return key, pos, ok


def partition_function(key, z, n, r, num):
    key = jax.random.split(key, num=num)
    key, pos, ok = imap(random_piston_positions, (0, None, None, None), 0)(key, z, n, r)
    V = (2 - 2 * r)**2 * (2 * z - 2 * r)
    return V**n * jnp.mean(ok)


partition_function = jax.jit(partition_function, static_argnums=(2, 4))


@lru_cache
def partition_function2(n, r, num):
    return jnp.stack([
        partition_function(jax.random.PRNGKey(0), 1 - x / 2, n, r, num)
        for x in tqdm(xs)
    ])

In [None]:
partition_function2(8, 0.1, 200_000)

In [None]:
plt.style.use('classic')

Va = 2 * jnp.pi
Vb = (2 - xs) * jnp.pi

def Z(n, V, r):
    vol = 4/3 * jnp.pi * r**3
    return V**n * (1 - n * (n-1) / 2 * vol / V)

fig, axs = plt.subplots(1, 2, dpi=100, figsize=(8, 3))

exps = [(8, 0.01, 1e-6, 2000), (8, 0.01, 0.1, 2000), ]

for ax, (n, vel, r, rep) in zip(axs, exps):
    plt.sca(ax)
    works = foobar(n, vel, r, rep)
    zs = partition_function2(n, r, 200_000)

    
    plt.plot(
        Z(n, Vb, r) / Z(n, Va, r), 
        jnp.exp(-works).mean(1), 
        '-b', 
        label=fr'approximate Z'
    )
#     for rep in jnp.logspace(1, jnp.log10(works.shape[1]), 10):
    plt.plot(
        zs / zs[0], 
        jnp.exp(-works[:, :]).mean(1), 
        '-r', 
        label=fr'numerical Z'
    )

    plt.legend(loc=0, fontsize=12, labelspacing=0, handlelength=1.5, handletextpad=0.2, frameon=False)
    plt.plot([0,1], [0,1], '--k')
    plt.xscale('log')
    plt.yscale('log')
    plt.title(fr'$n={n}$, $v_p={texnum(vel)}$, $r={texnum(r)}$', fontsize=13)
    plt.yticks([])
    plt.xlabel(r'$e^{-\Delta F / T} = Z_B/Z_A$')

plt.sca(axs[0])
plt.xlim(1e-5, 1)
plt.ylim(1e-5, 1)
plt.sca(axs[1])
plt.xlim(1e-7, 1)
plt.ylim(1e-7, 1)

plt.sca(axs[0])
plt.ylabel(r'$\langle e^{-W / T} \rangle$')

plt.tight_layout(pad=0.1)

plt.savefig('je.pgf')

In [None]:
works.shape

In [None]:
zs.shape

In [None]:
works.shape

In [None]:
zs

In [None]:
i = -1

rep = 1 + jnp.arange(works.shape[1])
error = jnp.exp(-works[i, :]).cumsum() / rep - (zs[i] / zs[0])

plt.plot(rep, error, '.-')
# plt.xscale('log')
# plt.yscale('log')


In [None]:
plt.hist(works[-1], bins=40);

In [None]:
(works < 0.0).sum(1)

In [None]:
Va = 2 * jnp.pi
Vb = (2 - xs) * jnp.pi

def Z(n, V, r):
    vol = 4/3 * jnp.pi * r**3
    return V**n * (1 - n * (n-1) / 2 * vol / V)

plt.style.use('classic')

fig, axs = plt.subplots(1, 2, dpi=100, figsize=(8, 3))

exps = [(8, 0.01, 1e-6, 1000), (8, 0.01, 0.1, 1000), ]

for ax, (n, vel, r, rep) in zip(axs, exps):
    plt.sca(ax)
    works = foobar(n, vel, r, rep)

    for m in [1, 10, works.shape[1]]:
        plt.plot(
            Z(n, Vb, r) / Z(n, Va, r), 
            jnp.exp(-works[:, :m]).mean(1), 
            '-', 
            label=fr'$N_{{\mathrm{{avg}}}} = {m}$'
        )

    plt.legend(loc=0, fontsize=12, labelspacing=0, handlelength=1.5, handletextpad=0.2, frameon=False)
    plt.plot([0,1], [0,1], '--k')
    plt.xscale('log')
    plt.yscale('log')
    plt.title(fr'$n={n}$, $v_p={vel}$', fontsize=13)

    plt.xlim(1e-7, 1)
    plt.ylim(1e-7, 1)
    plt.yticks([])

plt.sca(axs[1])
plt.xlabel(r'$e^{-\Delta F / T} = Z_B/Z_A$')
plt.sca(axs[0])
plt.ylabel(r'$\langle e^{-W / T} \rangle$')

plt.tight_layout(pad=0.1)

plt.savefig('je.pgf')

In [None]:
def cond(x):
    return x == 1

def body(x):
    return 1

jax.lax.while_loop(cond, body, 0)

In [None]:
def plot_piston(state, proj, draw_walls=True):
    plt.axis('off')
    plt.axis('square')
    plt.xlim(-1.1, 1.1)
    plt.ylim(-1.1, 1.1)

    r = state['balls']['r']
    p = proj(state['balls']['x'])
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    
    x = p[:, 0] + r * jnp.cos(phi[:, None])
    y = p[:, 1] + r * jnp.sin(phi[:, None])
    plt.plot(x, y)
    
    if draw_walls:
        x = proj(state['walls']['x'])
        j = proj(state['walls']['j'])
        k = proj(state['walls']['k'])
        path = [x, x + j, x + j + k, x + k, x]
        plt.plot([x[:, 0] for x in path], [x[:, 1] for x in path], 'white')


def xy(pos):
    return pos[..., [0, 1]]

def xz(pos):
    return pos[..., [0, 2]]

def zy(pos):
    return pos[..., [2, 1]]

In [None]:
plt.style.use('dark_background')

fig, [axs1, axs2] = plt.subplots(2, 3, figsize=(9, 6))

def view(axs, states):
    ax1, ax2, ax3 = axs
    
    plt.sca(ax1)
    plt.cla()
    plot_piston(state, xz)

    plt.sca(ax2)
    plt.cla()
    plot_piston(state, xy, False)
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    plt.plot(jnp.cos(phi), jnp.sin(phi), 'white')

    plt.sca(ax3)
    plt.cla()
    plot_piston(state, zy)
    
state = init_piston_j(jax.random.PRNGKey(0), 50, 1e-1)
view(axs1, state)

vel = 15
state['walls']['v'] = jnp.array([
    [0.0, 0.0, vel/2],
    [0.0, 0.0, -vel/2],
])
_, state, _ = forward_j(1.5 / vel, state)
view(axs2, state)

plt.tight_layout()
plt.savefig('compression.pdf')

In [None]:
plt.style.use('dark_background')

fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(9, 3))

def view(state):
    plt.sca(ax1)
    plt.cla()
    plot_piston(state, xz)

    plt.sca(ax2)
    plt.cla()
    plot_piston(state, xy, False)
    phi = jnp.linspace(0, 2 * jnp.pi, 100)
    plt.plot(jnp.cos(phi), jnp.sin(phi), 'white')

    plt.sca(ax3)
    plt.cla()
    plot_piston(state, zy)

    plt.tight_layout()

fps = 10
moviewriter = FFMpegFileWriter(fps=fps)

sout = 2.0
sin = 8.0
vel = 1.0

dt = (1.5 / vel) / (sin * fps)
state = init_piston_j(jax.random.PRNGKey(0), 50, 1e-1)

def capture(state, sec):
    for _ in tqdm(range(round(sec * fps))):
        view(state)
        moviewriter.grab_frame()
        _, state, _ = forward_j(dt, state)
        
    return state

with moviewriter.saving(fig, 'myfile.mp4', dpi=200):
    state = capture(state, sout)

#     state['walls']['v'] = jnp.array([
#         [0.0, 0.0, vel/2],
#         [0.0, 0.0, -vel/2],
#     ])
    
#     state = capture(state, sin)

#     state['walls']['v'] = jnp.array([
#         [0.0, 0.0, 0],
#         [0.0, 0.0, 0],
#     ])

#     state = capture(state, sout)