# Setup

In [None]:
import analysis as al
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import tqdm.auto as tqdm
%matplotlib widget

In [None]:
torch.set_default_device('cpu')
train_device = 'cuda'

In [None]:
def grab(x):
    return x.detach().cpu().numpy()

In [None]:
ens = np.fromfile('../heatbath_cpp/data/cpn_b4.0_L64_Nc3_big_ens.dat', dtype=np.complex128).reshape(-1, 64, 64, 3)[150:]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8,4))
axes[0].imshow(np.angle(ens[-1,:,:,1]/ens[-1,:,:,2]), cmap='twilight', interpolation='nearest')
axes[1].imshow(np.angle(ens[-5,:,:,1]/ens[-5,:,:,2]), cmap='twilight', interpolation='nearest')
plt.show()

Action defined by:
$$
S(z) = -\beta \sum_{i,\mu} |z_i z^\dagger_{i+\hat{\mu}}|^2 = -\beta \sum_{i,\mu} (z_i z^\dagger_{i+\hat\mu}) (z_{i+\hat\mu} z^\dagger_i)
$$
Force is then:
$$
\frac{\partial}{\partial X^R_i} S(z) = -2 \beta \, \mathrm{Re} \sum_{\mu} (z^\dagger_{i+\hat\mu} (z_{i+\hat\mu} z^\dagger_i) + z_{i-\hat\mu}),
\qquad
\frac{\partial}{\partial X^I_i} S(z) = -2 \beta \, \mathrm{Re} \sum_{\mu} (z_{i+\hat\mu} + z_{i-\hat\mu})
$$

In [None]:
ens_torch = torch.tensor(ens)

In [None]:
def to_complex(x):
    XR = x[...,::2]
    XI = x[...,1::2]
    z = XR + 1j*XI
    zbar = XR - 1j*XI
    return z, zbar

In [None]:
def to_real(z):
    return torch.stack([z.real, z.imag], axis=-1).flatten(-2,-1)

In [None]:
def action(z, zbar, *, beta):
    assert len(z.shape) == 4, 'z must have shape (batch, Lx, Lt, Nc)'
    assert z.shape == zbar.shape
    S = torch.zeros(z.shape[0])
    for mu in range(2):
        h1 = torch.sum(z * torch.roll(zbar, -1, dims=mu+1), axis=-1)
        h2 = torch.sum(zbar * torch.roll(z, -1, dims=mu+1), axis=-1)
        S = S + torch.sum(1.0 - h1*h2, axis=(1,2))
    return beta * S

In [None]:
def grad_action(x, *, beta):
    assert len(x.shape) == 4, 'x must have shape (batch, Lx, Lt, 2*Nc)'
    # F = np.zeros(x.shape)
    # for mu in range(2):
    #     F -= 2 * torch.sum(z * torch.roll(z.conj(), -1, axis=mu+1), axis=-1)
    # return beta * F
    def _single_action(x):
        z = (x[...,::2] + 1j*x[...,1::2])[None]
        zbar = (x[...,::2] - 1j*x[...,1::2])[None]
        return action(z, zbar, beta=beta)[0].real
    return torch.func.vmap(torch.func.jacrev(_single_action))(x)

In [None]:
fig, ax = plt.subplots(1,1)
E = action(ens, np.conj(ens), beta=1.0) / (ens.shape[-2]*ens.shape[-3])
E_est = al.bootstrap(al.bin_data(E, binsize=100)[1], Nboot=1000, f=al.rmean)
ax.plot(E)
ax.fill_between(
    [0, len(E)], [E_est[0]-E_est[1]]*2, [E_est[0]+E_est[1]]*2,
    ec='none', color='xkcd:red', alpha=0.5, zorder=2,
    label=rf'${E_est[0]:.3f} \pm {E_est[1]:.3f}$')
ax.legend()
ax.set_xlabel('mc step')
ax.set_ylabel('$E$')
plt.show()

