In [71]:
import stim
print(stim.__version__)
import numpy as np
import scipy
from scipy.linalg import kron
from typing import List
from pprint import pprint
import time
from scipy.sparse import csc_matrix
import operator
import itertools
import random
from functools import reduce
from PyDecoder_polar import PyDecoder_polar_SCL
import pickle
from pathlib import Path

n = 7
N = 2 ** n
d = 15
if d == 7:
    wt_thresh = n - (n-1)//3 # for [[127,1,7]]
elif d == 15:
    wt_thresh = n - (n-1)//2 # for [[127,1,15]]
else:
    print("unsupported distance", d)

bin_wt = lambda i: bin(i)[2:].count('1')
bit_rev = lambda t: int(bin(t)[2:].rjust(n, '0')[::-1], 2)

int2bin = lambda i: [int(c) for c in bin(i)[2:].rjust(n, '0')]
bin2int = lambda l: int(''.join(map(str, l)), 2)

def ce(exclude, l=0, u=n): # choose except
    choices = set(range(l,u)) - set([exclude])
    return random.choice(list(choices))

def Eij(i,j):
    A = np.eye(n, dtype=int)
    A[i,j] = 1
    return A

def propagate(
    pauli_string: stim.PauliString,
    circuits: List[stim.Circuit]
) -> stim.PauliString:
    for circuit in circuits:
        pauli_string = pauli_string.after(circuit)
    return pauli_string

def form_pauli_string(
    flipped_pauli_product: List[stim.GateTargetWithCoords],
    num_qubits: int = N,
) -> stim.PauliString:
    xs = np.zeros(num_qubits, dtype=np.bool_)
    zs = np.zeros(num_qubits, dtype=np.bool_)
    for e in flipped_pauli_product:
        target_qubit, pauli_type = e.gate_target.value, e.gate_target.pauli_type
        if target_qubit >= num_qubits:
            continue
        if pauli_type == 'X':
            xs[target_qubit] = 1
        elif pauli_type == 'Z':
            zs[target_qubit] = 1
        elif pauli_type == 'Y':
            xs[target_qubit] = 1
            zs[target_qubit] = 1
    s = stim.PauliString.from_numpy(xs=xs, zs=zs)
    return s

1.13.0


In [72]:
def dict_to_csc_matrix(elements_dict, shape):
    # Constructs a `scipy.sparse.csc_matrix` check matrix from a dictionary `elements_dict` 
    # giving the indices of nonzero rows in each column.
    nnz = sum(len(v) for k, v in elements_dict.items())
    data = np.ones(nnz, dtype=np.uint8)
    row_ind = np.zeros(nnz, dtype=np.int64)
    col_ind = np.zeros(nnz, dtype=np.int64)
    i = 0
    for col, v in elements_dict.items():
        for row in v:
            row_ind[i] = row
            col_ind[i] = col
            i += 1
    return csc_matrix((data, (row_ind, col_ind)), shape=shape)

def dem_to_check_matrices(dem: stim.DetectorErrorModel, circuit, num_detector, tick_circuits, flip_type, verbose=False):
    # set flip_type to 0 for X-flips and 1 for Z-flips
    explained_errors: List[stim.ExplainedError] = circuit.explain_detector_error_model_errors(dem_filter=dem, reduce_to_one_representative_error=False)
    
    D_ids: Dict[str, int] = {} # detectors operators
    priors_dict: Dict[int, float] = {} # for each fault
    error_dict = {} # for where the fault happened
    residual_error_dict = {}

    def handle_error(prob: float, detectors: List[int], rep_loc) -> None:
        dets = frozenset(detectors)
        if len(dets) == 0:
            print(rep_loc, "triggers no detector")
        key = " ".join([f"D{s}" for s in sorted(dets)])

        if key not in D_ids:
            D_ids[key] = len(D_ids)
            priors_dict[D_ids[key]] = 0.0

        hid = D_ids[key]
        # priors_dict[hid] = priors_dict[hid] * (1 - prob) + prob * (1 - priors_dict[hid])
        priors_dict[hid] += prob # ignore second order cancellation
        # store error representative location
        error_dict[hid] = rep_loc
        # propagate error to the end of the circuit to create an residual fault PCM
        final_pauli_string = propagate(form_pauli_string(rep_loc.flipped_pauli_product), tick_circuits[rep_loc.tick_offset:])
        final_wt = final_pauli_string.weight
        if verbose:
            print(rep_loc)
            print("final pauli string", final_pauli_string, "weight", final_wt)
        residual_error_dict[hid] = final_pauli_string.to_numpy()[flip_type] # for bit flips, use [1] to extract phase flips
        
    index = 0
    for instruction in dem.flattened():
        if instruction.type == "error":
            dets: List[int] = []
            t: stim.DemTarget
            p = instruction.args_copy()[0]
            for t in instruction.targets_copy():
                if t.is_relative_detector_id():
                    dets.append(t.val)

