In [2]:
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from tqdm import tqdm

from qec_generator import CircuitParams
from simulate_qec_rounds_stim import experiment_run
from stim_lib.scheduled_circuit import generate_scheduled
from scipy.optimize import curve_fit
import stim

In [236]:
first_flag_circ = stim.Circuit('''


    H 8
    CNOT 8 4
    CNOT 3 10
    CNOT 6 9
    CNOT 8 9  # 1st flag
    CNOT 8 1
    CNOT 4 10
    CNOT 5 9
        # inject error
    X 8
    
    CNOT 8 2
    CNOT 7 10
    CNOT 3 9
    CNOT 8 10  # second flag
    CNOT 8 3
    CNOT 6 10
    CNOT 2 9
    H 8
    MR 8 9 10
''')

second_flag_circ = stim.Circuit('''
    H 9 10
    CNOT 4 8
    CNOT 10 3
    CNOT 9 6
    CNOT 9 8  # 1st flag
    CNOT 1 8
    CNOT 10 4
    CNOT 9 5
            # inject error
    Z 8
    CNOT 2 8
    CNOT 10 7
    CNOT 9 3
    CNOT 10 8  # 2nd flag
    CNOT 3 8
    CNOT 10 6
    CNOT 9 2
    H 9 10
    MR 8 9 10
''')

unflagged = stim.Circuit('''

    
    H 8 9 10
    CNOT 8 1 8 2 8 3 8 4
    CNOT 9 2 9 3 9 5 9 6
    CNOT 10 3 10 4 10 6 10 7
    H 8 9 10
    MR 8 9 10
    CNOT 1 8 2 8 3 8 4 8
    CNOT 2 9 3 9 5 9 6 9
    CNOT 3 10 4 10 6 10 7 10
    MR 8 9 10
''')

encoding = stim.Circuit('''
    R 1 2 3 4 5 6 7 8 9 10
    H 1 5 7
    CNOT 1 2 5 6
    CNOT 7 4
    CNOT 7 6 5 3
    CNOT 1 4
    CNOT 5 2
    CNOT 4 3
    CNOT 2 8
    CNOT 4 8
    CNOT 6 8
    MR 8
''')

measure = stim.Circuit('''
    M 1 2 3 4 5 6 7
''')

In [237]:
sim = stim.TableauSimulator()
sim.do(encoding)
sim.do(second_flag_circ)
sim.do(unflagged)
rec = np.array(sim.current_measurement_record()).astype(int)
print('flagged', rec[1:4])
print('unflagged', rec[4:])

flagged [0 1 1]
unflagged [0 0 1 0 0 0]


In [238]:
def prep_zero():
    sim = stim.TableauSimulator()
    sim.do(encoding)
    return sim


def meas_flagging_syndromes_xzz(state):
    state.do(first_flag_circ)
    yyy = tuple(np.array(state.current_measurement_record())[-3:].astype(np.uint8))
    return yyy


def meas_flagging_syndromes_zxx(state):
    state.do(second_flag_circ)
    return tuple(np.array(state.current_measurement_record())[-3:].astype(np.uint8))


def meas_six_syndromes(state):
    state.do(unflagged)
    rec = np.array(state.current_measurement_record())[-6:].astype(np.uint8)
    return rec[-6:-3], rec[-3:]


def meas_z_data(state):
    state.do(measure)
    return np.array(state.current_measurement_record())[-7:].astype(np.uint8)

In [239]:
# note: we only test here the |0> state

def qec_cycles_exp (num_cycles, shots):
    runs = 0
    success = 0
    for i in range (shots):

        state = prep_zero()
        last_syndromes_x = np.zeros(3, dtype=np.uint8)
        last_syndromes_z = np.zeros(3, dtype=np.uint8)
        pf = np.zeros(2, dtype=np.uint8)

        for j in range (num_cycles) :
            state, last_syndromes_x, last_syndromes_z, pf = qec_cycle(state, last_syndromes_x, last_syndromes_z, pf)

        meas_output = logical_meas(state, last_syndromes_z)
        # Determine if measurement result was a s expected. rf so, upade success count: .
        expected_outcome = expected_result(meas_output)
        success += expected_outcome
        runs += 1

    return runs, success


