# Euclidean Sequential Monte Carlo
We can use SMC to randomly thermalize a given timeslice using weights and resampling to acquire a representative Monte Carlo sample for the ground state. This doesn't provide a nice way to access finite temperature, but could be useful for studying quantum critical points or other zero-temperature properties.

In [None]:
import analysis as al
import matplotlib.pyplot as plt
import numpy as np
import tqdm.auto as tqdm
%matplotlib widget

# Square lattice

### Hamiltonian
The magnetic operator appearing in the Hamiltonian takes the form
$$
-J P_{wxyz} X_w X_x X_y X_z
$$
which is a controlled $X$ operator dependent on the orientation of the four links. The Rokhsar-Kivelson term takes the form
$$
\lambda P_{wxyz}.
$$

In the height representation, we have
$$
-J P_{abcd} X_{p} + \lambda P_{abcd},
$$
where $a$, $b$, $c$, $d$ are the neighboring plaquettes and $p$ is the current plaquette.

### Time evolution
Exponentiating the Hamiltonian times $dt$ gives the Trotterized time evolution operator.
$$
\exp(dt J P X - dt \lambda P) = (1-P) + P e^{-dt \lambda} [\cosh(dt J) + \sinh(dt J) X] = (1-P) + P N [(1-p) + p X],
$$
where $p = e^{-dt J} \sinh(dt J)$, $1-p = e^{-dt J} \cosh(dt J)$, and $N = e^{dt J - dt \lambda}$.


In [None]:
def make_eo_mask(shape):
    inds = np.ix_(*map(np.arange, shape))
    return sum(inds) % 2 == 0

In [None]:
def init_cold_links(shape):
    """Initialize links to a reference state that is reasonably flippable"""
    assert shape[0] == 3 and len(shape[1:]) == 3, 'specialized for Nd=3'
    mask = make_eo_mask(shape[1:])
    return np.stack([
        2*mask - 1,
        1 - 2*mask,
        2*mask - 1,
    ])

In [None]:
def step(x, *, dtJ, dtLam):
    logw = 0.0
    def update_sublattice(mu, nu, dtJ, dtLam):
        nonlocal logw
        a = x[nu]
        b = np.roll(x[mu], -1, axis=nu)
        c = -np.roll(x[nu], -1, axis=mu)
        d = -x[mu]
        P = (a == b) & (b == c) & (c == d)
        logw += np.sum((mask & P) * (-dtJ+dtLam))
        p = P*np.sinh(dtJ)*np.exp(-dtJ)
        r = np.random.random(size=p.shape)
        assert p.shape == x.shape[1:] and p.shape == mask.shape
        x[nu][mask & (r < p)] *= -1
        x[mu][np.roll(mask & (r < p), 1, axis=nu)] *= -1
        x[nu][np.roll(mask & (r < p), 1, axis=mu)] *= -1
        x[mu][mask & (r < p)] *= -1
    for mu in range(x.shape[0]):
        for nu in range(mu+1, x.shape[0]):
            # even sites
            mask = make_eo_mask(x.shape[1:])
            update_sublattice(mu, nu, dtJ, dtLam)
            # odd sites
            mask = ~mask
            update_sublattice(mu, nu, dtJ, dtLam)
    return logw

In [None]:
def resample(logw):
    assert len(logw.shape) == 1
    # average weight
    new_logw = np.logaddexp.reduce(logw) - np.log(len(logw))
    logw -= new_logw + np.log(len(logw))
    assert np.isclose(np.logaddexp.reduce(logw), 0.0), np.logaddexp.reduce(logw)
    inds = np.random.choice(np.arange(len(logw)), size=len(logw), p=np.exp(logw))
    return inds, new_logw

In [None]:
def compute_ess(logw):
    log_ess = 2*np.logaddexp.reduce(logw) - np.logaddexp.reduce(2*logw)
    return np.exp(log_ess) / len(logw)

In [None]:
# TODO: Update to flux basis
def measure_Mx(x):
    mask = make_eo_mask(x.shape)
    return 2*np.mean(x * mask), 2*np.mean(x * (1-mask))

