In [1]:
import numpy as np
import quimb as qu
import quimb.tensor as qtn

import symmray as sr

Lx = 8
Ly = 8
nsites = Lx * Ly
D = 8
chi = D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

# batchsize
B = 1024

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [2]:
# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    # might need to specify the right site ordering here
    tnx = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})

    return tnx.contract_hotrg(
        max_bond=chi,
        cutoff=0.0,
        # these two options make the return value (mantissa, exponent)
        # which can avoid issues with small/large values and stability
        equalize_norms=1.0,
        final_contract_opts=dict(strip_exponent=True),
    )

In [3]:
# generate half-filling configs
rng = np.random.default_rng(seed)
xs = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs = rng.permuted(xs, axis=1)

First test non eager version:

In [4]:
mantissa, exponent = amplitude(xs[0], params)
print(mantissa, exponent)

-1.0 32.7511611219604


Then test version with torch, gpu tensors:

In [None]:
import torch

torch.set_default_device('cuda:0')

# convert bitstrings and arrays to torch
xs = torch.tensor(xs)
params = qu.tree_map(
    lambda x: torch.tensor(x, dtype=torch.float32),
    params,
)

In [10]:
%%timeit
mantissa, exponent = amplitude(xs[0], params)
mantissa, exponent

804 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Then test a traced and jit compiled version:

In [11]:
%%time
# tracing time
amplitude_jit = torch.jit.trace(amplitude, (xs[0], params))

CPU times: user 1min 28s, sys: 1.91 s, total: 1min 30s
Wall time: 1min 29s


In [12]:
%%time
# jit time
mantissa, exponent = amplitude_jit(xs[0], params)
mantissa, exponent

CPU times: user 2min 13s, sys: 418 ms, total: 2min 14s
Wall time: 2min 15s


(tensor(-1., device='cuda:0'), tensor(32.7514, device='cuda:0'))

In [13]:
%%timeit
# warmed up time
mantissa, exponent = amplitude_jit(xs[0], params)

214 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Then test and warm up torch vmapped version:

In [15]:
vf = torch.vmap(
    amplitude,
    # batch on configs, not parameters
    in_dims=(0, None)
)

In [16]:
%%time
# warmup time
vf(xs, params)

CPU times: user 1.73 s, sys: 4.99 ms, total: 1.73 s
Wall time: 1.72 s


(tensor([-1., -1.,  1.,  ..., -1.,  1.,  1.], device='cuda:0'),
 tensor([32.7514, 32.6899, 32.6164,  ..., 33.2167, 32.2461, 33.0070],
        device='cuda:0'))

In [17]:
%%timeit
# final time (to compute full batch)
vf(xs, params)

1.66 s ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
