# PTNT Quickstart · Reproduce a small Figure‑3 pipeline


> **Run notes**
>
> - Open from the **repo root** or install PTNT as editable (`pip install -e .`).
> - CPU is fine; GPU JAX improves throughput if `nvidia-smi` works and `jax[cuda12]` is installed.
> - First JAX call compiles with XLA (one-time warm‑up).


In [None]:

import os, sys, importlib, pathlib

_cwd = pathlib.Path.cwd()
_candidates = [_cwd, _cwd.parent, _cwd.parent.parent, _cwd.parent.parent.parent]
for root in _candidates:
    ptnt_pkg = root / "ptnt"
    if ptnt_pkg.exists() and ptnt_pkg.is_dir():
        if str(root) not in sys.path:
            sys.path.insert(0, str(root))
        break

try:
    import ptnt
    from ptnt._version import __version__ as ptnt_version
    print("[ptnt] import OK, version:", ptnt_version)
except Exception as e:
    print("[ptnt] import failed:", e)
    print("Install editable with `pip install -e .` from the repo root, then restart the kernel.")
    raise

try:
    import jax
    print("[ptnt] JAX devices:", jax.devices())
except Exception as e:
    print("[ptnt] JAX not available:", e)


This runs a **small** version of the Figure‑3 pipeline and shows where metrics/plots are saved.

In [None]:

from pathlib import Path
from ptnt.io.run import run_from_config, default_config_for_experiment
import json

cfg = default_config_for_experiment("figure3")
cfg["training"]["epochs"] = 1
cfg["training"]["batch_size"] = 64
cfg["data"]["jobs"] = 2
cfg["data"]["shadows_per_job"] = 60
cfg["data"]["shots_per_char"] = 256
cfg["data"]["val_shadows"] = 20
cfg["data"]["shots_per_val"] = 1024
cfg["training"]["opt"] = cfg["training"].get("opt", "greedy")

metrics = run_from_config(cfg)

run_dir = Path(cfg.get("output",{}).get("dir","."))
print("Run directory:", run_dir.resolve())
print("Saved:", [p.name for p in run_dir.glob("ptnt_*.*")])

mp = run_dir / "ptnt_metrics.json"
if mp.exists():
    print((mp.read_text())[:1200], "...")


In [None]:

import numpy as np
import matplotlib.pyplot as plt


In [None]:

from pathlib import Path
from PIL import Image
run_dir = Path(cfg.get("output",{}).get("dir","."))
p = run_dir / "ptnt_losses.png"
display(Image.open(p)) if p.exists() else print("Loss plot not found.")