def qec_cycle(state, last_syndrome_x, last_syndrome_z, pf):
    flag_diff_x = np.zeros(3, dtype=np.uint8)
    flag_diff_z = np.zeros(3, dtype=np.uint8)
    fx0, fz1, fz2 = meas_flagging_syndromes_xzz(state)
    
    flag_diff_x[0] = fx0 ^ last_syndrome_x[0]
    flag_diff_z[1] = fz1 ^ last_syndrome_z[1]
    flag_diff_z[2] = fz2 ^ last_syndrome_z[2]
    
    if np.all(flag_diff_x == [0] * 3) and np.all(flag_diff_z == [0] * 3):
        fz0, fx1, fx2 = meas_flagging_syndromes_zxx(state)
    
        flag_diff_z[0] = (fz0 + last_syndrome_z[0]) % 2
        flag_diff_x[1] = (fx1 + last_syndrome_x[1]) % 2
        flag_diff_x[2] = (fx2 + last_syndrome_x[2]) % 2
    
    if np.any(flag_diff_x != [0] * 3) or np.any(flag_diff_z != [0] * 3):
        syndromes_x, syndromes_z = meas_six_syndromes(state)
        syndrome_diff_x = (syndromes_x + last_syndrome_x) % 2
        syndrome_diff_z = (syndromes_z + last_syndrome_z) % 2
        pf[0] = (pf[0] + decoder_2d(syndrome_diff_x) + decoder_flag_update(syndrome_diff_x, flag_diff_x))
        pf[1] = (pf[1] + decoder_2d(syndrome_diff_z) + decoder_flag_update(syndrome_diff_z, flag_diff_z))
        
        last_syndrome_x = syndromes_x
        last_syndrome_z = syndromes_z
        
    return state, last_syndrome_x, last_syndrome_z, pf


def logical_meas(state, last_syndrome_z):
    m = meas_z_data(state)
    meas_output = np.sum(m[4:7]) % 2
    
    syndromes = (np.array([
        [1, 1, 1, 1, 0, 0, 0],
        [0, 1, 1, 0, 1, 1, 0],
        [0, 0, 1, 1, 0, 1, 1]
    ]) @ m) % 2
    
    syndrome_diff = (syndromes + last_syndrome_z) % 2
    final_correction = decoder_2d(syndrome_diff)
    return (meas_output + final_correction) % 2


def expected_result(meas_output):
    return (1 + meas_output) % 2



def decoder_2d(syndrome_diff):
    bad_syndromes = [[0, 1, 0], [0, 1, 1], [0, 0, 1]]
    if syndrome_diff.tolist() in bad_syndromes:
        return 1
    return 0
    
    
def decoder_flag_update(syndrome_diff, flag_diff):
    if np.all(flag_diff == [1, 0, 0]) and np.all(syndrome_diff == [0, 1, 0]):
        return 1
    if np.all(flag_diff == [1, 0, 0]) and np.all(syndrome_diff == [0, 0, 1]):
        return 1    
    if np.all(flag_diff == [0, 1, 1]) and np.all(syndrome_diff == [0, 0, 1]):
        return 1
    
    return 0
    

In [241]:
qec_cycles_exp(10, 1)

(0, 1, 1)
(0, 1, 0)
(0, 1, 1)
(0, 1, 0)
(0, 1, 1)
(0, 1, 0)
(0, 1, 1)
(0, 1, 0)
(0, 1, 1)
(0, 1, 0)


(1, 1.0)

In [204]:
uu1, uu2 = tuple(np.array([1, 2]))

In [205]:
uu1

1