In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import numpy as np
import quimb.tensor as qtn

import symmray as sr

Lx = 8
Ly = 8
nsites = Lx * Ly
D = 4
chi = D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

# batchsize
B = 1024

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=4,
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [3]:
import jax

# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    # first convert spinup/spindown pairs to single index 0..3
    # these should match up with the phys_dim ordering above
    x = 2 * (x[::2] != x[1::2]) + x[1::2]

    # might need to specify the right site ordering here
    tnx = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})

    return tnx.contract_hotrg(
        max_bond=chi,
        cutoff=0.0,
        # these two options make the return value (mantissa, exponent)
        # which can avoid issues with small/large values and stability
        equalize_norms=1.0,
        final_contract_opts=dict(strip_exponent=True),
    )


amplitude_jit = jax.jit(amplitude)
amplitude_vmap = jax.jit(jax.vmap(amplitude, in_axes=(0, None)))

In [4]:
# generate half-filling configs
rng = np.random.default_rng(seed)
xs = np.concatenate(
    [
        np.zeros((B, nsites), dtype=np.int32),
        np.ones((B, nsites), dtype=np.int32),
    ],
    axis=1,
)
xs = rng.permuted(xs, axis=1)

In [5]:
xs

array([[0, 0, 0, ..., 1, 1, 0],
       [1, 1, 1, ..., 1, 1, 0],
       [1, 0, 1, ..., 1, 1, 0],
       ...,
       [0, 1, 1, ..., 0, 1, 1],
       [1, 0, 1, ..., 0, 1, 1],
       [1, 0, 1, ..., 1, 0, 0]], shape=(1024, 128), dtype=int32)

First test non jax jitted version:

In [6]:
mantissa, exponent = amplitude(xs[0], params)
mantissa, exponent

(np.float64(-1.0), np.float64(18.686366895508034))

Then test and warm up jax jitted version:

In [7]:
%%time
# test single amplitude is working
mantissa, exponent = amplitude_jit(xs[0], params)
mantissa, exponent

CPU times: user 13min 3s, sys: 18.1 s, total: 13min 21s
Wall time: 12min 35s


(Array(-0.99999994, dtype=float32), Array(18.611166, dtype=float32))

In [8]:
%%timeit
mantissa, exponent = jax.block_until_ready(amplitude_jit(xs[0], params))

31.6 ms ± 191 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Then test and warm up jax vmapped version:

In [9]:
%%time
# test batch
mantissas, exponents = amplitude_vmap(xs, params)
mantissas[0], exponents[0]

E0107 17:41:00.573808 2442509 slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_amplitude for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
E0107 17:41:19.126950 2418735 slow_operation_alarm.cc:140] The operation took 2m18.553267586s

********************************
[Compiling module jit_amplitude for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


CPU times: user 7min 41s, sys: 24.8 s, total: 8min 6s
Wall time: 7min 52s


(Array(-0.99999994, dtype=float32), Array(18.611166, dtype=float32))

In [10]:
%%timeit
mantissas, exponents = jax.block_until_ready(amplitude_vmap(xs, params))

88.2 ms ± 391 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
