In [1]:
import numpy as np
import quimb as qu
import quimb.tensor as qtn

import symmray as sr

Lx = 8
Ly = 8
nsites = Lx * Ly
D = 8
chi = D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

# batchsize
B = 1024

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [2]:
# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    # might need to specify the right site ordering here
    tnx = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})

    return tnx.contract_hotrg(
        max_bond=chi,
        cutoff=0.0,
        # these two options make the return value (mantissa, exponent)
        # which can avoid issues with small/large values and stability
        equalize_norms=1.0,
        final_contract_opts=dict(strip_exponent=True),
    )

In [3]:
# generate half-filling configs
rng = np.random.default_rng(seed)
xs = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs = rng.permuted(xs, axis=1)

First test non eager version:

In [4]:
mantissa, exponent = amplitude(xs[0], params)
print(mantissa, exponent)

1.0 32.66959439764813


Then test version with torch, gpu tensors:

In [5]:
import torch

torch.set_default_device("cuda:0")

# convert bitstrings and arrays to torch
xs = torch.tensor(xs)
params = qu.tree_map(
    lambda x: torch.tensor(x, dtype=torch.float32),
    params,
)

In [6]:
%%timeit
mantissa, exponent = amplitude(xs[0], params)
mantissa, exponent

1.92 s ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Then test and warm up torch vmapped version:

In [7]:
vf = torch.vmap(
    amplitude,
    # batch on configs, not parameters
    in_dims=(0, None),
)

In [8]:
%%time
# warmup time
vf(xs, params)

CPU times: user 4.12 s, sys: 28.1 ms, total: 4.15 s
Wall time: 4.13 s


(tensor([ 1., -1.,  1.,  ..., -1.,  1., -1.], device='cuda:0'),
 tensor([32.6688, 32.7025, 32.7157,  ..., 32.2940, 32.6343, 32.9307],
        device='cuda:0'))

In [9]:
%%timeit
# final time (to compute full batch)
vf(xs, params)

3.93 s ± 244 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Using `torch.nn.Module`

Here we wrap the pure function in as a torch module, this enables various
functionality, such as training and exporting it.

In [10]:
class TNAmplitudeModel(torch.nn.Module):
    def __init__(self, fn, tn, vmap=False, **kwargs):
        import quimb as qu
        import quimb.tensor as qtn

        super().__init__()

        # split into plain arrays and bare tn structure
        params, skeleton = qtn.pack(tn)
        # for torch, further flatten pytree into a single list
        params_flat, params_pytree = qu.utils.tree_flatten(
            params, get_ref=True
        )
        # register the flat list parameters
        self.params = torch.nn.ParameterList([
            torch.as_tensor(x, dtype=torch.float32) for x in params_flat
        ])

        def amplitude(x):
            params = qu.utils.tree_unflatten(self.params, params_pytree)
            tn = qtn.unpack(params, skeleton)
            return fn(x, tn, **kwargs)

        if vmap:
            self.f = torch.vmap(amplitude)
        else:
            self.f = amplitude

    def forward(self, *args):
        return self.f(*args)

In [11]:
# instantiate the model
model = TNAmplitudeModel(amplitude, peps)

In [12]:
# compute output, with gradients
model.forward(xs[0])

(tensor(1., device='cuda:0', grad_fn=<DivBackward0>),
 tensor(32.6688, device='cuda:0', grad_fn=<AddBackward0>))

In [13]:
# compute output, no gradients
with torch.inference_mode():
    print(model.forward(xs[0]))

(tensor(1., device='cuda:0'), tensor(32.6688, device='cuda:0'))


In [14]:
# instantiate a vmapped model
vmodel = TNAmplitudeModel(amplitude, peps, vmap=True)

In [15]:
# compute batch output, no gradients
with torch.inference_mode():
    print(vmodel.forward(xs))

(tensor([ 1., -1.,  1.,  ..., -1.,  1., -1.], device='cuda:0'), tensor([32.6688, 32.7025, 32.7157,  ..., 32.2940, 32.6343, 32.9307],
       device='cuda:0'))
