In [1]:
import numpy as np
import stim
from qecdec import BPDecoder, RotatedSurfaceCode, detector_error_model_to_check_matrices
from utils import visualize_bp_decoding_process

In [2]:
def get_chk_coords_and_var_coords_from_dem(dem: stim.DetectorErrorModel) -> tuple[np.ndarray, np.ndarray]:
    chk_coords = np.zeros((dem.num_detectors, 3))
    var_coords = np.zeros((dem.num_errors, 3))

    for instruction in dem.flattened():
        if instruction.type == "detector":
            xyt: list[float] = instruction.args_copy()
            tgt: stim.DemTarget = instruction.targets_copy()[0]
            chk_coords[tgt.val] = np.array(xyt)

    j = 0
    for instruction in dem.flattened():
        if instruction.type == "error":
            tgts: list[stim.DemTarget] = instruction.targets_copy()
            if len(tgts) == 1 and tgts[0].is_relative_detector_id():
                d = tgts[0].val
                var_coords[j] = chk_coords[d] + np.array([0., 1., 0.])
            elif len(tgts) == 2 and tgts[0].is_relative_detector_id() and tgts[1].is_relative_detector_id():
                d0, d1 = tgts[0].val, tgts[1].val
                var_coords[j] = (chk_coords[d0] + chk_coords[d1]) / 2
            elif len(tgts) == 2 and tgts[0].is_relative_detector_id() and tgts[1].is_logical_observable_id():
                d = tgts[0].val
                var_coords[j] = chk_coords[d] + np.array([0., -1., 0.])
            else:
                raise NotImplementedError()
            j += 1

    return chk_coords, var_coords

In [3]:
d = 5
p = 0.01

code = RotatedSurfaceCode(d=d)
circuit = code.make_circuit_memory_z_experiment(
    rounds=3,
    data_qubit_error_rate=p,
    meas_error_rate=p,
    keep_z_detectors_only=True
)
dem = circuit.detector_error_model()
matrices = detector_error_model_to_check_matrices(dem)
chkmat, obsmat, pvec = matrices.check_matrix, matrices.observables_matrix, matrices.priors
chkmat = chkmat.toarray().astype(np.uint8)
obsmat = obsmat.toarray().astype(np.uint8)
pvec = pvec.astype(np.float64)

num_shots = 1_000
syndrome_batch, observable_batch = circuit.compile_detector_sampler(seed=0).sample(num_shots, separate_observables=True)
syndrome_batch = syndrome_batch.astype(np.uint8)
observable_batch = observable_batch.astype(np.uint8)

decoder = BPDecoder(pcm=chkmat, prior=pvec, max_iter=50)

chk_coords, var_coords = get_chk_coords_and_var_coords_from_dem(dem)

In [4]:
shot_idx = 17

syndrome = syndrome_batch[shot_idx]
ehat = decoder.decode(syndrome, record_llr_history=True)
llr_history = decoder.get_llr_history()

print("Number of iterations: ", llr_history.shape[0])
print("Has the decoder matched the syndrome? ", np.all(syndrome == (chkmat @ ehat) % 2))
print("Has the decoder correctly predicted the logical observable? ", np.all(observable_batch[shot_idx] == (obsmat @ ehat) % 2))


Number of iterations:  50
Has the decoder matched the syndrome?  False
Has the decoder correctly predicted the logical observable?  True


In [5]:
visualize_bp_decoding_process(chkmat, chk_coords, var_coords, syndrome, llr_history)