# Contraction Options & Performance


> **Run notes**
>
> - Open from the **repo root** or install PTNT as editable (`pip install -e .`).
> - CPU is fine; GPU JAX improves throughput if `nvidia-smi` works and `jax[cuda12]` is installed.
> - First JAX call compiles with XLA (one-time warmâ€‘up).


In [None]:

import os, sys, importlib, pathlib

_cwd = pathlib.Path.cwd()
_candidates = [_cwd, _cwd.parent, _cwd.parent.parent, _cwd.parent.parent.parent]
for root in _candidates:
    ptnt_pkg = root / "ptnt"
    if ptnt_pkg.exists() and ptnt_pkg.is_dir():
        if str(root) not in sys.path:
            sys.path.insert(0, str(root))
        break

try:
    import ptnt
    from ptnt._version import __version__ as ptnt_version
    print("[ptnt] import OK, version:", ptnt_version)
except Exception as e:
    print("[ptnt] import failed:", e)
    print("Install editable with `pip install -e .` from the repo root, then restart the kernel.")
    raise

try:
    import jax
    print("[ptnt] JAX devices:", jax.devices())
except Exception as e:
    print("[ptnt] JAX not available:", e)


In [None]:

import time, numpy as np
import quimb.tensor as qtn
from ptnt.preprocess import shadow as sh
from ptnt.tn.pepo import create_PT_PEPO_guess, produce_LPDO

nQ, nS = 2, 2
K_lists = [[2] + [1]*(nS-1) + [2] for _ in range(nQ)]
vertical_bonds = [[2 for _ in range(nQ-1)]] + [[2] + [2 for _ in range(nQ-3)] + [2] for _ in range(nS)]
horizontal_bonds = [1 for _ in range(nS)]
pepo_half = create_PT_PEPO_guess(nS, nQ, horizontal_bonds, vertical_bonds, K_lists)
lpdo = produce_LPDO(pepo_half)

seq = np.array([[0, 1],[2, 3]])
sequences = [[seq]]; keys = [["00"]]
op_full = sh.shadow_seqs_to_op_array(sequences, keys, sh.clifford_measurements_vT, sh.clifford_unitaries_vT)
seqTN = sh.op_arrays_to_single_vector_TN_padded(op_full[0])

def time_contract(opt):
    t0 = time.perf_counter()
    val = (lpdo & seqTN).contract(optimize=opt)
    return float(val), time.perf_counter() - t0

for opt in ["greedy", "random-greedy", "hyper-kahypar", "auto-hq"]:
    try:
        v, dt = time_contract(opt)
        print(f"{opt:>14s}: value={v:.6f}  time={dt:.4f}s")
    except Exception as e:
        print(f"{opt:>14s}: not available ({e})")
