In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
import numpy as np
import quimb.tensor as qtn

import symmray as sr

Lx = 8
Ly = 8
nsites = Lx * Ly
D = 4
chi = D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

# batchsize
B = 1024

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [3]:
import jax

# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    # might need to specify the right site ordering here
    tnx = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})

    return tnx.contract_hotrg(
        max_bond=chi,
        cutoff=0.0,
        # these two options make the return value (mantissa, exponent)
        # which can avoid issues with small/large values and stability
        equalize_norms=1.0,
        final_contract_opts=dict(strip_exponent=True),
    )


amplitude_jit = jax.jit(amplitude)
amplitude_vmap = jax.jit(jax.vmap(amplitude, in_axes=(0, None)))

In [4]:
# generate half-filling configs
rng = np.random.default_rng(seed)
xs = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs = rng.permuted(xs, axis=1)

First test non jax jitted version:

In [5]:
mantissa, exponent = amplitude(xs[0], params)
mantissa, exponent

(np.float64(1.0), np.float64(20.24082443830696))

Then test and warm up jax jitted version:

In [6]:
%%time
# test single amplitude is working
mantissa, exponent = amplitude_jit(xs[0], params)
mantissa, exponent

W1212 17:42:16.858054  411880 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1212 17:42:16.868534  405761 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
E1212 17:47:38.182433  413111 slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_amplitude for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
E1212 17:47:44.480987  405761 slow_operation_alarm.cc:140] The operation took 2m6.298655514s

********************************
[Compiling module jit_amplitude for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


CPU times: user 13min 6s, sys: 19.9 s, total: 13min 26s
Wall time: 12min 32s


(Array(1., dtype=float32), Array(20.242916, dtype=float32))

In [7]:
%%timeit
mantissa, exponent = jax.block_until_ready(amplitude_jit(xs[0], params))

35.5 ms ± 144 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Then test and warm up jax vmapped version:

In [8]:
%%time
# test batch
mantissas, exponents = amplitude_vmap(xs, params)
mantissas[0], exponents[0]

CPU times: user 8min 19s, sys: 21.4 s, total: 8min 40s
Wall time: 8min 20s


(Array(1., dtype=float32), Array(20.242916, dtype=float32))

In [9]:
%%timeit
mantissas, exponents = jax.block_until_ready(amplitude_vmap(xs, params))

90.2 ms ± 408 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
