In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
import itertools
from functools import partial

import numpy as np
from tqdm import tqdm

import jax
jax.config.update("jax_default_matmul_precision", "highest")
import jax.numpy as jnp
from jax_tqdm import scan_tqdm

from gauge_field_utils import coef_to_lie_group, wilson_action, mean_wilson_rectangle, accurate_wilson_hamiltonian_error
from integrators import int_LF2, int_MN2_omelyan, int_MN4_takaishi_forcrand

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
def HMC(beta, afn, nfev_approx):
    action_fn = lambda x: afn(coef_to_lie_group(x), beta)
    action_grad_fn = jax.grad(action_fn)
    
    def step_fn(coef, tau, random_key):
        key1, key2 = jax.random.split(random_key, num=2)
        p0 = jax.random.normal(key1, shape=coef.shape, dtype=coef.dtype)

        coef_prop, pt = int_MN4_takaishi_forcrand(coef, p0, action_grad_fn, tau, nfev_approx)
        dH = accurate_wilson_hamiltonian_error(coef, p0, coef_prop, pt, beta)

        p_acc = jnp.minimum(1, jnp.exp(-dH))

        coef_next = jax.lax.cond(
            jax.random.uniform(key2) < p_acc,
            lambda: coef_prop,
            lambda: coef
        )

        return coef_next, (dH, p_acc)

    return step_fn

def warmup_tint(coef, beta, random_key, observable_fn, tau, iters=2000, nfev_approx=20):
    stepper_fn = jax.jit(HMC(beta, wilson_action, nfev_approx))
    
    @scan_tqdm(iters, print_rate=1, tqdm_type="notebook")
    def warmup_step(carry, step):
        coef, rng_key, running_p_acc = carry
        rng_key, k1 = jax.random.split(rng_key)
        
        coef, (dH, p_acc) = stepper_fn(coef, tau, k1)
        running_p_acc = (running_p_acc * step + p_acc) / (step + 1)
        o = observable_fn(coef)
        jax.debug.print("warmup step {step} ; o={o} ; dH={dH} ; p_acc={p_acc}", step=step, o=o, dH=dH, p_acc=running_p_acc)

        carry = (coef, rng_key, running_p_acc)
        return carry, o

    (coef, *_), O = jax.lax.scan(
        warmup_step,
        init=(coef, random_key, 0),
        xs=np.arange(iters),
        length=iters
    )

    return coef, O

@partial(jax.jit, static_argnames=["R_range", "T_range"])
def calculate_wilson_loops(gauge_coef, R_range, T_range):
    R_min, R_max = R_range
    T_min, T_max = T_range
    wilson_loop_values = jnp.array([mean_wilson_rectangle(coef_to_lie_group(gauge_coef), R, T, time_unique=False) for R, T in itertools.product(range(R_min, R_max+1), range(T_min, T_max+1))]).reshape(R_max-R_min+1, T_max-T_min+1)
    return wilson_loop_values

In [3]:
L = (40, 20, 20, 20)
R_range = (1, 15)
T_range = (1, 20)

random_key, key1, key2 = jax.random.split(jax.random.key(0), num=3)
coef = jax.random.normal(key1, shape=(*L, 4, 8), dtype=jnp.float32)
# coef = jnp.load("warmed_16_8x3_beta_6p7.npy")

In [4]:
coef, O = warmup_tint(
    coef,
    beta=6.0,
    random_key=random_key,
    observable_fn=jax.jit(lambda x: mean_wilson_rectangle(coef_to_lie_group(x), 3, 3, time_unique=False).real),
    tau=1.5,
    iters=10000,
    nfev_approx=20
)

2025-03-09 01:02:28.016051: W external/xla/xla/hlo/transforms/host_offloader.cc:360] Token parameters are not supported for streaming.


warmup step 0 ; o=8.51610311656259e-05 ; dH=-18.2509765625 ; p_acc=1.0


Running for 10,000 iterations:   0%|          | 0/10000 [00:00<?, ?it/s]

warmup step 1 ; o=0.0006130764959380031 ; dH=-11.4896240234375 ; p_acc=1.0
warmup step 2 ; o=-0.0007965529803186655 ; dH=-4.78375244140625 ; p_acc=1.0
warmup step 3 ; o=0.0010879465844482183 ; dH=-1.8739013671875 ; p_acc=1.0
warmup step 4 ; o=0.0008943431312218308 ; dH=-0.32208251953125 ; p_acc=1.0
warmup step 5 ; o=0.0013544571120291948 ; dH=0.306884765625 ; p_acc=0.9559559226036072
warmup step 6 ; o=0.001267742714844644 ; dH=0.69915771484375 ; p_acc=0.8903913497924805
warmup step 7 ; o=0.001267742714844644 ; dH=0.84063720703125 ; p_acc=0.833021879196167
warmup step 8 ; o=0.001267742714844644 ; dH=1.03125 ; p_acc=0.7800818085670471
warmup step 9 ; o=0.0008233338012360036 ; dH=0.984832763671875 ; p_acc=0.7394238114356995
warmup step 10 ; o=0.0008233338012360036 ; dH=1.08013916015625 ; p_acc=0.7030714750289917
warmup step 11 ; o=0.0005616701673716307 ; dH=0.951904296875 ; p_acc=0.676649272441864
warmup step 12 ; o=0.0018294022884219885 ; dH=1.10028076171875 ; p_acc=0.6501976251602173
wa

E0309 02:06:48.017466    7468 pjrt_stream_executor_client.cc:3050] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/asyncio/base_events.py

XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/asyncio/events.py", line 84, in _run
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
  File "/tmp/ipykernel_7468/1254799174.py", line 1, in <module>
  File "/tmp/ipykernel_7468/763720248.py", line 40, in warmup_tint
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 331, in scan
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/core.py", line 1024, in process_primitive
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/pjit.py", line 341, in cache_miss
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/pjit.py", line 1679, in _pjit_call_impl_python
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/profiler.py", line 334, in wrapper
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1287, in __call__
  File "/home/ubuntu/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/jax/_src/callback.py", line 777, in _wrapped_callback
KeyboardInterrupt: 

In [None]:
wilson_loops = []

stepper_fn = jax.jit(HMC(6.7, wilson_action, nfev_approx=8))

for i in (bar := tqdm(range(20000))):
    random_key, key1 = jax.random.split(random_key)
    coef, (dH, p_acc) = stepper_fn(coef, tau=1.0, random_key=key1)

    # Calculate wilson loops
    if i % 50 == 0:
        wilson_loops.append(calculate_wilson_loops(coef, R_range, T_range))

    bar.set_postfix({"dH": dH})

In [27]:
wilson_loops = jnp.array(wilson_loops)
mean_loops = jnp.real(wilson_loops.mean(axis=0)).copy()
omrt_loops = 1 - mean_loops/3

In [None]:
plt.plot(jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]))
plt.show()

In [39]:
sigma, V0, alpha = jnp.polyfit(
    1+jnp.arange(5+1).astype(jnp.float32),
    jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]) * (1+jnp.arange(5+1)),
    deg=2
)

In [None]:
x = 1+jnp.linspace(0, 5, 100)
y = V0 + alpha / x + sigma * x

plt.scatter(1+jnp.arange(5+1), jnp.log(mean_loops[:-2,5] / mean_loops[:-2,6]))
plt.plot(x, y)
plt.show()

In [None]:
jnp.sqrt(sigma) / 440 * 1000 * 0.1973164956590371