# Setup

In [None]:
import analysis as al
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm.auto as tqdm
%matplotlib widget

In [None]:
# ens = np.fromfile('../heatbath_cpp/data/cpn_b6.2_L256_Nc3_ens.dat', dtype=np.complex128).reshape(-1, 256, 256, 3)
# ens = np.fromfile('../heatbath_cpp/data/cpn_b4.5_L64_Nc3_big_ens.dat', dtype=np.complex128).reshape(-1, 64, 64, 3)
ens = np.fromfile('../heatbath_cpp/data/cpn_b4.0_L64_Nc3_big_ens.dat', dtype=np.complex128).reshape(-1, 64, 64, 3)[150:]

In [None]:
file = h5py.File(f'../heatbath_cpp/data/CP2_lat64x64_b4.0.h5', 'r')["configs"]
burn = int(file["info"]["burn"][0])
stepsize = 5  # int(file["info"]["stepsize"][0])
ens2 = np.array(file["vectors"][burn::stepsize])

In [None]:
print(f'{ens.shape=}')
print(f'{ens2.shape=}')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8,4))
axes[0].imshow(np.angle(ens[-1,:,:,1]/ens[-1,:,:,2]), cmap='twilight', interpolation='nearest')
axes[1].imshow(np.angle(ens2[-1,:,:,1]/ens2[-1,:,:,2]), cmap='twilight', interpolation='nearest')
plt.show()

In [None]:
def action(z, zbar, *, beta):
    assert len(z.shape) == 4, 'z must have shape (batch, Lx, Lt, Nc)'
    assert z.shape == zbar.shape
    S = np.zeros(z.shape[0])
    for mu in range(2):
        h1 = np.sum(z * np.roll(zbar, -1, axis=mu+1), axis=-1)
        h2 = np.sum(zbar * np.roll(z, -1, axis=mu+1), axis=-1)
        S = S + np.sum(1.0 - h1*h2, axis=(1,2))
    return beta * S

In [None]:
fig, ax = plt.subplots(1,1)
E = action(ens, np.conj(ens), beta=1.0) / (ens.shape[-2]*ens.shape[-3])
E_est = al.bootstrap(al.bin_data(E, binsize=100)[1], Nboot=1000, f=al.rmean)
E2 = action(ens2, np.conj(ens2), beta=1.0) / (ens.shape[-2]*ens.shape[-3])
E2_est = al.bootstrap(al.bin_data(E2, binsize=100)[1], Nboot=1000, f=al.rmean)
ax.plot(E)
ax.plot(E2)
ax.fill_between(
    [0, len(E)], [E_est[0]-E_est[1]]*2, [E_est[0]+E_est[1]]*2,
    ec='none', color='xkcd:red', alpha=0.5, zorder=2,
    label=rf'${E_est[0]:.3f} \pm {E_est[1]:.3f}$')
ax.fill_between(
    [0, len(E2)], [E2_est[0]-E2_est[1]]*2, [E2_est[0]+E2_est[1]]*2,
    ec='none', color='xkcd:red', alpha=0.5, zorder=2,
    label=rf'${E2_est[0]:.3f} \pm {E2_est[1]:.3f}$')
ax.legend()
ax.set_xlabel('mc step')
ax.set_ylabel('$E$')
plt.show()

In [None]:
Pij_est = []
for i in range(3):
    for j in range(3):
        Pij = np.mean(ens[...,i]*np.conj(ens[...,j]), axis=(-1,-2))
        Pij_est.append(al.bootstrap(al.bin_data(Pij, binsize=250)[1], Nboot=1000, f=al.rmean))
        print(f'{i=} {j=} {Pij_est[-1]=}')
Pij_est = np.stack(Pij_est, axis=-1)
print(f'{Pij_est.shape=}')
fig, ax = plt.subplots(1,1)
al.add_errorbar(Pij_est, ax=ax, marker='o', linestyle='', capsize=2, fillstyle='none')
plt.show()

# StN stats

