In [1]:
import numpy as np
import torch

from gauge_field_utils import *
from special_unitary import *
from integrators import *

In [6]:
def HMC(action_fn, error_fn, integrator, tau_md=1.0, steps_md=10, unitary_violation_tol=5e-6):

    action_grad = special_unitary_grad(action_fn)
    
    def step_fn(links, skip_metropolis=False):
        Nc = links.shape[-1]

        p0 = torch.normal(0, 1, size=(*links.shape[:-2], Nc*Nc-1))

        links_next, p_final = integrator(links, p0, action_grad, tau_md, steps_md)
        links_next = torch.cond(
            unitary_violation_tol is not None,
            lambda: torch.cond(
                unitary_violation(links_next, "mean") > unitary_violation_tol,
                proj_SU3,
                lambda x: x,
                links_next
            ),
            lambda: links_next
        )

        delta_hamiltonian = error_fn(links, p0, links_next, p_final)
        p_acc = torch.minimum(torch.scalar_tensor(1, dtype=delta_hamiltonian.dtype, device=delta_hamiltonian.device), torch.exp(-delta_hamiltonian))
        
        return torch.cond(
            skip_metropolis,
            lambda: links_next,
            lambda: torch.cond(
                torch.rand() < p_acc,
                lambda: links_next,
                lambda: links
            )
        ), (delta_hamiltonian, p_acc)
    
    return torch.compile(step_fn, backend="cudagraphs")

In [7]:
L = (16, 16, 16, 16)

gauge_links = FastExpiSU3.apply(torch.normal(
    0, 1,
    size=(*L, 4, 8),
    dtype=torch.float32
))

stepper_fn = HMC(
    action_fn=partial(luscher_weisz_action, beta=8.00, u0=0.8876875888655319),
    error_fn=partial(luscher_weisz_gauge_error, beta=8.00, u0=0.8876875888655319),
    integrator=int_4MN4FP,
    tau_md=1.0,
    steps_md=25,
    unitary_violation_tol=5e-6,
)

In [8]:
test_next = stepper_fn(gauge_links)

skipping cudagraphs due to skipping cudagraphs due to cpu device (normal)
skipping cudagraphs due to skipping cudagraphs due to cpu device (arg0_1). Found from : 
   File "/mnt/d/other-projects/lqcd-lambda/torch-implementation/integrators.py", line 13, in _torch_scan
    val, y = f(val, x)
  File "/mnt/d/other-projects/lqcd-lambda/torch-implementation/integrators.py", line 75, in scan_fn
    q = T_operator(q, p, rho)
  File "/mnt/d/other-projects/lqcd-lambda/torch-implementation/integrators.py", line 69, in <lambda>
    T_operator = lambda q, p, coef: torch.matmul(FastExpiSU3.apply(coef * eps * p), q)



Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/nobe/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3546, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_9754/2752613319.py", line 1, in <module>
    test_next = stepper_fn(gauge_links)
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nobe/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_9754/615773733.py", line 5, in step_fn
    def step_fn(links, skip_metropolis=False):
  File "/mnt/d/other-projects/lqcd-lambda/torch-implementation/integrators.py", line 66, in int_4MN4FP
    def int_4MN4FP(q0, p0, F_func, tau, steps_md, rho=0.1786178958448091, theta=-0.06626458266981843, lambd=0.7123418310626056):
  File "/mnt/d/other-projects/lqcd-lambda/torch-implementation/integrators.py", line 9, in _torch_scan
    