#             print(explained_errors[index].circuit_error_locations[0]) ####################### location
            handle_error(p, dets, explained_errors[index].circuit_error_locations[0])
            index += 1
        elif instruction.type == "detector":
#             print("should not have detector, instruction", instruction)
            pass
        elif instruction.type == "logical_observable":
            print("should not have logical observable, instruction", instruction)
            pass
        else:
            raise NotImplementedError()
        
    check_matrix = dict_to_csc_matrix({v: [int(s[1:]) for s in k.split(" ") if s.startswith("D")] 
                                       for k, v in D_ids.items()},
                                      shape=(num_detector, len(D_ids)))
    priors = np.zeros(len(D_ids))
    for i, p in priors_dict.items():
        priors[i] = p

#     print("number of possible residual error strings", len(residual_error_dict))
#     print(np.concatenate([*residual_error_dict.values()]))
    return check_matrix, priors, error_dict, residual_error_dict

def get_pcm(permute, flip_type, verbose=False): # set flip_type to 0 for X-flips, 1 for Z-flips
    p_CNOT = 0.001
    p_single = 0.0005
    circuit = stim.Circuit()
    tick_circuits = [] # for PauliString.after
    num_detector = 0
    # initialization
    for i in range(N-1):
        if bin_wt(i) >= wt_thresh:
            circuit.append("RX", permute[i])
            circuit.append("Z_ERROR", permute[i], p_single)
        else:
            circuit.append("R", permute[i])
            circuit.append("X_ERROR", permute[i], p_single)
    circuit.append("R", N-1)
    circuit.append("TICK") ############################# TODO: fix bug

    for r in range(n): # rounds
        sep = 2 ** r
        tick_circuit = stim.Circuit()
        for j in range(0, N, 2*sep):
            for i in range(sep):
                if j+i+sep < N-1:
                    circuit.append("CNOT", [permute[j+i+sep], permute[j+i]])
                    tick_circuit.append("CNOT", [permute[j+i+sep], permute[j+i]])
                    circuit.append("DEPOLARIZE2", [permute[j+i+sep], permute[j+i]], p_CNOT)

        circuit.append("TICK")
        tick_circuits.append(tick_circuit)

    # syndrome detectors
    for r in range(n):
        sep = 2 ** r
        for j in range(0, N, 2*sep):
            for i in range(sep):
                circuit.append("CNOT", [j+i+sep, j+i])    
        circuit.append("TICK")

    for i in range(N-1): 
        if bin_wt(i) >= wt_thresh:
            circuit.append("MX", i)
        else:
            circuit.append("M", i)
    circuit.append("M", N-1)

    detector_str = ""
    if flip_type == 0: # bit-flips
        for i in range(N): # put detector on the punctured qubit, see if any single fault can trigger it
            if bin_wt(i) < wt_thresh:
                detector_str += f"DETECTOR rec[{-N+i}]\n"
                num_detector += 1
    else: # phase-flips
        for i in range(N): # put detector on the punctured qubit, see if any single fault can trigger it
            if bin_wt(i) >= wt_thresh: 
                detector_str += f"DETECTOR rec[{-N+i}]\n"
                num_detector += 1
    detector_circuit = stim.Circuit(detector_str)
    circuit += detector_circuit

    dem: stim.DetectorErrorModel = circuit.detector_error_model()
    dem_sampler: stim.CompiledDemSampler = dem.compile_sampler()
    pcm, priors, error_explain_dict, residual_error_dict = dem_to_check_matrices(dem, circuit, num_detector, tick_circuits, flip_type, verbose=verbose)
#     print("flip type", "Z" if flip_type else "X", " #detectors:", num_detector, " residual error shape", len(residual_error_dict))
    pcm = pcm.toarray()
#     if flip_type == 0: # bit-flips
#         print("last detector can be triggered by", pcm[-1,:].sum(), "faults")
    # circuit.diagram('timeline-svg')   
    return pcm, error_explain_dict, residual_error_dict


