# Training by Maximum Likelihood (+ Causality)


**Goal:** run a small negative log‑likelihood fit with optional causality regularization.

**Tip:** start tiny (few circuits, small bonds, 1–2 epochs), then scale.


In [None]:

# Make a nearby PTNT checkout importable if not pip-installed.
import os, sys, pathlib
roots = [pathlib.Path.cwd(), *pathlib.Path.cwd().parents]
for r in roots[:4]:
    if (r / "ptnt").is_dir() and str(r) not in sys.path:
        sys.path.insert(0, str(r))

# Basic environment info
try:
    import ptnt
    from ptnt._version import __version__ as ptnt_version
    print("[ptnt] import OK → version:", ptnt_version)
except Exception as e:
    print("[ptnt] import failed:", e)
    raise

try:
    import jax
    print("[ptnt] JAX devices:", jax.devices())
except Exception as e:
    print("[ptnt] JAX not available:", e)


In [None]:

# Minimal pipeline: build shell → generate shadows → preprocess → build LPDO → tiny fit
import numpy as np, quimb as qu
from qiskit_aer import Aer
from ptnt.circuits.templates import base_PT_circ_template
from ptnt.circuits.noise_models import create_env_IA
from ptnt.circuits.utils import bind_ordered
from ptnt.preprocess.shadow import (
    clifford_param_dict, validation_param_dict, shadow_results_to_data_vec,
    shadow_seqs_to_op_array, pure_measurement,
    clifford_measurements_vT, clifford_unitaries_vT
)
from ptnt.tn.pepo import create_PT_PEPO_guess, expand_initial_guess_
from ptnt.tn.optimize import TNOptimizer
from ptnt.tn.fit import compute_likelihood, causality_keys_to_op_arrays, compute_probabilities
from ptnt.utilities import hellinger_fidelity

backend = Aer.get_backend("aer_simulator")
Q, T = 2, 2
env = create_env_IA(0.4, 0.2, 0.3)
template = base_PT_circ_template(Q, T, backend, None, "dd_clifford", env)

def batch(template, N, table):
    circs, seqs = [], []
    for _ in range(N):
        idx = np.random.randint(0, len(table), (T+1, Q))
        seqs.append(idx.T)
        params = np.array([table[i] for i in idx.ravel()])
        circs.append(bind_ordered(template, params.ravel()))
    return circs, seqs

N_train, N_val = 60, 20
shots_char, shots_val = 256, 1024
train_circs, train_seqs = batch(template, N_train, clifford_param_dict)
val_circs,   val_seqs   = batch(template, N_val,   validation_param_dict)

job_t = backend.run(train_circs, shots=shots_char)
job_v = backend.run(val_circs,   shots=shots_val)

train_counts = job_t.result().get_counts()
val_counts   = job_v.result().get_counts()

train_p, train_keys = shadow_results_to_data_vec(train_counts, shots_char, Q)
val_p,   val_keys   = shadow_results_to_data_vec(val_counts,   shots_val,   Q)

def reverse_seq_list(seq_list):
    out = []
    for seq in seq_list:
        tmp = []
        for Tseq in seq:
            tmp.append([o for o in reversed(Tseq)])
        tmp.reverse()
        out.append(tmp)
    return out

train_full = shadow_seqs_to_op_array(reverse_seq_list(train_seqs), train_keys, clifford_measurements_vT, clifford_unitaries_vT)
val_full   = shadow_seqs_to_op_array(reverse_seq_list(val_seqs),   val_keys,   clifford_measurements_vT, clifford_unitaries_vT)

K_lists = [[2] + [1]*(T-1) + [2] for _ in range(Q)]
vertical_bonds   = [[2 for _ in range(Q-1)]] + [[2] + [2 for _ in range(Q-3)] + [2] for _ in range(T)]
horizontal_bonds = [1 for _ in range(T)]
pepo = create_PT_PEPO_guess(T, Q, horizontal_bonds, vertical_bonds, K_lists)
grid = qu.tensor.tensor_2d.TensorNetwork2DFlat.from_TN(pepo, site_tag_id="q{}_I{}", Ly=T+1, Lx=Q, y_tag_id="ROWq{}", x_tag_id="COL{}")

train_vec = np.array(train_p, dtype=float); train_vec[train_vec < 1e-12] = 1e-12
val_vec   = np.array(val_p,   dtype=float); val_vec[val_vec   < 1e-12]   = 1e-12
epochs, batch_size = 1, 64
iterations = int(2 * epochs * len(train_vec) / batch_size)

optmzr = TNOptimizer(
    grid,
    loss_fn=compute_likelihood,
    causality_fn=causality_keys_to_op_arrays,
    causality_key_size=32,
    training_data=train_vec,
    training_sequences=train_full,
    Lx=grid.Lx, Ly=grid.Ly,
    validation_data=list(val_vec),
    validation_sequences=val_full,
    batch_size=batch_size,
    loss_constants={},
    loss_kwargs={"kappa": 1e-3, "opt": "greedy", "X_decomp": False},
    autodiff_backend="jax",
    optimizer={"name": "adam", "lr": 5e-3},
    progbar=True,
)
_ = optmzr.optimize(iterations)
best = optmzr.best_val_mpo

pred = compute_probabilities(best, val_full, X_decomp=False, opt="greedy")
pred = sum(val_vec) * pred / sum(pred)
Qbits = 2**Q
import numpy as np
fids = []
for i in range(N_val):
    p = np.array(pred[Qbits*i:Qbits*(i+1)]); p = p / p.sum()
    a = np.array(val_vec[Qbits*i:Qbits*(i+1)])
    fids.append(hellinger_fidelity(p, a))
print("mean val fidelity (tiny demo):", float(np.mean(fids)))
