In [1]:
import itertools
from functools import partial

import numpy as np
from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")
jax.config.update("jax_debug_nans", True)
import jax.numpy as jnp
from jax_tqdm import scan_tqdm

from special_unitary import (
    fast_expi_su3,
    special_unitary_grad,
    unitary_violation,
    proj_SU3
)

from integrators import int_2MN, int_4MN4FP
from gauge_field_utils import (
    wilson_action,
    wilson_gauge_error,
    luscher_weisz_action,
    luscher_weisz_gauge_error,
    wilson_loops_range,
    smear_HYP,
    smear_stout,
    mean_plaquette
)

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
def HMC(action_fn, error_fn, integrator, tau_md=1.0, steps_md=10, unitary_violation_tol=5e-6):

    action_grad = special_unitary_grad(action_fn)
    
    def step_fn(links, random_key, skip_metropolis=False):
        key1, key2 = jax.random.split(random_key, 2)
        Nc = links.shape[-1]

        p0 = jax.random.normal(key1, shape=(*links.shape[:-2], Nc*Nc-1))

        links_next, p_final = integrator(links, p0, action_grad, tau_md, steps_md)
        links_next = jax.lax.cond(
            unitary_violation_tol is not None,
            lambda: jax.lax.cond(
                unitary_violation(links_next, "mean") > unitary_violation_tol,
                proj_SU3,
                lambda x: x,
                links_next
            ),
            lambda: links_next
        )

        delta_hamiltonian = error_fn(links, p0, links_next, p_final)
        p_acc = jnp.minimum(1, jnp.exp(-delta_hamiltonian))
        
        return jax.lax.cond(
            skip_metropolis,
            lambda: links_next,
            lambda: jax.lax.cond(
                jax.random.uniform(key2) < p_acc,
                lambda: links_next,
                lambda: links
            )
        ), (delta_hamiltonian, p_acc)
    
    return jax.jit(step_fn)

In [3]:
random_key, _k = jax.random.split(jax.random.key(0), num=2)

L = (16, 16, 16, 16)
gauge_links = fast_expi_su3(jax.random.normal(
    _k,
    shape=(*L, 4, 8),
    dtype=jnp.float32
))
# gauge_links = jnp.load("../results/configs_3-25-25_1/step_205_gauge.npy")

stepper_fn = HMC(
    action_fn=partial(luscher_weisz_action, beta=8.00, u0=0.8876875888655319),
    error_fn=partial(luscher_weisz_gauge_error, beta=8.00, u0=0.8876875888655319),
    integrator=int_4MN4FP,
    tau_md=1.0,
    steps_md=25,
    unitary_violation_tol=5e-6,
)

In [4]:
running_p_acc = 0

for step in (bar := tqdm(range(1000))):
    random_key, _k = jax.random.split(random_key, num=2)

    gauge_links, aux = stepper_fn(gauge_links, _k)
    mean_plaq = mean_plaquette(gauge_links)
    jnp.save(f"../results/configs_3-25-25_2/step_{step}_gauge.npy", gauge_links)
    running_p_acc = (running_p_acc * step + aux[1]) / (step + 1)

    bar.set_postfix({
        "pl": mean_plaq,
        "delta_H": aux[0],
        "p_acc": aux[1],
        "running_p_acc": running_p_acc
    })

  3%|▎         | 32/1000 [07:32<3:48:17, 14.15s/it, pl=1.8637114, delta_H=0.6849365, p_acc=0.50412226, running_p_acc=0.59303844] 


KeyboardInterrupt: 