In [None]:
def measure_F(x):
    F = []
    for mu in range(x.shape[0]):
        for nu in range(mu+1, x.shape[0]):
            a = x[nu]
            b = np.roll(x[mu], -1, axis=nu)
            c = -np.roll(x[nu], -1, axis=mu)
            d = -x[mu]
            F.append(np.mean((a == b) & (b == c) & (c == d)))
    return np.stack(F)

In [None]:
def run_walkers(x, *, dt, J, lam, n_iter, resample_thresh, n_meas):
    logw = np.zeros(x.shape[0])
    ts = dt*np.arange(0, n_iter, n_meas)
    hist = dict(logw=[], ess=[], Mlogw=[], F=[], cfgs=[], ts=ts)
    for i in tqdm.tqdm(range(n_iter)):
        for j,xj in enumerate(x):
            logw[j] += step(xj, dtJ=dt*J, dtLam=dt*lam)
        ess = compute_ess(logw)
        if ess < resample_thresh:
            inds, new_logw = resample(logw)
            x = x[inds]
            logw[:] = new_logw
            ess = 1.0
        hist['ess'].append(ess)
        if (i+1) % n_meas == 0:
            # np.stack([measure_Fx(xj) for xj in x])
            hist['cfgs'].append(np.copy(x))
            # hist['Mx'].append(np.stack([measure_Mx(xj) for xj in x]))
            hist['F'].append(np.stack([measure_F(xj) for xj in x]))
            hist['Mlogw'].append(np.copy(logw))
        hist['logw'].append(np.copy(logw))
    return hist

In [None]:
x = np.stack([init_cold_links((3, 4, 4, 4)) for _ in range(128)])
hist = run_walkers(x, dt=0.01, J=1.0, lam=0.0, n_iter=1000, resample_thresh=0.7, n_meas=10)

In [None]:
fig, axes = plt.subplots(1,2, figsize=(6,3))
# logw = al.bootstrap(np.stack(hist['logw'], axis=1), Nboot=1000, f=al.rmean)
# al.add_errorbar(logw, ax=ax)
for logwi in np.stack(hist['logw'], axis=1):
    axes[0].plot(logwi, linewidth=0.5, color='0.5')
def weighted_meas(logw, M):
    logw = logw - np.max(logw, axis=0)
    return al.rmean(np.exp(logw)*np.abs(M)) / al.rmean(np.exp(logw))
F = np.stack(al.bootstrap(
    np.stack(hist['Mlogw'], axis=1)[...,None], np.stack(hist['F'], axis=1),
    Nboot=100, f=weighted_meas))
print(f'{F.shape=}')
# print(f'{Mx.shape=}')
# al.add_errorbar(Mx[:,:,0], xs=np.arange(0, 10000, 10), ax=axes[1], label='Mx even')
# al.add_errorbar(Mx[:,:,1], xs=np.arange(0, 10000, 10), ax=axes[1], label='Mx odd')
al.add_errorbar(F[:,:,0], xs=hist['ts'], ax=axes[1], label='F (0,1)')
al.add_errorbar(F[:,:,1], xs=hist['ts'], ax=axes[1], label='F (0,2)')
al.add_errorbar(F[:,:,2], xs=hist['ts'], ax=axes[1], label='F (1,2)')
axes[1].legend()
plt.show()

In [None]:
fig, axes = plt.subplots(2,4)
mask = make_eo_mask(hist['cfgs'][0][0].shape)
cmap = plt.get_cmap('Grays')
cmap.set_bad(alpha=0.0)
for ax_col,cfg in zip(np.transpose(axes),np.stack(hist['cfgs'])[-4:,0]):
    cfg_even = np.where(mask, cfg, float('nan'))
    cfg_odd = np.where(~mask, cfg, float('nan'))
    ax_col[0].imshow(cfg_even, interpolation='nearest', cmap=cmap)
    ax_col[1].imshow(cfg_odd, interpolation='nearest', cmap=cmap)
    ax.set_aspect(1.0)
plt.show()