def get_plus_pcm(permute, flip_type, verbose=False): # set flip_type to 0 for X-flips, 1 for Z-flips
    p_CNOT = 0.001
    p_single = 0.0005
    circuit = stim.Circuit()
    tick_circuits = [] # for PauliString.after
    num_detector = 0
    # |+> initialization, bit-reversed w.r.t |0>
    for i in range(1,N):
        if bin_wt(i) >= wt_thresh:
            circuit.append("RX", permute[N-1-i])
            circuit.append("Z_ERROR", permute[N-1-i], p_single)
        else:
            circuit.append("R", permute[N-1-i])
            circuit.append("X_ERROR", permute[N-1-i], p_single)
    circuit.append("RX", N-1-0)
    circuit.append("TICK")

    for r in range(n): # rounds
        sep = 2 ** r
        tick_circuit = stim.Circuit()
        for j in range(0, N, 2*sep):
            for i in range(sep):
                if j+i+sep < N-1:
                    circuit.append("CNOT", [permute[j+i], permute[j+i+sep]])
                    tick_circuit.append("CNOT", [permute[j+i], permute[j+i+sep]])
                    circuit.append("DEPOLARIZE2", [permute[j+i], permute[j+i+sep]], p_CNOT)

        circuit.append("TICK")
        tick_circuits.append(tick_circuit)

    # syndrome detectors
    for r in range(n):
        sep = 2 ** r
        for j in range(0, N, 2*sep):
            for i in range(sep):
                circuit.append("CNOT", [j+i, j+i+sep])    
        circuit.append("TICK")

    detector_str = ""
    j = 0
    for i in range(1,N)[::-1]: 
        if bin_wt(i) >= wt_thresh:
            circuit.append("MX", N-1-i)
            if flip_type == 1: # phase-flips
                detector_str += f"DETECTOR rec[{-N+j}]\n"
                num_detector += 1
        else:
            circuit.append("M", N-1-i)
            if flip_type == 0: # bit-flips
                detector_str += f"DETECTOR rec[{-N+j}]\n"
                num_detector += 1
        j += 1
    circuit.append("MX", N-1-0)
#     detector_str += f"DETECTOR rec[-1]\n"; num_detector += 1 # put detector on the punctured qubit

    detector_circuit = stim.Circuit(detector_str)
    circuit += detector_circuit

    dem: stim.DetectorErrorModel = circuit.detector_error_model()
    dem_sampler: stim.CompiledDemSampler = dem.compile_sampler()
    pcm, priors, error_explain_dict, residual_error_dict = dem_to_check_matrices(dem, circuit, num_detector, tick_circuits, flip_type, verbose=verbose)
    print("flip type", "Z" if flip_type else "X", " #detectors:", num_detector, " residual error shape", len(residual_error_dict))
    pcm = pcm.toarray()
#     if flip_type == 1: # phase-flips
#     print("last detector can be triggered by", pcm[-1,:].sum(), "faults")
    # circuit.diagram('timeline-svg')   
    return pcm, error_explain_dict, residual_error_dict 

In [146]:
def Eij(i,j):
    A = np.eye(n, dtype=int)
    A[i,j] = 1
    return A
# permutations indicated by a list of Eij
if d == 7:
    PA = [(1,0),(2,1),(3,2),(4,3),(5,4),(0,3),(1,4)]
    PB = [(2,6),(5,1),(6,0),(0,5),(4,2),(0,3),(1,4)] 
    PC = [(3,1),(0,2),(2,6),(6,4),(5,0),(6,5),(3,6)]
    PD = [(5,3),(6,1),(1,2),(2,5),(4,0),(6,5),(3,6)]
elif d == 15:
    PA = [(1,2),(6,0),(4,3),(3,6),(0,1),(2,3),(1,6)]
#     PA = [(0, 3), (5, 6), (5, 3), (2, 1), (4, 6), (1, 0), (3, 2)] # work
    PB = [(2,6),(5,1),(6,0),(0,5),(4,2),(0,3),(1,4)] 
    PC = [(3,1),(0,2),(2,6),(6,4),(5,0),(6,5),(3,6)] 
    PD = [(5,3),(6,1),(1,2),(2,5),(4,0),(3,4),(4,5)] 
else:
    PA = []; PB = []; PC = []; PD = []

list_prod = lambda A : reduce(operator.matmul, [Eij(a[0],a[1]) for a in A], np.eye(n, dtype=int)) % 2

A1 = list_prod(PA[::-1]) % 2
A2 = list_prod(PB[::-1]) % 2
A3 = list_prod(PC[::-1]) % 2
A4 = list_prod(PD[::-1]) % 2
Ax = lambda A, i: N-1-bin2int(A @ np.array(int2bin(N-1-i)) % 2)
a1_permute = [Ax(A1, i) for i in range(N-1)]
a2_permute = [Ax(A2, i) for i in range(N-1)]
a3_permute = [Ax(A3, i) for i in range(N-1)]
a4_permute = [Ax(A4, i) for i in range(N-1)]

def get_inv_dict(pcm):
    inv_dict = {}
    num_col = pcm.shape[1]
    for i in range(num_col):
        key = int(''.join(pcm[:,i].astype('str')), 2)
        if key in inv_dict.keys():
            print("two different faults trigger the same set of detectors")
        else:
            inv_dict[key] = i
    return inv_dict

####################### Settings ######################
state = '0'
flip_type = 1  # 0 for X-flip, 1 for Z-flip
#######################################################

if state == '0':
    a1_pcm, a1_error_explain_dict, a1_residual_error_dict = get_pcm(a1_permute, flip_type)
    a2_pcm, a2_error_explain_dict, a2_residual_error_dict = get_pcm(a2_permute, flip_type)
    a3_pcm, a3_error_explain_dict, a3_residual_error_dict = get_pcm(a3_permute, flip_type)
    a4_pcm, a4_error_explain_dict, a4_residual_error_dict = get_pcm(a4_permute, flip_type)
