In [None]:
import stim
import pymatching as pm
import numpy as np

def build_repetition_circuit(d=5, rounds=None, p1=0.001, p2=0.005, pm=0.01):
    """
    Repetition code (Z-parity checks) with proper endcap detectors and a meaningful logical.
      - d data qubits, d-1 ancillas (between data qubits)
      - rounds defaults to d
      - Noise: depolarizing on 1q/2q, simple readout flip via X_ERROR before M
      - DETECTORs: time-difference (round t vs t-1) + final spatial closures to data measurements
      - OBSERVABLE: parity of boundary data measurements (data[0] xor data[d-1])
    """
    if rounds is None:
        rounds = d

    n_data = d
    n_anc = d - 1
    c = stim.Circuit()

    # Reset all to |0>
    c.append("R", range(n_data + n_anc))

    # Repeated parity measurement rounds
    for t in range(rounds):
        # Light 1q noise
        if p1 > 0:
            for q in range(n_data + n_anc):
                c.append("DEPOLARIZE1", [q], p1)

        # Entangle ancilla j with data j and j+1 (use CZ; repetition is Z-parity)
        for j in range(n_anc):
            a = n_data + j
            dL, dR = j, j + 1
            c.append("CZ", [a, dL])
            if p2 > 0:
                c.append("DEPOLARIZE2", [a, dL], p2)
            c.append("CZ", [a, dR])
            if p2 > 0:
                c.append("DEPOLARIZE2", [a, dR], p2)

        # Measure ancillas with readout flips modeled as pre-measure X
        for j in range(n_anc):
            a = n_data + j
            if pm > 0:
                c.append("X_ERROR", [a], pm)
            c.append("M", [a])

        # Time-like detectors: difference of same ancilla t vs t-1
        for j in range(n_anc):
            cur = stim.target_rec(-(j + 1))  # current round's j-th ancilla M
            if t == 0:
                c.append("DETECTOR", [cur])
            else:
                prev = stim.target_rec(-(n_anc + j + 1))  # previous round's same ancilla
                c.append("DETECTOR", [prev, cur])

    # --- Final data measurements (space-like closures) ---
    # Measure all data qubits at the end
    for q in range(n_data):
        if pm > 0:
            c.append("X_ERROR", [q], pm)  # simple readout flip model
        c.append("M", [q])

    # Add spatial closure detectors tying last-round ancilla parity to data endpoints
    # After the data Ms, the last n_data recs are data[0..d-1] (in that order).
    # The last round's ancilla recs are *before* those; ancilla j is at offset: n_data + (n_anc - j)
    for j in range(n_anc):
        data_j = stim.target_rec(-(n_data - j))         # M(data j)
        data_j1 = stim.target_rec(-(n_data - (j + 1)))  # M(data j+1)
        last_round_anc_j = stim.target_rec(-(n_data + (n_anc - j)))
        # Parity check should be 0 in the no-error case:
        c.append("DETECTOR", [last_round_anc_j, data_j, data_j1])

    # Logical observable = parity of boundary data measurements (data[0] xor data[d-1])
    left = stim.target_rec(-n_data)     # data[0]
    right = stim.target_rec(-1)         # data[d-1]
    c.append("OBSERVABLE_INCLUDE", [left, right], 0)

    return c

# Build a circuit that's noisy enough to see failures
circuit = build_repetition_circuit(d=7, rounds=7, p1=0.002, p2=0.01, pm=0.02)
print(circuit)

shots = 50_000

# DEM -> matcher
dem = circuit.detector_error_model(decompose_errors=True)
matcher = pm.Matching.from_detector_error_model(dem)

# Sample dets/obs
sampler = circuit.compile_detector_sampler()
dets, obs = sampler.sample(shots, separate_observables=True)

# Decode and estimate logical error rate
pred = matcher.decode_batch(dets).reshape(-1, 1)
p_L = np.count_nonzero(pred ^ obs) / shots
p_L



R 0 1 2 3 4 5 6 7 8 9 10 11 12
DEPOLARIZE1(0.002) 0 1 2 3 4 5 6 7 8 9 10 11 12
CZ 7 0
DEPOLARIZE2(0.01) 7 0
CZ 7 1
DEPOLARIZE2(0.01) 7 1
CZ 8 1
DEPOLARIZE2(0.01) 8 1
CZ 8 2
DEPOLARIZE2(0.01) 8 2
CZ 9 2
DEPOLARIZE2(0.01) 9 2
CZ 9 3
DEPOLARIZE2(0.01) 9 3
CZ 10 3
DEPOLARIZE2(0.01) 10 3
CZ 10 4
DEPOLARIZE2(0.01) 10 4
CZ 11 4
DEPOLARIZE2(0.01) 11 4
CZ 11 5
DEPOLARIZE2(0.01) 11 5
CZ 12 5
DEPOLARIZE2(0.01) 12 5
CZ 12 6
DEPOLARIZE2(0.01) 12 6
X_ERROR(0.02) 7
M 7
X_ERROR(0.02) 8
M 8
X_ERROR(0.02) 9
M 9
X_ERROR(0.02) 10
M 10
X_ERROR(0.02) 11
M 11
X_ERROR(0.02) 12
M 12
DETECTOR rec[-1]
DETECTOR rec[-2]
DETECTOR rec[-3]
DETECTOR rec[-4]
DETECTOR rec[-5]
DETECTOR rec[-6]
DEPOLARIZE1(0.002) 0 1 2 3 4 5 6 7 8 9 10 11 12
CZ 7 0
DEPOLARIZE2(0.01) 7 0
CZ 7 1
DEPOLARIZE2(0.01) 7 1
CZ 8 1
DEPOLARIZE2(0.01) 8 1
CZ 8 2
DEPOLARIZE2(0.01) 8 2
CZ 9 2
DEPOLARIZE2(0.01) 9 2
CZ 9 3
DEPOLARIZE2(0.01) 9 3
CZ 10 3
DEPOLARIZE2(0.01) 10 3
CZ 10 4
DEPOLARIZE2(0.01) 10 4
CZ 11 4
DEPOLARIZE2(0.01) 11 4
CZ 11 5
DEPOLARIZE

np.float64(0.0)