In [21]:
import itertools
from functools import partial

import numpy as np
from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")
import jax.numpy as jnp
from jax_tqdm import scan_tqdm

from special_unitary import LALG_SU_N, proj_SU3, unitary_violation, special_unitary_grad

from gauge_field_utils import wilson_action, accurate_wilson_hamiltonian_error, smear_HYP, smear_stout, wilson_loops_range, luscher_weisz_action
from integrators import int_LF2, int_MN2_omelyan, int_MN4_takaishi_forcrand

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
def HMC(action_fn, beta, integrator, tau_md=1.0, steps_md=10, unitary_violation_tol=5e-6):

    action_grad = special_unitary_grad(lambda x: action_fn(x, beta))
    
    def step_fn(links, random_key, skip_metropolis=False):
        key1, key2 = jax.random.split(random_key, 2)
        Nc = links.shape[-1]

        p0 = jax.random.normal(key1, shape=(*links.shape[:-2], Nc*Nc-1))

        links_next, p_final = integrator(links, p0, action_grad, tau_md, steps_md)
        links_next = jax.lax.cond(
            unitary_violation_tol is not None,
            lambda: jax.lax.cond(
                unitary_violation(links_next, "mean") > unitary_violation_tol,
                proj_SU3,
                lambda x: x,
                links_next
            ),
            lambda: links_next
        )

        delta_hamiltonian = accurate_wilson_hamiltonian_error(links, p0, links_next, p_final, beta)
        p_acc = jnp.minimum(1, jnp.exp(-delta_hamiltonian))
        
        return jax.lax.cond(
            skip_metropolis,
            lambda: links_next,
            lambda: jax.lax.cond(
                jax.random.uniform(key2) < p_acc,
                lambda: links_next,
                lambda: links
            )
        ), (delta_hamiltonian, p_acc)
    
    return jax.jit(step_fn)

In [4]:
L = (12, 12, 12, 12)

random_key, key1, key2 = jax.random.split(jax.random.key(0), num=3)
gauge_links = LALG_SU_N(jax.random.normal(key1, shape=(*L, 4, 8), dtype=jnp.float32))
# gauge_links = LALG_SU_N(jnp.load("warmed_32_16x3_beta_6p0.npy"))

In [5]:
stepper_fn = HMC(
    action_fn=wilson_action,
    beta=5.7,
    integrator=int_MN2_omelyan,
    tau_md=1.0,
    steps_md=20,
    unitary_violation_tol=6e-5,
)

In [50]:
configs = []
next_links = gauge_links
for i in (bar := tqdm(range(200))):
    next_links, aux = stepper_fn(next_links, jax.random.key(int(10000*np.random.rand())))
    bar.set_postfix({"dH": aux[0]})
    configs.append(np.array(next_links))
configs = np.stack(configs)

100%|██████████| 200/200 [11:17<00:00,  3.39s/it, dH=0.026435852]  