else:
    a1_pcm, a1_error_explain_dict, a1_residual_error_dict = get_plus_pcm(a1_permute, flip_type, verbose=False)
    a2_pcm, a2_error_explain_dict, a2_residual_error_dict = get_plus_pcm(a2_permute, flip_type)
    a3_pcm, a3_error_explain_dict, a3_residual_error_dict = get_plus_pcm(a3_permute, flip_type)
    a4_pcm, a4_error_explain_dict, a4_residual_error_dict = get_plus_pcm(a4_permute, flip_type)
    

a3_num_col = a3_pcm.shape[1]
a4_num_col = a4_pcm.shape[1]
a1_inv_dict = get_inv_dict(a1_pcm)
a2_inv_dict = get_inv_dict(a2_pcm)
a3_inv_dict = get_inv_dict(a3_pcm)
a4_inv_dict = get_inv_dict(a4_pcm)

In [147]:
def test_pairs(a1_permute, a2_permute):
    a1_pcm, a1_error_explain_dict, a1_residual_error_dict = get_pcm(a1_permute, flip_type)
    a2_pcm, a2_error_explain_dict, a2_residual_error_dict = get_pcm(a2_permute, flip_type)
    a1_num_col = a1_pcm.shape[1]
    a2_num_col = a2_pcm.shape[1]
    a1_inv_dict = get_inv_dict(a1_pcm)
    a2_inv_dict = get_inv_dict(a2_pcm)
    print("test one fault on ancilla 1, one fault on ancilla 2")
    for i in range(a1_num_col):
        key = int(''.join(a1_pcm[:,i].astype('str')), 2)
        if key in a2_inv_dict.keys(): # two faults' stabilizer patterns cancel, pass the test
            j = a2_inv_dict[key]
            final_error = a1_residual_error_dict[i] # extract residual error on ancilla 1
            if final_error.sum() > 1:
    #                 return False
                print("final error weight", final_error.sum())
                print("explained faults:")
                print("on ancilla 1,", a1_error_explain_dict[i], "final error at", np.where(a1_residual_error_dict[i])[0])
                print("on ancilla 2,", a2_error_explain_dict[j], "final error at", np.where(a2_residual_error_dict[j])[0])

    print("test two faults on ancilla 1, one fault on ancilla 2, and create a1 two fault dictionary")
    a1_two_faults_dict = {} # key is stabilizer pattern, value is residual error
    a1_two_faults_explain_dict = {} # key is stabilizer pattern, value is which two columns from a1_pcm
    for i in range(a1_num_col):
        for j in range(i+1, a1_num_col):
            xor = (a1_pcm[:,i] + a1_pcm[:,j]) % 2
            key = int(''.join(xor.astype('str')), 2)
            if key in a2_inv_dict.keys():
                k = a2_inv_dict[key]
                final_error = a1_residual_error_dict[i] ^ a1_residual_error_dict[j]
                if final_error.sum() > 2:
    #                 final_error_a2 = a2_residual_error_dict[k]
    #                 if min(final_error.sum(), final_error_a2.sum()) > 2: # residual error on a1 is equivalent (up to stabilizer) to residual error on a2
    #                     return False
                    print("final error weight", final_error.sum())
                    print("explained faults:")
                    print("on ancilla 1,", a1_error_explain_dict[i], "final error at", np.where(a1_residual_error_dict[i])[0])
                    print("on ancilla 1,", a1_error_explain_dict[j], "final error at", np.where(a1_residual_error_dict[j])[0])
                    print("on ancilla 2,", a2_error_explain_dict[k], "final error at", np.where(a2_residual_error_dict[k])[0])

            if key not in a1_two_faults_dict.keys():
                a1_two_faults_dict[key] = a1_residual_error_dict[i] ^ a1_residual_error_dict[j]
                a1_two_faults_explain_dict[key] = (i,j)

    print("test one fault on ancilla 1, two faults on ancilla 2, and create a2 two fault dictionary")
    a2_two_faults_dict = {} # key is stabilizer pattern, value is residual error
    a2_two_faults_explain_dict = {} # key is stabilizer pattern, value is which two columns from a1_pcm
    for i in range(a2_num_col):
        for j in range(i+1, a2_num_col):
            xor = a2_pcm[:,i] ^ a2_pcm[:,j]
            key = int(''.join(xor.astype('str')), 2)
            if key in a1_inv_dict.keys():
                k = a1_inv_dict[key]
                # assert not np.any(a2_pcm[:,i] ^ a2_pcm[:,j] ^ a1_pcm[:,k])
                final_error = a1_residual_error_dict[k]
                if final_error.sum() > 2:
                    print("final error weight", final_error.sum())
                    print("explained faults:")
                    print("on ancilla 1,", a1_error_explain_dict[k], "final error at", np.where(a1_residual_error_dict[k])[0])
                    print("on ancilla 2,", a2_error_explain_dict[i], "final error at", np.where(a2_residual_error_dict[i])[0])
                    print("on ancilla 2,", a2_error_explain_dict[j], "final error at", np.where(a2_residual_error_dict[j])[0])
            if key not in a2_two_faults_dict.keys():
                a2_two_faults_dict[key] = a2_residual_error_dict[i] ^ a2_residual_error_dict[j]
                a2_two_faults_explain_dict[key] = (i,j)

    print("test two fault on ancilla 1, two faults on ancilla 2")
    for k1, v1 in a1_two_faults_dict.items():
        if k1 in a2_two_faults_dict.keys():
            if v1.sum() > 4:
    #             if min(v1.sum(), a2_two_faults_dict[k1].sum()) > 4:
    #                 return False
                print("final error weight", v1.sum())
                f1 = a1_two_faults_explain_dict[k1]
                f2 = a2_two_faults_explain_dict[k1]
                print("explained faults:")
                print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
                print("on ancilla 1,", a1_error_explain_dict[f1[1]], "final error at", np.where(a1_residual_error_dict[f1[1]])[0])
                print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
                print("on ancilla 2,", a2_error_explain_dict[f2[1]], "final error at", np.where(a2_residual_error_dict[f2[1]])[0])

    print("test three faults on ancilla 1 and one fault on ancilla 2, and vice versa")
    for i in range(a1_num_col):
        for j in range(a2_num_col):
            xor = (a1_pcm[:,i] + a2_pcm[:,j]) % 2
            key = int(''.join(xor.astype('str')), 2)
            if key in a1_two_faults_dict.keys(): # one fault on ancilla 2
                final_error = a2_residual_error_dict[j]
                if final_error.sum() > 4 and is_malignant(final_error, 4):
    #                     return False
                    print("3 faults on A1, 1 fault on A2, final error weight", final_error.sum())
            if key in a2_two_faults_dict.keys(): # one fault on ancilla 1
                final_error = a1_residual_error_dict[i]
                if final_error.sum() > 4 and is_malignant(final_error, 4):
    #                     return False
                    print("3 faults on A2, 1 fault on A1, final error weight", final_error.sum())