In [None]:
Pij_est = []
for i in range(3):
    for j in range(3):
        Pij = np.mean(ens[...,i]*np.conj(ens[...,j]), axis=(-1,-2))
        Pij_est.append(al.bootstrap(al.bin_data(Pij, binsize=250)[1], Nboot=1000, f=al.rmean))
        print(f'{i=} {j=} {Pij_est[-1]=}')
Pij_est = np.stack(Pij_est, axis=-1)
print(f'{Pij_est.shape=}')
fig, ax = plt.subplots(1,1)
al.add_errorbar(Pij_est, ax=ax, marker='o', linestyle='', capsize=2, fillstyle='none')
plt.show()

# PINN Training

In [None]:
def reshape_jac(x: torch.Tensor):
    n_dim = len(x.shape[1:])
    d = functools.reduce(op.mul, x.shape[1:1+n_dim//2])
    return x.reshape(-1, d, d)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, Nc, hidden_ch=8):
        super().__init__()
        self.in_ch = 2*Nc
        self.out_ch = 2*Nc
        # just a crappy conv net
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(self.in_ch, hidden_ch, 3, padding=1, padding_mode='circular'),
            torch.nn.SiLU(),
            torch.nn.Conv2d(hidden_ch, hidden_ch, 3, padding=1, padding_mode='circular'),
            torch.nn.SiLU(),
            torch.nn.Conv2d(hidden_ch, self.out_ch, 3, padding=1, padding_mode='circular'),
        )

    def _vel_single(self, x):
        f = self.net(x)
        fx = torch.sum(f*x)
        return f - fx*x

    def forward(self, x):
        funcF = torch.func.vmap(self._vel_single)
        # F = funcF(x)
        # jacF = torch.func.vmap(torch.func.jacfwd(self._vel_single))(x)
        # jacF = reshape_jac(jacF)
        # divF = torch.einsum('xii->x', jacF)
        # hutch estimator
        eta = torch.randn_like(x)
        F, jvp = torch.func.jvp(funcF, (x,), (eta,))
        inds = tuple(range(1, len(F.shape)))
        divF = (eta*jvp).sum(inds)
        # divF = torch.zeros(F.shape[0]).to(device=F.device)
        return F, divF

In [None]:
def train_step(x, gradS, Q, model, optimizer):
    optimizer.zero_grad()
    F, divF = model(x)
    inds = tuple(range(1, len(F.shape)))
    FgradS = (F*gradS).sum(inds)
    assert divF.shape == FgradS.shape
    assert divF.shape == Q.shape, f'{divF.shape=} {Q.shape=}'
    # NOTE: assumes either on-policy training or <Q> = 0
    loss = ((divF - FgradS - Q)**2).mean()
    loss.backward()
    optimizer.step()
    return dict(loss=grab(loss))

In [None]:
def main(ens, beta, *, i=0, j=1, batch_size, n_step):
    ens = to_real(ens).to(torch.float32)
    gradS = grad_action(ens, beta=beta)
    Q = torch.mean(ens[...,0,i]*ens[...,0,j].conj(), axis=-1).real
    # TEST
    Q_est = al.bootstrap(grab(Q), Nboot=1000, f=al.rmean)
    print(f'{Q_est=}')
    Nc = ens.shape[-1]//2
    # move channels dim
    gradS = gradS.moveaxis(-1, 1)
    ens = ens.moveaxis(-1, 1)
    model = Model(Nc).to(train_device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    hist = dict(loss=[])
    for i in range(n_step):
        inds = np.random.randint(len(ens), size=batch_size)
        xi = ens[inds].to(train_device)
        gradSi = gradS[inds].to(train_device)
        Qi = Q[inds].to(train_device)
        res = train_step(xi, gradSi, Qi, model, optimizer)
        hist['loss'].append(res['loss'])
        if (i+1)%100 == 0:
            print(f'Step {i+1}: Loss {res["loss"]:.2g}')
    for k in hist:
        hist[k] = np.stack(hist[k])
    fig, ax = plt.subplots(1,1)
    ax.plot(hist['loss'])
    ax.set_ylabel('Loss')
    ax.set_xlabel('Train step')
    ax.set_yscale('log')
    plt.show()

In [None]:
main(ens_torch, beta=4.0, batch_size=32, n_step=10000)