In [1]:
from functools import partial
import numpy as np
import torch

from gauge_field_utils import *
from special_unitary import *
from integrators import *

In [2]:
def HMC(action_fn, error_fn, integrator, tau_md=1.0, steps_md=10, unitary_violation_tol=5e-6):

    action_grad = special_unitary_grad(action_fn)

    def step_fn(links, skip_metropolis=False):
        Nc = links.shape[-1]

        p0 = torch.normal(0, 1, size=(*links.shape[:-2], Nc*Nc-1), dtype=links.real.dtype, device=links.device)

        links_next, p_final = integrator(links, p0, action_grad, tau_md, steps_md)
        links_next = torch.cond(
            unitary_violation_tol is not None,
            lambda: torch.cond(
                unitary_violation(links_next) > torch.scalar_tensor(unitary_violation_tol, dtype=links_next.real.dtype, device=links_next.device),
                proj_SU3,
                lambda x: x,
                (links_next,)
            ),
            lambda: links_next
        )

        delta_hamiltonian = error_fn(links, p0, links_next, p_final)
        p_acc = torch.minimum(torch.scalar_tensor(1, dtype=delta_hamiltonian.dtype, device=delta_hamiltonian.device), torch.exp(-delta_hamiltonian))
        
        return torch.cond(
            skip_metropolis,
            lambda: links_next,
            lambda: torch.cond(
                torch.rand_like(p_acc) < p_acc,
                lambda: links_next,
                lambda: links
            )
        ), (delta_hamiltonian, p_acc)
    
    return torch.compile(step_fn, backend="cudagraphs")

In [3]:
L = (4, 4, 4, 4)

gauge_links = expi(torch.normal(
    0, 1,
    size=(*L, 4, 8),
    dtype=torch.float64,
    device="cuda"
))

# stepper_fn = HMC(
#     action_fn=partial(luscher_weisz_action, beta=8.00, u0=0.8876875888655319),
#     error_fn=partial(luscher_weisz_gauge_error, beta=8.00, u0=0.8876875888655319),
#     integrator=int_2MN,
#     tau_md=1.0,
#     steps_md=25,
#     unitary_violation_tol=5e-6,
# )
stepper_fn = HMC(
    action_fn=partial(luscher_weisz_action, beta=8.00, u0=0.8876875888655319),
    error_fn=partial(luscher_weisz_gauge_error, beta=8.00, u0=0.8876875888655319),
    integrator=int_2MN,
    tau_md=1.0,
    steps_md=25,
    unitary_violation_tol=5e-6,
)