print("a1, a2")
test_pairs(a1_permute, a2_permute)
print("a1, a3")
test_pairs(a1_permute, a3_permute)
print("a1, a4")
test_pairs(a1_permute, a4_permute)
print("a2, a3")
test_pairs(a2_permute, a3_permute)
print("a2, a4")
test_pairs(a2_permute, a4_permute)
print("a3, a4")
test_pairs(a3_permute, a4_permute)

a1, a2
test one fault on ancilla 1, one fault on ancilla 2
test two faults on ancilla 1, one fault on ancilla 2, and create a1 two fault dictionary
test one fault on ancilla 1, two faults on ancilla 2, and create a2 two fault dictionary
test two fault on ancilla 1, two faults on ancilla 2
test three faults on ancilla 1 and one fault on ancilla 2, and vice versa
a1, a3
test one fault on ancilla 1, one fault on ancilla 2
test two faults on ancilla 1, one fault on ancilla 2, and create a1 two fault dictionary
test one fault on ancilla 1, two faults on ancilla 2, and create a2 two fault dictionary
test two fault on ancilla 1, two faults on ancilla 2
test three faults on ancilla 1 and one fault on ancilla 2, and vice versa
a1, a4
test one fault on ancilla 1, one fault on ancilla 2
test two faults on ancilla 1, one fault on ancilla 2, and create a1 two fault dictionary
test one fault on ancilla 1, two faults on ancilla 2, and create a2 two fault dictionary
test two fault on ancilla 1, two fa

In [148]:
print("test one fault on a1, one fault on a2, one fault on a3 or a4")
for i in range(a1_num_col):
    for j in range(a2_num_col):
        xor = (a1_pcm[:,i] + a2_pcm[:,j]) % 2
        key = int(''.join(xor.astype('str')), 2)
        if key in a3_inv_dict.keys():
            k = a3_inv_dict[key]
            final_error = a1_residual_error_dict[i] ^ a2_residual_error_dict[j]
            if final_error.sum() > 2:
                print("final error weight", final_error.sum())
                print("explained faults:")
                print("on ancilla 1,", a1_error_explain_dict[i], "final error at", np.where(a1_residual_error_dict[i])[0])
                print("on ancilla 2,", a2_error_explain_dict[j], "final error at", np.where(a2_residual_error_dict[j])[0])
                print("on ancilla 3,", a3_error_explain_dict[k], "final error at", np.where(a3_residual_error_dict[k])[0])
        if key in a4_inv_dict.keys():
            k = a4_inv_dict[key]
            final_error = a1_residual_error_dict[i] ^ a2_residual_error_dict[j]
            if final_error.sum() > 2:
                print("final error weight", final_error.sum())
                print("explained faults:")
                print("on ancilla 1,", a1_error_explain_dict[i], "final error at", np.where(a1_residual_error_dict[i])[0])
                print("on ancilla 2,", a2_error_explain_dict[j], "final error at", np.where(a2_residual_error_dict[j])[0])
                print("on ancilla 4,", a4_error_explain_dict[k], "final error at", np.where(a4_residual_error_dict[k])[0])

