In [1]:
import stim
import stimcirq

In [2]:
circuit_noisy = stim.Circuit.from_file(".data/surface_code_bZ_d3_r03_center_3_5/circuit_noisy.stim")

detector_error_model_no_decompose_errors = circuit_noisy.detector_error_model()

In [4]:
def get_detector_val(targets_copy):
    target_val = []
    for target in targets_copy:
        if target.is_relative_detector_id():
            target_val.append(target.val)
    return target_val

def get_logical_observable_val(targets_copy):
    target_val = []
    for target in targets_copy:
        if target.is_logical_observable_id():
            target_val.append(target.val)
    return target_val

def get_flip_index(detector_val, logical_observable_val, detector_number):
    flip_index = []
    for detector in detector_val:
        if detector not in flip_index:
            flip_index.append(detector)
        else:
            flip_index.remove(detector)
    for logical_observable in logical_observable_val:
        logical_observable = logical_observable+detector_number
        if logical_observable not in flip_index:
            flip_index.append(logical_observable)
        else:
            flip_index.remove(logical_observable)
    return flip_index

def flip_detector_even_str_by_flip_index(detector_even_str, flip_index):
    flip_detector_even_str = detector_even_str
    
    for i in flip_index:
        if detector_even_str[i] == '0':
            flip_detector_even_str = flip_detector_even_str[:i]+'1'+flip_detector_even_str[i+1:]
        else:
            flip_detector_even_str = flip_detector_even_str[:i]+'0'+flip_detector_even_str[i+1:]
    
    return flip_detector_even_str

import time
def get_error_model(detector_error_model, detector_number):
    # detector_number = 8
    
    initial_detector_even_str = '0'*(detector_number+1)
    error_model = {}
    error_model[initial_detector_even_str] = 1
    n=0
    
    for error in detector_error_model:
        start_time = time.time()
        print(f"n={n}")
        n = n+1
        if error.type == "error":
            probability = error.args_copy()[0]
            targets_copy = error.targets_copy()
            detector_val = get_detector_val(targets_copy)
            logical_observable_val = get_logical_observable_val(targets_copy)
            flip_index = get_flip_index(detector_val, logical_observable_val, detector_number)
            
            table = {}
            for detector_even_str, detector_even_probability in error_model.items():
                no_flip_detector_even_str = detector_even_str
                no_flip_probability = detector_even_probability * (1-probability)
                
                flip_detector_even_str = flip_detector_even_str_by_flip_index(detector_even_str, flip_index)
                flip_probability = detector_even_probability * probability
                
                table[no_flip_detector_even_str] = table.get(no_flip_detector_even_str, 0) + no_flip_probability
                table[flip_detector_even_str] = table.get(flip_detector_even_str, 0) + flip_probability
            error_model = table
        end_time = time.time()
        print(f"table shape:{len(error_model)}")
        print(f"cost time:{end_time-start_time}")
    return error_model

# error_model = get_error_model(detector_error_model=detector_error_model, detector_number=detector_error_model.num_detectors)

In [14]:
import time

def pruing_table(table, pruning_threshold=1e-08):
    for key in list(table.keys()):
        if table[key] < pruning_threshold:
            del table[key]
    return table

def get_error_model_pruning(detector_error_model, detector_number, pruning_n = 10000, pruning_threshold=1e-08):
    # detector_number = 8
    
    initial_detector_even_str = '0'*(detector_number+1)
    error_model = {}
    error_model[initial_detector_even_str] = 1
    n=0
    for error in detector_error_model:
        start_time = time.time()
        print(f"n={n}")
        n = n+1
        if error.type == "error":
            probability = error.args_copy()[0]
            targets_copy = error.targets_copy()
            detector_val = get_detector_val(targets_copy)
            logical_observable_val = get_logical_observable_val(targets_copy)
            flip_index = get_flip_index(detector_val, logical_observable_val, detector_number)
            
            table = {}
            for detector_even_str, detector_even_probability in error_model.items():
                no_flip_detector_even_str = detector_even_str
                no_flip_probability = detector_even_probability * (1-probability)
                
                flip_detector_even_str = flip_detector_even_str_by_flip_index(detector_even_str, flip_index)
                flip_probability = detector_even_probability * probability
                
                table[no_flip_detector_even_str] = table.get(no_flip_detector_even_str, 0) + no_flip_probability
                table[flip_detector_even_str] = table.get(flip_detector_even_str, 0) + flip_probability
            
            if len(table) >= pruning_n:
                table = pruing_table(table, pruning_threshold)
            
            error_model = table
            
        end_time = time.time()
        print(f"table shape:{len(error_model)}")
        print(f"cost time:{end_time-start_time}")
    return error_model

error_model1 = get_error_model_pruning(detector_error_model=detector_error_model_no_decompose_errors, detector_number=detector_error_model_no_decompose_errors.num_detectors, pruning_n = 1000, pruning_threshold=1e-07)
error_model2 = get_error_model_pruning(detector_error_model=detector_error_model_no_decompose_errors, detector_number=detector_error_model_no_decompose_errors.num_detectors, pruning_n = 100000, pruning_threshold=1e-07)
error_model3 = get_error_model_pruning(detector_error_model=detector_error_model_no_decompose_errors, detector_number=detector_error_model_no_decompose_errors.num_detectors, pruning_n = 100000, pruning_threshold=1e-07)

