In [1]:
!pip install numpy~=1.0  # 1.0 instead of 2.0 for pymatching compatibility later
!pip install scipy
!pip install stim~=1.14
!pip install pymatching~=2.0



In [2]:
import stim
print(stim.__version__)

1.14.0


In [3]:
import stim
import pymatching
import numpy as np
def count_logical_errors(circuit: stim.Circuit, num_shots: int) -> int:
    # Sample the circuit.
    sampler = circuit.compile_detector_sampler()
    detection_events, observable_flips = sampler.sample(num_shots, separate_observables=True)

    # Configure a decoder using the circuit.
    detector_error_model = circuit.detector_error_model(decompose_errors=True)
    matcher = pymatching.Matching.from_detector_error_model(detector_error_model)

    # Run the decoder.
    predictions = matcher.decode_batch(detection_events)

    # Count the mistakes.
    num_errors = 0
    for shot in range(num_shots):
        actual_for_shot = observable_flips[shot]
        predicted_for_shot = predictions[shot]
        if not np.array_equal(actual_for_shot, predicted_for_shot):
            num_errors += 1
    return num_errors

In [4]:
import stim
def get_logical_err_rate(distance: int, rounds: int, phy_err_p: int, num_shots: int) -> float:
    # Generate Surface Code circuit with the specified parameters.
    circuit = stim.Circuit.generated(
            "surface_code:rotated_memory_z",
            rounds=rounds,
            distance=distance,
            after_clifford_depolarization=phy_err_p,
            after_reset_flip_probability=phy_err_p,
            before_measure_flip_probability=phy_err_p,
            before_round_data_depolarization=phy_err_p,
        )
    return count_logical_errors(circuit, num_shots) / num_shots

In [5]:
print(get_logical_err_rate(distance=7, rounds=21, phy_err_p=0.001, num_shots=100000))

8e-05


In [6]:
def find_distance_for_logical_err_rate(target_logical_err_rate: float, rounds: int, phy_err_p: float, num_shots: int) -> int:
    # Binary search for the distance that gives the target logical error rate.
    low = 1
    high = 50
    while low < high:
        mid = (low + high) // 2
        logical_err_rate = get_logical_err_rate(mid, rounds, phy_err_p, num_shots)
        if logical_err_rate < target_logical_err_rate:
            high = mid
        else:
            low = mid + 1
    return low

In [7]:
print(find_distance_for_logical_err_rate(target_logical_err_rate=0.0001, rounds=21, phy_err_p=0.001, num_shots=10000))

6