print("test one fault on a3, one fault on a4, one fault on a1 or a2")
for i in range(a3_num_col):
    for j in range(a4_num_col):
        xor = (a3_pcm[:,i] + a4_pcm[:,j]) % 2
        key = int(''.join(xor.astype('str')), 2)
        if key in a1_inv_dict.keys():
            k = a1_inv_dict[key]
            final_error = a1_residual_error_dict[k]
            if final_error.sum() > 2:
                print("final error weight", final_error.sum())
                print("explained faults:")
                print("on ancilla 3,", a3_error_explain_dict[i], "final error at", np.where(a3_residual_error_dict[i])[0])
                print("on ancilla 4,", a4_error_explain_dict[j], "final error at", np.where(a4_residual_error_dict[j])[0])
                print("on ancilla 1,", a1_error_explain_dict[k], "final error at", np.where(a1_residual_error_dict[k])[0])
        if key in a2_inv_dict.keys():
            k = a2_inv_dict[key]
            final_error = a2_residual_error_dict[k]
            if final_error.sum() > 2:
                print("final error weight", final_error.sum())
                print("explained faults:")
                print("on ancilla 3,", a3_error_explain_dict[i], "final error at", np.where(a3_residual_error_dict[i])[0])
                print("on ancilla 4,", a4_error_explain_dict[j], "final error at", np.where(a4_residual_error_dict[j])[0])
                print("on ancilla 2,", a2_error_explain_dict[k], "final error at", np.where(a2_residual_error_dict[k])[0])


test one fault on a1, one fault on a2, one fault on a3 or a4
test one fault on a3, one fault on a4, one fault on a1 or a2


In [149]:
print("test one fault on a1, one fault on a2, one fault on a3, one fault on a4")
a1_one_a2_one_dict = {}
a1_one_a2_one_explain_dict = {}
for i in range(a1_num_col):
    for j in range(a2_num_col):
        xor = a1_pcm[:,i] ^ a2_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a1_one_a2_one_dict.keys():
            a1_one_a2_one_dict[key] = a1_residual_error_dict[i] ^ a2_residual_error_dict[j]
            a1_one_a2_one_explain_dict[key] = (i,j)

a3_one_a4_one_dict = {}
a3_one_a4_one_explain_dict = {}
for i in range(a3_num_col):
    for j in range(a4_num_col):
        xor = a3_pcm[:,i] ^ a4_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a3_one_a4_one_dict.keys():
            a3_one_a4_one_dict[key] = a3_residual_error_dict[i] ^ a4_residual_error_dict[j]
            a3_one_a4_one_explain_dict[key] = (i,j)

a3_two_faults_dict = {} # key is stabilizer pattern, value is residual error
a3_two_faults_explain_dict = {} # key is stabilizer pattern, value is which two columns from a1_pcm
for i in range(a3_num_col):
    for j in range(i+1, a3_num_col):
        xor = a3_pcm[:,i] ^ a3_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a3_two_faults_dict.keys():
            a3_two_faults_dict[key] = a3_residual_error_dict[i] ^ a3_residual_error_dict[j]
            a3_two_faults_explain_dict[key] = (i,j)

a4_two_faults_dict = {} # key is stabilizer pattern, value is residual error
a4_two_faults_explain_dict = {} # key is stabilizer pattern, value is which two columns from a1_pcm
for i in range(a4_num_col):
    for j in range(i+1, a4_num_col):
        xor = a4_pcm[:,i] ^ a4_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a4_two_faults_dict.keys():
            a4_two_faults_dict[key] = a4_residual_error_dict[i] ^ a4_residual_error_dict[j]
            a4_two_faults_explain_dict[key] = (i,j)
        
