In [1]:
import numpy as np
import quimb.experimental.operatorbuilder as qop
import quimb.tensor as qtn

import symmray as sr

Lx = 3
Ly = 3
D = 4
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    # put an odd number of odd sites in, for testing
    site_charge=lambda site: int(site in [(0, 0), (0, 1), (2, 2)]),
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [2]:
edges = qtn.edges_2d_square(Lx, Ly)
sites = [(i, j) for i in range(Lx) for j in range(Ly)]

In [3]:
terms = sr.hamiltonians.ham_fermi_hubbard_spinless_from_edges(
    "Z2",
    edges=edges,
    V=3.71,
    mu=0.119,
)
if flat:
    terms = {k: v.to_flat() for k, v in terms.items()}

In [4]:
eref = peps.compute_local_expectation_exact(terms, normalized=True)
eref

np.float64(7.7778407241590655)

In [5]:
H = qop.fermi_hubbard_spinless_from_edges(edges, V=3.71, mu=0.119)
hs = H.hilbert_space
hs.sites

((0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2))

In [6]:
if flat:
    import jax
    import jax.numpy as jnp

    peps.apply_to_arrays(jnp.array)

In [7]:
def flat_amplitude(fx):
    selector = {peps.site_ind(site): val for site, val in zip(hs.sites, fx)}
    tnb = peps.isel(selector)
    return tnb.contract()


if flat:
    flat_amplitude = jax.jit(flat_amplitude)

In [8]:
O = 0.0
p = 0.0

fcs = []
for i in range(hs.size):
    fx = hs.rank_to_flatconfig(i)

    xpsi = flat_amplitude(fx)
    if not xpsi:
        continue

    pi = abs(xpsi) ** 2
    p += pi

    Oloc = 0.0
    for fy, hxy in zip(*H.flatconfig_coupling(fx)):
        ypsi = flat_amplitude(fy)
        Oloc = Oloc + hxy * ypsi / xpsi

    O += Oloc * pi

O / p

Array(7.7778387, dtype=float32)

In [9]:
eref

np.float64(7.7778407241590655)