In [None]:
def build_corr(ens, dx=(0,25), i=0, j=1):
    Pij = ens[...,0] * np.conj(ens[...,1])
    print(f'{Pij.shape=}')
    Cij_shape = (Pij * np.roll(np.conj(Pij), -np.array(dx), axis=(1, 2)))
    Cij = Cij_shape.flatten()

    Cij_est = np.stack([
        np.stack(al.bootstrap(np.mean(Cij_shape, axis=(1,2)), Nboot=1000, f=al.rmean)),
        # np.stack(al.bootstrap(np.mean(Cij_shape, axis=(1,2)), Nboot=1000, f=al.imean))
        np.stack(al.bootstrap(np.mean(Cij_shape, axis=(1,2)), Nboot=1000, f=lambda x: al.rmean(x**2) - al.rmean(x)**2)),
        np.stack(al.bootstrap(np.mean(Cij_shape, axis=(1,2)), Nboot=1000, f=lambda x: al.rmean(x**2)))
    ], axis=-1)
    print(Cij_est)

    bins = np.linspace(-0.1, 0.1)
    fig, axes = plt.subplots(1,2, figsize=(8,3))
    ax = axes[0]
    ax.hist(Cij.real, bins=bins, density=True)
    ax.axvline(Cij_est[0,0], color='xkcd:red', linestyle='--')
    ax.set_xlabel('Re C')
    ax.set_ylabel('Density')
    ax = axes[1]
    # ax.hist2d(Cij.real, Cij.imag, bins=bins)
    hist, xs, ys = np.histogram2d(Cij.real, Cij.imag, bins=bins, density=True)
    print(f'{hist.shape=} {xs.shape=} {ys.shape=}')
    xs = (xs[1:]+xs[:-1])/2
    ys = (ys[1:]+ys[:-1])/2
    xs, ys = np.meshgrid(xs, ys, indexing='ij')
    cols = ax.contourf(xs, ys, np.log(hist))
    ax.set_xlabel('Re C')
    ax.set_ylabel('Im C')
    fig.colorbar(cols, ax=ax)
    # ax.hist(Cij.real, bins=bins)
    # ax.scatter(Cij.real, Cij.imag, marker='.', color='k', alpha=0.5, s=1, ec='none')
    ax.axhline(0, color='w', linewidth=0.5)
    ax.axvline(0, color='w', linewidth=0.5)
    ax.plot(*Cij_est[0], marker='x', color='xkcd:red')
    ax.set_aspect(1.0)
    ax.set_title('Log density')
    fig.set_tight_layout(True)
    plt.show()
build_corr(ens, dx=(0,1))

