In [1]:
import itertools
import torch
import numpy as np
from sun_utils import SUGenerator

In [68]:
gen = SUGenerator(3, module="torch", dtype=torch.complex64, device="cuda:0")

def coef_to_lie_group(coef):
    # Ensure coef has the correct dtype before contraction.
    coef = coef.to(dtype=gen.generators.dtype)
    # Contract with the generators (note: same Einstein summation as in JAX)
    su = torch.einsum("...N,Nij->...ij", coef, gen.generators)
    # Use PyTorch’s matrix exponential for complex matrices.
    SU_N = torch.linalg.matrix_exp(1j * su)
    return SU_N

def wilson_action(field, beta):
    N = field.shape[-1]

    def plaquette(mu, nu):
        # Select the mu and nu directions (assume field shape: [..., 4, N, N])
        U_mu = field[..., mu, :, :]
        U_nu = field[..., nu, :, :]
        # Roll along the corresponding axis
        U_mu_shifted = torch.roll(U_mu, shifts=-1, dims=nu)
        U_nu_shifted = torch.roll(U_nu, shifts=-1, dims=mu)
        # Compute the conjugate transpose as .transpose(-2, -1).conj()
        U_mu_shifted_H = U_mu_shifted.transpose(-2, -1).conj()
        U_nu_H = U_nu.transpose(-2, -1).conj()
        # Compute the trace contraction (like the JAX einsum)
        Re_Tr_Plaquettes = torch.einsum("...AB,...BC,...CD,...DA->...", 
                                         U_mu, U_nu_shifted, U_mu_shifted_H, U_nu_H).real
        return 1 - Re_Tr_Plaquettes / N

    S = 0
    for mu in range(4):
        for nu in range(mu+1, 4):
            S += torch.sum(plaquette(mu, nu))
    return beta * S

def tree_level_improved_action(field, beta):
    def P(mu, nu):
        U_mu = field[:, mu, :, :]
        U_nu = field[:, nu, :, :]
        rolled1 = torch.roll(U_nu, shifts=-1, dims=mu)
        # Roll U_mu in the nu direction and take conjugate transpose
        rolled2 = torch.roll(U_mu, shifts=-1, dims=nu).transpose(-2, -1).conj()
        U_nu_H = U_nu.transpose(-2, -1).conj()
        return torch.einsum("...AB,...BC,...CD,...DA->...", 
                             U_mu, rolled1, rolled2, U_nu_H).real

    def R(mu, nu):
        U_mu = field[:, mu, :, :]
        U_nu = field[:, nu, :, :]
        R1 = torch.einsum(
            "...AB,...BC,...CD,...DE,...EF,...FA->...",
            U_mu,
            torch.roll(U_nu, shifts=-1, dims=mu),
            torch.roll(U_nu, shifts=(-1, -1), dims=(mu, nu)),
            torch.roll(U_mu, shifts=-2, dims=nu).transpose(-2, -1).conj(),
            torch.roll(U_nu, shifts=-1, dims=nu).transpose(-2, -1).conj(),
            U_nu.transpose(-2, -1).conj()
        ).real
        R2 = torch.einsum(
            "...AB,...BC,...CD,...DE,...EF,...FA->...",
            U_mu,
            torch.roll(U_mu, shifts=-1, dims=mu),
            torch.roll(U_nu, shifts=-2, dims=mu),
            torch.roll(U_mu, shifts=(-1, -1), dims=(mu, nu)).transpose(-2, -1).conj(),
            torch.roll(U_mu, shifts=-1, dims=nu).transpose(-2, -1).conj(),
            U_nu.transpose(-2, -1).conj()
        ).real
        return R1 + R2

    # Compute the average plaquette value to define u0.
    P_vals = [P(mu, nu) for mu in range(4) for nu in range(4) if mu != nu]
    P_mean = torch.stack(P_vals).mean() / 3.0
    u0_sqr = torch.pow(P_mean, 1/2)
    print(f"u0^2 = {u0_sqr}")

    S = 0
    for mu in range(4):
        for nu in range(mu+1, 4):
            S += (5 * beta) * (1 - P(mu, nu).sum() / 3) - (beta / (4 * u0_sqr)) * (1 - R(mu, nu).sum() / 3)
    return S