for k1, v1 in a1_one_a2_one_dict.items():
    if k1 in a3_one_a4_one_dict.keys():
        if v1.sum() > 4:
            print("final error weight", v1.sum())
            f1 = a1_one_a2_one_explain_dict[k1]
            f2 = a3_one_a4_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f1[1]], "final error at", np.where(a2_residual_error_dict[f1[1]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[0]], "final error at", np.where(a3_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])
    if k1 in a3_two_faults_dict.keys():
        if v1.sum() > 4:
            print("final error weight", v1.sum())
            f1 = a1_one_a2_one_explain_dict[k1]
            f2 = a3_two_faults_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f1[1]], "final error at", np.where(a2_residual_error_dict[f1[1]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[0]], "final error at", np.where(a3_residual_error_dict[f2[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[1]], "final error at", np.where(a3_residual_error_dict[f2[1]])[0])  
    if k1 in a4_two_faults_dict.keys():
        if v1.sum() > 4:
            print("final error weight", v1.sum())
            f1 = a1_one_a2_one_explain_dict[k1]
            f2 = a4_two_faults_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f1[1]], "final error at", np.where(a2_residual_error_dict[f1[1]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[0]], "final error at", np.where(a4_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])
            
            
            
for k1, v1 in a3_one_a4_one_dict.items():
    if k1 in a1_two_faults_dict.keys():
        if v1.sum() > 4:
            print("final error weight", v1.sum())
            f1 = a3_one_a4_one_explain_dict[k1]
            f2 = a1_two_faults_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 3,", a3_error_explain_dict[f1[0]], "final error at", np.where(a3_residual_error_dict[f1[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f1[1]], "final error at", np.where(a4_residual_error_dict[f1[1]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[0]], "final error at", np.where(a1_residual_error_dict[f2[0]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[1]], "final error at", np.where(a1_residual_error_dict[f2[1]])[0])  
    if k1 in a2_two_faults_dict.keys():
        if v1.sum() > 4:
            print("final error weight", v1.sum())
            f1 = a3_one_a4_one_explain_dict[k1]
            f2 = a2_two_faults_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 3,", a3_error_explain_dict[f1[0]], "final error at", np.where(a3_residual_error_dict[f1[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f1[1]], "final error at", np.where(a4_residual_error_dict[f1[1]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[1]], "final error at", np.where(a2_residual_error_dict[f2[1]])[0])                  
            

test one fault on a1, one fault on a2, one fault on a3, one fault on a4
final error weight 8
explained faults:
on ancilla 3, CircuitErrorLocation {
    flipped_pauli_product: Z95*Z70
    Circuit location stack trace:
        (after 1 TICKs)
        at instruction #344 (DEPOLARIZE2) in the circuit
        at targets #1 to #2 of the instruction
        resolving to DEPOLARIZE2(0.001) 95 70
} final error at [ 70  78  87  95 102 110 119]
on ancilla 4, CircuitErrorLocation {
    flipped_pauli_product: Z110*Z111
    Circuit location stack trace:
        (after 1 TICKs)
        at instruction #336 (DEPOLARIZE2) in the circuit
        at targets #1 to #2 of the instruction
        resolving to DEPOLARIZE2(0.001) 110 111
} final error at [ 78  79  94  95 110 111 126]
on ancilla 1, CircuitErrorLocation {
    flipped_pauli_product: Z61*Z77
    Circuit location stack trace:
        (after 5 TICKs)
        at instruction #822 (DEPOLARIZE2) in the circuit
        at targets #1 to #2 of the instructi

In [137]:
decoder = PyDecoder_polar_SCL(3)
def is_malignant(s, order):
    num_flip = decoder.decode(list(np.nonzero(s)[0]))
    class_bit = decoder.last_info_bit
    is_malignant = False
    if num_flip > order or (class_bit==1 and state=='+' and flip_type==1) or\
                           (class_bit==1 and state=='0' and flip_type==0):
        is_malignant = True
    print(f"original wt: {s.sum()}, up to stabilizer: {num_flip}, is malignant: {is_malignant}")
    return is_malignant

a1_one_a3_one_dict = {}
a1_one_a3_one_explain_dict = {}
for i in range(a1_num_col):
    for j in range(a3_num_col):
        xor = a1_pcm[:,i] ^ a3_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a1_one_a3_one_dict.keys():
            a1_one_a3_one_dict[key] = a1_residual_error_dict[i]
            a1_one_a3_one_explain_dict[key] = (i,j)

a1_one_a4_one_dict = {}
a1_one_a4_one_explain_dict = {}
for i in range(a1_num_col):
    for j in range(a4_num_col):
        xor = a1_pcm[:,i] ^ a4_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a1_one_a4_one_dict.keys():
            a1_one_a4_one_dict[key] = a1_residual_error_dict[i]
            a1_one_a4_one_explain_dict[key] = (i,j)

a2_one_a3_one_dict = {}
a2_one_a3_one_explain_dict = {}
for i in range(a2_num_col):
    for j in range(a3_num_col):
        xor = a2_pcm[:,i] ^ a3_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a2_one_a3_one_dict.keys():
            a2_one_a3_one_dict[key] = a2_residual_error_dict[i]
            a2_one_a3_one_explain_dict[key] = (i,j)

a2_one_a4_one_dict = {}
a2_one_a4_one_explain_dict = {}
for i in range(a2_num_col):
    for j in range(a4_num_col):
        xor = a2_pcm[:,i] ^ a4_pcm[:,j]
        key = int(''.join(xor.astype('str')), 2)
        if key not in a2_one_a4_one_dict.keys():
            a2_one_a4_one_dict[key] = a2_residual_error_dict[i]
            a2_one_a4_one_explain_dict[key] = (i,j)

print("test 1123, 1124")
for k1, v1 in a1_two_faults_dict.items():
    if k1 in a2_one_a3_one_dict.keys():
        final_error = v1 ^ a2_one_a3_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a1_two_faults_explain_dict[k1]
            f2 = a2_one_a3_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f1[1]], "final error at", np.where(a1_residual_error_dict[f1[1]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[1]], "final error at", np.where(a3_residual_error_dict[f2[1]])[0])  
    if k1 in a2_one_a4_one_dict.keys():
        final_error = v1 ^ a2_one_a4_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a1_two_faults_explain_dict[k1]
            f2 = a2_one_a4_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 1,", a1_error_explain_dict[f1[0]], "final error at", np.where(a1_residual_error_dict[f1[0]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f1[1]], "final error at", np.where(a1_residual_error_dict[f1[1]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])  

print("test 2214, 2213")
for k1, v1 in a2_two_faults_dict.items():
    if k1 in a1_one_a4_one_dict.keys():
        final_error = v1 ^ a1_one_a4_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a2_two_faults_explain_dict[k1]
            f2 = a1_one_a4_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 2,", a2_error_explain_dict[f1[0]], "final error at", np.where(a2_residual_error_dict[f1[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f1[1]], "final error at", np.where(a2_residual_error_dict[f1[1]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[0]], "final error at", np.where(a1_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])  
    if k1 in a1_one_a3_one_dict.keys():
        final_error = v1 ^ a1_one_a3_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a2_two_faults_explain_dict[k1]
            f2 = a1_one_a3_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 2,", a2_error_explain_dict[f1[0]], "final error at", np.where(a2_residual_error_dict[f1[0]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f1[1]], "final error at", np.where(a2_residual_error_dict[f1[1]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[0]], "final error at", np.where(a1_residual_error_dict[f2[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[1]], "final error at", np.where(a3_residual_error_dict[f2[1]])[0])  

print("test 3314, 3324")
for k1, v1 in a3_two_faults_dict.items():
    if k1 in a1_one_a4_one_dict.keys():
        final_error = a1_one_a4_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a3_two_faults_explain_dict[k1]
            f2 = a1_one_a4_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 3,", a3_error_explain_dict[f1[0]], "final error at", np.where(a3_residual_error_dict[f1[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f1[1]], "final error at", np.where(a3_residual_error_dict[f1[1]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[0]], "final error at", np.where(a1_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])  
    if k1 in a2_one_a4_one_dict.keys():
        final_error = a2_one_a4_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a3_two_faults_explain_dict[k1]
            f2 = a2_one_a4_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 3,", a3_error_explain_dict[f1[0]], "final error at", np.where(a3_residual_error_dict[f1[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f1[1]], "final error at", np.where(a3_residual_error_dict[f1[1]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f2[1]], "final error at", np.where(a4_residual_error_dict[f2[1]])[0])  

            
print("test 4413, 4423")
for k1, v1 in a4_two_faults_dict.items():
    if k1 in a1_one_a3_one_dict.keys():
        final_error = a1_one_a3_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a4_two_faults_explain_dict[k1]
            f2 = a1_one_a3_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 4,", a4_error_explain_dict[f1[0]], "final error at", np.where(a4_residual_error_dict[f1[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f1[1]], "final error at", np.where(a4_residual_error_dict[f1[1]])[0])
            print("on ancilla 1,", a1_error_explain_dict[f2[0]], "final error at", np.where(a1_residual_error_dict[f2[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[1]], "final error at", np.where(a3_residual_error_dict[f2[1]])[0])  
    if k1 in a2_one_a3_one_dict.keys():
        final_error = a2_one_a3_one_dict[k1]
        if final_error.sum() > 4 and is_malignant(final_error, 4):
            print("final error weight", final_error.sum())
            f1 = a4_two_faults_explain_dict[k1]
            f2 = a2_one_a3_one_explain_dict[k1]
            print("explained faults:")
            print("on ancilla 4,", a4_error_explain_dict[f1[0]], "final error at", np.where(a4_residual_error_dict[f1[0]])[0])
            print("on ancilla 4,", a4_error_explain_dict[f1[1]], "final error at", np.where(a4_residual_error_dict[f1[1]])[0])
            print("on ancilla 2,", a2_error_explain_dict[f2[0]], "final error at", np.where(a2_residual_error_dict[f2[0]])[0])
            print("on ancilla 3,", a3_error_explain_dict[f2[1]], "final error at", np.where(a3_residual_error_dict[f2[1]])[0])  


test 1123, 1124
test 2214, 2213
test 3314, 3324
test 4413, 4423
