In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
import itertools
from functools import partial

import numpy as np
from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")
import jax.numpy as jnp
from jax_tqdm import scan_tqdm

from gauge_field_utils import coef_to_lie_group, wilson_action, mean_wilson_rectangle, accurate_wilson_hamiltonian_error
from integrators import int_LF2, int_MN2_omelyan, int_MN4_takaishi_forcrand

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
def HMC(beta, afn, nfev_approx):
    action_fn = lambda x: afn(coef_to_lie_group(x), beta)
    action_grad_fn = jax.grad(action_fn)
    
    def step_fn(coef, tau, random_key):
        key1, key2 = jax.random.split(random_key, num=2)
        p0 = jax.random.normal(key1, shape=coef.shape, dtype=coef.dtype)

        coef_prop, pt = int_MN4_takaishi_forcrand(coef, p0, action_grad_fn, tau, nfev_approx)
        dH = accurate_wilson_hamiltonian_error(coef, p0, coef_prop, pt, beta)

        p_acc = jnp.minimum(1, jnp.exp(-dH))

        coef_next = jax.lax.cond(
            jax.random.uniform(key2) < p_acc,
            lambda: coef_prop,
            lambda: coef
        )

        return coef_next, (dH, p_acc)

    return step_fn

def warmup_tint(coef, beta, random_key, observable_fn, tau, iters=2000, nfev_approx=20):
    stepper_fn = jax.jit(HMC(beta, wilson_action, nfev_approx))
    
    @scan_tqdm(iters, print_rate=1, tqdm_type="notebook")
    def warmup_step(carry, step):
        coef, rng_key, running_p_acc = carry
        rng_key, k1 = jax.random.split(rng_key)
        
        coef, (dH, p_acc) = stepper_fn(coef, tau, k1)
        running_p_acc = (running_p_acc * step + p_acc) / (step + 1)
        o = observable_fn(coef)
        jax.debug.print("warmup step {step} ; o={o} ; dH={dH} ; p_acc={p_acc}", step=step, o=o, dH=dH, p_acc=running_p_acc)

        carry = (coef, rng_key, running_p_acc)
        return carry, o

    (coef, *_), O = jax.lax.scan(
        warmup_step,
        init=(coef, random_key, 0),
        xs=np.arange(iters),
        length=iters
    )

    return coef, O

@partial(jax.jit, static_argnames=["R_range", "T_range"])
def calculate_wilson_loops(gauge_coef, R_range, T_range):
    R_min, R_max = R_range
    T_min, T_max = T_range
    wilson_loop_values = jnp.array([mean_wilson_rectangle(coef_to_lie_group(gauge_coef), R, T, time_unique=False) for R, T in itertools.product(range(R_min, R_max+1), range(T_min, T_max+1))]).reshape(R_max-R_min+1, T_max-T_min+1)
    return wilson_loop_values

In [3]:
L = (20, 10, 10, 10)
R_range = (1, 10)
T_range = (1, 20)

random_key, key1, key2 = jax.random.split(jax.random.key(0), num=3)
coef = jax.random.normal(key1, shape=(*L, 4, 8), dtype=jnp.float32)
# coef = jnp.load("warmed_16_8x3_beta_6p7.npy")

In [None]:
coef, O = warmup_tint(
    coef,
    beta=6.7,
    random_key=random_key,
    observable_fn=jax.jit(lambda x: mean_wilson_rectangle(coef_to_lie_group(x), 3, 3, time_unique=False).real),
    tau=1.0,
    iters=10000,
    nfev_approx=10
)

In [None]:
wilson_loops = []

stepper_fn = jax.jit(HMC(6.7, wilson_action, nfev_approx=8))

for i in (bar := tqdm(range(20000))):
    random_key, key1 = jax.random.split(random_key)
    coef, (dH, p_acc) = stepper_fn(coef, tau=1.0, random_key=key1)

    # Calculate wilson loops
    if i % 50 == 0:
        wilson_loops.append(calculate_wilson_loops(coef, R_range, T_range))

    bar.set_postfix({"dH": dH})

In [27]:
wilson_loops = jnp.array(wilson_loops)
mean_loops = jnp.real(wilson_loops.mean(axis=0)).copy()
omrt_loops = 1 - mean_loops/3

In [None]:
plt.plot(jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]))
plt.show()

In [39]:
sigma, V0, alpha = jnp.polyfit(
    1+jnp.arange(5+1).astype(jnp.float32),
    jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]) * (1+jnp.arange(5+1)),
    deg=2
)

In [None]:
x = 1+jnp.linspace(0, 5, 100)
y = V0 + alpha / x + sigma * x

plt.scatter(1+jnp.arange(5+1), jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]))
plt.plot(x, y)
plt.show()

In [None]:
jnp.sqrt(sigma) / 440 * 1000 * 0.1973164956590371