def chain_matmul_einsum(arrs, trace_last=True):
    # Sequentially multiply the list of matrices.
    result = arrs[0]
    for A in arrs[1:]:
        result = torch.matmul(result, A)
    if trace_last:
        # Take the trace over the last two dimensions.
        result = result.diagonal(offset=0, dim1=-2, dim2=-1).sum(-1)
    return result

def mean_wilson_rectangle(field, R, T, time_unique=True):
    result = 0
    if time_unique:
        # Loop over spatial directions 1,2,3 (assuming axis 0 is time)
        for spatial_dim in [1, 2, 3]:
            link_list = []
            for i in range(R):
                link_list.append(torch.roll(field[:, :, :, :, spatial_dim], shifts=-i, dims=spatial_dim))
            for i in range(T):
                link_list.append(torch.roll(field[:, :, :, :, 0], shifts=(-i, -R), dims=(0, spatial_dim)))
            for i in range(R - 1, -1, -1):
                link_list.append(
                    torch.roll(field[:, :, :, :, spatial_dim], shifts=(-T, -i), dims=(0, spatial_dim))
                         .transpose(-2, -1).conj()
                )
            for i in range(T - 1, -1, -1):
                link_list.append(
                    torch.roll(field[:, :, :, :, 0], shifts=-i, dims=0)
                         .transpose(-2, -1).conj()
                )
            result += chain_matmul_einsum(link_list, trace_last=True).mean()
        return result / 3
    else:
        for mu, nu in [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]:
            link_list = []
            for i in range(R):
                link_list.append(torch.roll(field[:, :, :, :, nu], shifts=-i, dims=nu))
            for i in range(T):
                link_list.append(torch.roll(field[:, :, :, :, mu], shifts=(-i, -R), dims=(mu, nu)))
            for i in range(R - 1, -1, -1):
                link_list.append(
                    torch.roll(field[:, :, :, :, nu], shifts=(-T, -i), dims=(mu, nu))
                         .transpose(-2, -1).conj()
                )
            for i in range(T - 1, -1, -1):
                link_list.append(
                    torch.roll(field[:, :, :, :, mu], shifts=-i, dims=mu)
                         .transpose(-2, -1).conj()
                )
            result += chain_matmul_einsum(link_list, trace_last=True).mean()
        return result / 6

def calculate_wilson_loops(gauge_coef, R_range, T_range):
    R_min, R_max = R_range
    T_min, T_max = T_range
    gauge = coef_to_lie_group(gauge_coef)
    loops = []
    for R, T in itertools.product(range(R_min, R_max + 1), range(T_min, T_max + 1)):
        loops.append(mean_wilson_rectangle(gauge, R, T, time_unique=False))
    wilson_loop_values = torch.tensor(loops).reshape(R_max - R_min + 1, T_max - T_min + 1)
    return wilson_loop_values

# --- HMC trajectory and dual-averaging functions ---

def parametrized_action(coef, beta):
    gauge = coef_to_lie_group(coef)
    return wilson_action(gauge, beta)

def parametrized_action_grad(coef, beta):
    coef.requires_grad_(True)
    U = parametrized_action(coef, beta)
    # Compute gradient with respect to coef.
    grad, = torch.autograd.grad(U, coef, create_graph=True)
    return grad

def parametrized_action_value_and_grad(coef, beta):
    coef.requires_grad_(True)
    U = parametrized_action(coef, beta)
    grad, = torch.autograd.grad(U, coef, create_graph=True)
    return U, grad

# --- Compile (optimize) the key functions using torch.compile ---
parametrized_action = torch.compile(parametrized_action)
parametrized_action_grad = torch.compile(parametrized_action_grad)
parametrized_action_value_and_grad = torch.compile(parametrized_action_value_and_grad)

def dual_averaging_update(log_eps_bar, h, step, accept_prob, mu, target_accept=0.75, gamma=0.05, t0=10, kappa=0.75):
    step = step + 1
    eta = 1.0 / (step + t0)
    h = (1 - eta) * h + eta * (target_accept - accept_prob)
    log_eps = mu - (step**0.5 / gamma) * h
    log_eps_bar = (step ** (-kappa)) * log_eps + (1 - step ** (-kappa)) * log_eps_bar
    new_epsilon = torch.exp(log_eps) if isinstance(log_eps, torch.Tensor) else np.exp(log_eps)
    return new_epsilon, log_eps_bar, h