In [None]:
def build_pattern(src, snk, latt_shape, eps=1.0):
    out = np.zeros(latt_shape)
    wrap_x = lambda x: (x + latt_shape[0]//2) % latt_shape[0] - latt_shape[0]//2
    wrap_y = lambda y: (y + latt_shape[1]//2) % latt_shape[1] - latt_shape[1]//2
    xs, ys = np.arange(latt_shape[0]), np.arange(latt_shape[1])
    xs, ys = np.meshgrid(xs, ys, indexing='ij')
    out += 1/(wrap_x(src[0]-xs)**2 + wrap_y(src[1]-ys)**2 + eps**2)
    out -= 1/(wrap_x(snk[0]-xs)**2 + wrap_y(snk[1]-ys)**2 + eps**2)
    return out

In [None]:
def _check_pattern():
    latt_shape = (64, 64)
    out = build_pattern((0,0), (16, 0), latt_shape)
    fig, ax = plt.subplots(1,1)
    ax.imshow(out, cmap='PuOr')
    ax.set_aspect(1.0)
    plt.show()
_check_pattern()

In [None]:
def build_corr_t(ens, i=0, j=1, omega=0.1):
    latt_shape = ens.shape[1:-1]
    Nc = ens.shape[-1]
    ts = np.arange(4, 32, 4)
    Ct = []
    Ct_tilde = []
    inds = np.arange(len(ens))
    for t in tqdm.tqdm(ts):
        Cij = []
        Cij_tilde = []
        for inds_chunk in tqdm.tqdm(np.split(inds, 10)):
            ens_chunk = ens[inds_chunk]
            S_orig_chunk = action(ens_chunk, np.conj(ens_chunk), beta=4.0)
            Pij = ens_chunk[...,i] * np.conj(ens_chunk[...,j])
            # Cij = np.mean(Pij * np.roll(np.conj(Pij), (-t, 0), axis=(1,2)), axis=(1,2))
            Cij.append(Pij[:,0,0] * np.conj(Pij)[:,t,0])
            # deform
            mask = build_pattern((0,0), (t,0), latt_shape)
            alpha = omega*mask[...,None] # broadcast over re/im
            X = np.stack([np.real(ens_chunk), np.imag(ens_chunk)], axis=-1)
            Y = np.zeros_like(X)
            alpha_vec = np.zeros_like(X[...,0])
            alpha_vec[...,i] = alpha[...,0]
            alpha_vec[...,j] = -alpha[...,0]
            _fname = f'alpha_vec_t{t}.npy'
            if not os.path.exists(_fname):
                np.save(_fname, alpha_vec[0])
            Y = alpha_vec[...,None] * np.stack([-np.imag(ens_chunk), np.real(ens_chunk)], axis=-1)
            assert np.allclose(np.sum(X * Y, axis=(-1,-2)), 0.0)
            lam = np.sqrt(1 + np.sum(Y**2, axis=(-1,-2), keepdims=True))
            lam_factor = np.prod((lam**2)[...,0,0], axis=(1,2))
            Ztilde = lam * X + 1j*Y
            z, zbar = Ztilde[...,0] + 1j*Ztilde[...,1], Ztilde[...,0] - 1j*Ztilde[...,1]
            S_tilde = action(z, zbar, beta=4.0)
            aX = (alpha_vec[...,None]*X).reshape(-1, *latt_shape, 2*Nc)
            A = alpha_vec[...,None] * np.identity(Nc)
            AOmega = (np.array([[0, -1], [1, 0]]) * A[...,None,None]).swapaxes(-2,-3).reshape(*A.shape[:-2], 2*Nc, 2*Nc)
            J = lam * np.identity(2*Nc) + (aX[...,None,:] * aX[...,None]) / lam - 1j*AOmega
            detJ = np.prod(np.linalg.det(J), axis=(-1,-2))
            assert detJ.shape == lam_factor.shape, f'{detJ.shape=} {lam_factor.shape=}'
            Pij_tilde = z[...,i] * zbar[...,j]
            Pij_tilde_bar = zbar[...,i] * z[...,j]
            Cij_tilde.append(Pij_tilde[:,0,0] * Pij_tilde_bar[:,t,0] * np.exp(-S_tilde + S_orig_chunk) * detJ / lam_factor)
        Ct.append(np.concatenate(Cij))
        Ct_tilde.append(np.concatenate(Cij_tilde))
    return dict(
        Ct=np.stack(Ct, axis=-1), Ct_tilde=np.stack(Ct_tilde, axis=-1), ts=ts)
res = build_corr_t(ens, omega=0.3)

In [None]:
def _plot_res(res):
    ts = res['ts']
    Ct = res['Ct']
    Ct_tilde = res['Ct_tilde']
    C_est = al.bootstrap(Ct, Nboot=1000, f=al.rmean)
    Ctilde_est = al.bootstrap(Ct_tilde, Nboot=1000, f=al.rmean)
    var_ratio = al.bootstrap(np.real(Ct)**2, np.real(Ct_tilde)**2, Nboot=1000, f=lambda x, y: al.rmean(x)/al.rmean(y))
    C2_est = al.bootstrap(np.real(Ct)**2, Nboot=1000, f=al.rmean)
    Ctilde2_est = al.bootstrap(np.real(Ct_tilde)**2, Nboot=1000, f=al.rmean)
    fig, axes = plt.subplots(3,1, figsize=(6,6))
    ax = axes[0]
    style = dict(capsize=2, linestyle='', fillstyle='none', linewidth=0.7)
    al.add_errorbar(C_est, xs=ts, ax=ax, **style, marker='o', color='k', label='C(t) [orig]')
    al.add_errorbar(Ctilde_est, xs=ts, off=0.3, ax=ax, **style, marker='s', color='xkcd:red', label='C(t) [deform]')
    ax.legend()
    ax = axes[1]
    al.add_errorbar(C_est, xs=ts, ax=ax, **style, marker='o', color='k', label='C(t) [orig]')
    al.add_errorbar(Ctilde_est, xs=ts, off=0.3, ax=ax, **style, marker='s', color='xkcd:red', label='C(t) [deform]')
    ax.set_yscale('log')
    ax.set_xlabel(r'$t$')
    ax = axes[2]
    # al.add_errorbar(C2_est, xs=ts, ax=ax, **style, marker='o', color='k')
    # al.add_errorbar(Ctilde2_est, xs=ts, ax=ax, **style, marker='s', color='xkcd:red')
    # ax.set_yscale('log')
    al.add_errorbar(var_ratio, xs=ts, ax=ax, **style, marker='o')
    plt.show()
_plot_res(res)