n=0
table shape:2
cost time:0.0
n=1
table shape:4
cost time:0.0
n=2
table shape:8
cost time:0.0
n=3
table shape:16
cost time:0.0
n=4
table shape:32
cost time:0.0009982585906982422
n=5
table shape:32
cost time:0.0
n=6
table shape:32
cost time:0.0
n=7
table shape:64
cost time:0.0
n=8
table shape:128
cost time:0.0005800724029541016
n=9
table shape:256
cost time:0.0003147125244140625
n=10
table shape:256
cost time:0.0005629062652587891
n=11
table shape:256
cost time:0.0005421638488769531
n=12
table shape:256
cost time:0.0
n=13
table shape:256
cost time:0.0009999275207519531
n=14
table shape:256
cost time:0.0
n=15
table shape:256
cost time:0.0010006427764892578
n=16
table shape:256
cost time:0.0
n=17
table shape:256
cost time:0.001001119613647461
n=18
table shape:256
cost time:0.0
n=19
table shape:512
cost time:0.0009996891021728516
n=20
table shape:512
cost time:0.0012009143829345703
n=21
table shape:512
cost time:0.0011646747589111328
n=22
table shape:1024
cost time:0.0009801387786865234


In [15]:
import numpy as np

counts = {}
for detector_even_str, detector_even_probability in error_model.items():
    key = tuple(np.int32(list(detector_even_str)))
    counts[key] = detector_even_probability

In [16]:
def get_look_up_table_all_detector_purning(counts, syndrome_num=8):
    """根据syndrome测量值和逻辑值的概率分布, 构建look up table.

    Args:
        counts (_type_): syndrome测量值和逻辑值的概率分布

    Returns:
        dict: look up table
    """
    look_up_table = {}
    for key, value in counts.items():
        stablizer = key[:syndrome_num]
        logical = key[syndrome_num:]
        if stablizer not in look_up_table:
            look_up_table[stablizer] = logical[0]
        elif stablizer in look_up_table:
            if value > counts.get(stablizer +(0,), 0):
                # print(key, mz +(0,))
                # print(value, counts[mz +(0,)])
                look_up_table[stablizer] = logical[0]
    return look_up_table

In [17]:
look_up_table = get_look_up_table_all_detector_purning(counts=counts, syndrome_num=detector_error_model_no_decompose_errors.num_detectors)

## 对比lookup table方法和MWPM方法的差别。

In [18]:
import pymatching
detector_error_model_decompose_errors = circuit_noisy.detector_error_model(decompose_errors=True)

matcher = pymatching.Matching.from_detector_error_model(detector_error_model_decompose_errors)

In [19]:
import time

list_time = []
list_error = []
for i in range(10):
    shots = 100000
    detector_number = detector_error_model_no_decompose_errors.num_detectors
    sampler = circuit_noisy.compile_detector_sampler()
    # syndrome, actual_observables = sampler.sample(shots=shots, separate_observables=True)
    results = sampler.sample(shots=shots, append_observables=True)
    syndrome = results[:,:detector_number]
    actual_observables = results[:,detector_number:]
    
    start_time = time.time()
    error = 0
    for i in range(shots):
        detector_event = tuple(np.int32(results[i,:detector_number]))
        obs_flips_actual = int(np.int32(results[i,detector_number:])[0])
        # 如果没有look up table，说明都被剪枝，默认没有错误。
        if look_up_table.get(detector_event, 0) != obs_flips_actual:
            error += 1
    look_up_table_time = time.time() - start_time
    print(error)
    print(f"基于lookup table的解码方法: {error/shots}")
    print("基于lookup table的解码方法运行时间:", look_up_table_time)

    start_time = time.time()
    num_errors = 0
    for i in range(syndrome.shape[0]):
        # 将syndrome输入到matcher中，得到预测的observables，判断输入的和预测是是否相同。
        predicted_observables = matcher.decode(syndrome[i, :])
        num_errors += not np.array_equal(actual_observables[i, :], predicted_observables)
    mwpm_time = time.time() - start_time
    print(num_errors)
    print(f"基于pymatching的解码方法: {num_errors/shots}")
    print("基于pymatching的解码方法运行时间:", mwpm_time)
    
    list_error.append((error/shots, num_errors/shots))
    list_time.append((look_up_table_time, mwpm_time))

18013
基于lookup table的解码方法: 0.18013
基于lookup table的解码方法运行时间: 0.7080121040344238
8374
基于pymatching的解码方法: 0.08374
基于pymatching的解码方法运行时间: 1.7555673122406006
17981
基于lookup table的解码方法: 0.17981
基于lookup table的解码方法运行时间: 0.6848087310791016
8336
基于pymatching的解码方法: 0.08336
基于pymatching的解码方法运行时间: 1.7385199069976807
17831
基于lookup table的解码方法: 0.17831
基于lookup table的解码方法运行时间: 0.6810574531555176
8111
基于pymatching的解码方法: 0.08111
基于pymatching的解码方法运行时间: 1.7518293857574463
17914
基于lookup table的解码方法: 0.17914
基于lookup table的解码方法运行时间: 0.6789834499359131
8070
基于pymatching的解码方法: 0.0807
基于pymatching的解码方法运行时间: 1.7681310176849365
17827
基于lookup table的解码方法: 0.17827
基于lookup table的解码方法运行时间: 0.7029664516448975


KeyboardInterrupt: 

: 