def HMC_trajectory(q0, beta, steps, dt, rng=None):
    # Use torch’s random generator if none provided.
    if rng is None:
        rng = torch.Generator()
    p0 = torch.randn_like(q0)
    # fixed leapfrog integration
    q = q0.clone().detach().requires_grad_(True)
    U_initial, grad = parametrized_action_value_and_grad(q, beta)
    H_initial = U_initial + torch.sum(p0 ** 2) / 2
    # initial half step for momentum
    p = p0 + (dt / 2) * grad
    for _ in range(steps):
        # Full step for q.
        q = q + dt * p
        U, grad = parametrized_action_value_and_grad(q, beta)
        # Full step for p.
        p = p - dt * grad
    # Final half step for momentum.
    p = p - (dt / 2) * grad
    U_final = parametrized_action(q, beta)
    H_final = U_final + torch.sum(p ** 2) / 2
    dH = H_final - H_initial
    accept_prob = min(1.0, torch.exp(-dH).item())
    # Metropolis acceptance step.
    if torch.rand(1).item() < accept_prob:
        q_next = q.detach()
        accepted = True
    else:
        q_next = q0
        accepted = False
    return {
        "q_next": q_next,
        "dH": dH,
        "accept_prob": accept_prob,
        "was_accepted": accepted
    }

def warmup_epsilon(coef, beta, target_accept=0.651, eps0=1e-3, warmup_iters=300, trajectory_steps=30):
    import math
    from tqdm import tqdm
    mu = math.log(10 * eps0)
    log_eps_bar = 0.0
    h = 0.0
    eps = eps0
    for step in tqdm(range(warmup_iters), desc="Warmup epsilon"):
        traj = HMC_trajectory(coef, beta, trajectory_steps, eps)
        coef = traj["q_next"]
        accept_prob = traj["accept_prob"]
        print(f"warmup step {step} : eps={eps} ; p={accept_prob} ; dH={traj['dH']}")
        eps, log_eps_bar, h = dual_averaging_update(log_eps_bar, h, step, accept_prob, mu, target_accept=target_accept)
    return coef, eps

def warmup_tint(coef, beta, observable_fn, eps, warmup_iters=2000, trajectory_steps=30):
    from tqdm import tqdm
    O_list = []
    for step in tqdm(range(warmup_iters), desc="Warmup tint"):
        traj = HMC_trajectory(coef, beta, trajectory_steps, eps)
        coef = traj["q_next"]
        o = observable_fn(coef)
        print(f"warmup step {step} ; o={o}")
        O_list.append(o)
    O = torch.stack(O_list)
    O = O.view(-1)
    tint = 1.0
    for s in range(1, O.shape[0] // 2):
        # Using torch.corrcoef (available in recent PyTorch releases)
        corr = torch.corrcoef(torch.stack([O[:-s], O[s:]]))[0, 1]
        tint += 2 * corr
    return coef, tint, O

In [70]:
import jax.numpy as jnp

L = (16, 8, 8, 8)

coef = torch.normal(0, 1, (*L, 4, 8), dtype=torch.float32, device="cuda:0")

In [71]:
coef, eps = warmup_epsilon(
    coef,
    beta=6.7,
    target_accept=0.4,
    eps0=1e-3,
    warmup_iters=300,
    trajectory_steps=80
)

Warmup epsilon:   0%|          | 0/300 [00:00<?, ?it/s]E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1] failed while attempting to run meta for aten.view.dtype
E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1] Traceback (most recent call last):
E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1]   File "/home/nobe/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1]     r = func(*args, **kwargs)
E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1]         ^^^^^^^^^^^^^^^^^^^^^
E0306 20:21:20.526000 19306 site-packages/torch/_subclasses/fake_tensor.py:2388] [0/0_1]   File "/home/nobe/miniconda3/envs/lattice-qcd/lib/python3.11/site-packages/torch/_ops.py", line 723, in __call__
E0306 

BackendCompilerFailed: backend='inductor' raised:
RuntimeError: self.stride(-1) must be 1 to view ComplexFloat as Float (different element sizes), but got 3

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
