In [16]:
import stim
import numpy as np
from numpy.linalg import matrix_power, matrix_rank
import matplotlib.pyplot as plt
from mec import make_circle
import galois
from scipy.sparse import lil_matrix
from ldpc import bp_decoder, bposd_decoder
from tqdm import tqdm

In [4]:
code = [12,3,9,1,2,0,7,2,3,12,2,1]

In [5]:
def cyclic_shift_matrix(l):
    arr = np.eye(l, dtype=int)
    return np.roll(arr, axis=1, shift=1)

ell = code[0]
m = code[1]

x = np.kron(cyclic_shift_matrix(ell), np.eye(m))
y = np.kron(np.eye(ell), cyclic_shift_matrix(m))

A1 = matrix_power(x, code[2])
A2 = matrix_power(y, code[3])
A3 = matrix_power(y, code[4])
A = ( A1 + A2 + A3 ) % 2

B1 = matrix_power(y, code[5])
B2 = matrix_power(x, code[6])
B3 = matrix_power(x, code[7])
B = ( B1 + B2 + B3 ) % 2

Hx = np.hstack([A, B]).astype(int)
Hz = np.hstack([B.T, A.T]).astype(int)

GF = galois.GF(2)
arr = GF(Hz.T)
k = 2 * (Hz.T.shape[1] - matrix_rank(arr))

In [6]:
def par2gen(H):
    GF = galois.GF(2)
    gfH = GF(H)
    gfH_rank = np.linalg.matrix_rank(gfH)

    rref_H = gfH.row_reduce()

    swaps = []
    col_H = rref_H.copy()
    for i in range(gfH_rank):
        inds = np.where(col_H[i])[0]
        pivot = inds[0]
        col_H[:,[i,pivot]] = col_H[:,[pivot,i]]
        swaps.append((i,pivot))

    col_H = col_H[:gfH_rank]
    col_G = GF(np.hstack([col_H[:,gfH_rank:].T, np.eye(H.shape[1]-gfH_rank, dtype=int)]))

    G = col_G.copy()
    for swap in swaps[::-1]:
        G[:,[swap[1],swap[0]]] = G[:,[swap[0],swap[1]]]

    if (np.any(G @ rref_H[:gfH_rank].T) or np.any(col_G @ col_H.T)):
        print("FAILED")
        return
    return (np.array(G, dtype=int), np.array(col_G, dtype=int))

def commute(x, z, n):
    # 0 if commute, 1 if anticommute
    x1 = x[:n]
    x2 = x[n:]
    z1 = z[:n]
    z2 = z[n:]
    return (x1 @ z2 % 2) ^ (x2 @ z1 % 2)


def SGSOP(Gx, Gz, n):
    # symplectic gram-schmidt orthogonalization procedure
    sym_Gx = np.hstack([Gx, np.zeros(Gx.shape, dtype=int)])
    sym_Gz = np.hstack([np.zeros(Gz.shape, dtype=int), Gz])
    sym_G = np.vstack([sym_Gx, sym_Gz])
    logicals = []
    generators = []

    while(sym_G.shape[0]):
        g1 = sym_G[0]

        commutes = True
        for i in range(1, sym_G.shape[0]-1):
            g2 = sym_G[i]
            if (commute(g1,g2,n)):
                logicals.append((g1, g2))
                sym_G = np.delete(sym_G, [0, i], axis=0)

                for j in range(sym_G.shape[0]):
                    gj = sym_G[j]
                    sym_G[j] = gj ^ (commute(gj,g2,n) * g1) ^ (commute(gj,g1,n) * g2)
                commutes = False
                break

        if commutes:
            generators.append(g1)
            sym_G = np.delete(sym_G, 0, axis=0)

    return (logicals, generators)

In [7]:
def get_logicals(gen_type=False):
    n = Hx.shape[1]
    Gx, col_Gx = par2gen(Hx)
    Gz, col_Gz = par2gen(Hz)
    logicals, generators = SGSOP(Gx, Gz, n)

    logX = np.array([l[1][n:] for l in logicals])
    logZ = np.array([l[0][:n] for l in logicals])

    if gen_type: return logX
    else: return logZ

In [698]:
def embed_code(code, init):
    emb_m, emb_ell, A_ind, B_ind = code

    lattice = np.empty((2*emb_m, 2*emb_ell), dtype=object)
    lattice[0][0] = f"x{init}"

    # As = [[A1, A2.T], [A2, A3.T], [A1, A3.T]]
    # Bs = [[B1, B2.T], [B2, B3.T], [B1, B3.T]]
    As = [[A1, A2.T], [A2, A1.T], [A2, A3.T], [A3, A2.T], [A1, A3.T], [A3, A1.T]]
    Bs = [[B1, B2.T], [B2, B1.T], [B2, B3.T], [B3, B2.T], [B1, B3.T], [B3, B1.T]]

    def get_nbr(i, j):
        if (i % 2 == 0):
            if (j % 2 == 0):
                return "x"
            else:
                return "r"
        else:
            if (j % 2 == 0):
                return "l"
            else:
                return "z"

    for i in range(2*emb_m - 1):
        for j in range(2*emb_ell):
            curr_ind = int(lattice[i][j][1:])

            if (i % 2 == 0):
                tmp_A = As[A_ind][1]
            else:
                tmp_A = As[A_ind][0]
            if (j % 2 == 0):
                tmp_B = Bs[B_ind][1]
            else:
                tmp_B = Bs[B_ind][0]

            lattice[(i+1)%(2*emb_m)][j] = f"{get_nbr((i+1)%(2*emb_m), j)}{np.where(tmp_A @ np.eye(m*ell)[curr_ind])[0][0]}"
            lattice[i][(j+1)%(2*emb_ell)] = f"{get_nbr(i, (j+1)%(2*emb_ell))}{np.where(tmp_B @ np.eye(m*ell)[curr_ind])[0][0]}"

    for i in range(2*emb_m):
        for j in range(2*emb_ell):
            if (lattice[i][j][0] == "z"):
                lattice[i][j] = f"z{int(lattice[i][j][1:]) + m*ell}"
            elif (lattice[i][j][0] == "r"):
                lattice[i][j] = f"r{int(lattice[i][j][1:]) + m*ell}"

    return lattice

lattice = embed_code((code[8],code[9],code[10],code[11]), 0)

all_qbts = {}

qbts = np.array([None for i in range(2*m*ell)])
for i in range(lattice.shape[0]):
    for j in range(lattice.shape[1]):
        if lattice[i][j][0] == "r" or lattice[i][j][0] == "l":
            all_qbts[(i,j)] = int(lattice[i][j][1:])
            qbts[int(lattice[i][j][1:])] = (i, j)
x_checks = np.array([None for i in range(m*ell)])
z_checks = np.array([None for i in range(m*ell)])

for i in range(lattice.shape[0]):
    for j in range(lattice.shape[1]):
        if lattice[i][j][0] == "x":
            all_qbts[(i,j)] = int(lattice[i][j][1:]) + 2*m*ell
            x_checks[int(lattice[i][j][1:])] = (i, j)
        elif lattice[i][j][0] == "z":
            all_qbts[(i,j)] = int(lattice[i][j][1:]) + 2*m*ell
            z_checks[int(lattice[i][j][1:])-(m*ell)] = (i, j)

x_rs = []
z_rs = []
for i in range(m*ell):
    gen_qbts = qbts[np.where(Hx[i])[0]]
    x_rs.append(make_circle(gen_qbts)[2])
for i in range(m*ell):
    gen_qbts = qbts[np.where(Hz[i])[0]]
    z_rs.append(make_circle(gen_qbts)[2])

lr_x_checks = np.array([], dtype=int)
sr_x_checks = np.array([], dtype=int)
lr_z_checks = np.array([], dtype=int)
sr_z_checks = np.array([], dtype=int)

for i, x_check in enumerate(x_checks):
    gen_qbts = qbts[np.where(Hx[i])[0]]

    nonlocal_qbts = []
    if (x_rs[i] > (min(x_rs)+np.std(x_rs))):
        lr_x_checks = np.append(lr_x_checks, i)
    else:
        sr_x_checks = np.append(sr_x_checks, i)

for i, z_check in enumerate(z_checks):
    gen_qbts = qbts[np.where(Hz[i])[0]]

    nonlocal_qbts = []
    if (z_rs[i] > min(z_rs)+np.std(z_rs)):
        lr_z_checks = np.append(lr_z_checks, i)
    else:
        sr_z_checks = np.append(sr_z_checks, i)

In [699]:
def lr_bell_pair(paths, p):
    c = stim.Circuit()

    for path in paths:
        size = len(path)
        c.append("CNOT", path[:size-(size%2)])
        c.append("DEPOLARIZE2", path[:size-(size%2)], p)
    c.append("TICK")

    for path in paths:
        size = len(path)
        c.append("CNOT", path[1:size-1+(size%2)])
        c.append("DEPOLARIZE2", path[1:size-1+(size%2)], p)
    c.append("TICK")

    for path in paths:
        c.append("H", path[:-1][1::2])
        c.append("DEPOLARIZE1", path[:-1][2::2], p)
    c.append("TICK")

    for path in paths:
        c.append("X_ERROR", path[1:-1], p)
        c.append("MR", path[1:-1])
        c.append("X_ERROR", path[1:-1], p)
    c.append("TICK")

    for j, path in enumerate(paths):
        tot_len = sum([len(p[1:-1]) for p in paths[j:]])
        size = len(path)
        for i in range(1 + (size%2), size-1, 2):
            c.append("CZ", [stim.target_rec(-tot_len+i-1), path[0]])
        c.append("DEPOLARIZE1", path[0], p)
        for i in range(2 - (size%2), size-1, 2):
            c.append("CX", [stim.target_rec(-tot_len+i-1), path[-1]])
        c.append("DEPOLARIZE1", path[-1], p)
    c.append("TICK")

    return c

def lr_CNOT_bell(paths, p):
    c = stim.Circuit()

    for path in paths:
        c.append("CNOT", [path[0], path[2][0], path[2][1], path[1]])
        c.append("DEPOLARIZE2", [path[0], path[2][0], path[2][1], path[1]], p)
    c.append("TICK")

    for path in paths:
        c.append("MR", path[2][0])
        c.append("MRX", path[2][1])
    c.append("TICK")

    for j, path in enumerate(paths[::-1]):
        c.append("CX", [stim.target_rec(-2*j-2), path[1]])
        c.append("CZ", [stim.target_rec(-2*j-1), path[0]])
        c.append("DEPOLARIZE1", [path[0], path[1]], p)
    c.append("TICK")
    return c

def lr_CNOT_no_bell(paths):
    # path[0] is control, path[-1] is target. Reverse path to get reverse CNOT
    c = stim.Circuit()

    for path in paths:
        size = len(path)
        c.append("H", path[:-1][2::2])
        c.append("DEPOLARIZE1", path[:-1][2::2], 0.001)
    c.append("TICK")

    for path in paths:
        size = len(path)
        c.append("CNOT", path[:size-(size%2)])
        c.append("DEPOLARIZE2", path[:size-(size%2)], 0.001)
    c.append("TICK")

    for path in paths:
        size = len(path)
        c.append("CNOT", path[1:size-1+(size%2)])
        c.append("DEPOLARIZE2", path[1:size-1+(size%2)], 0.001)
    c.append("TICK")

    for path in paths:
        c.append("H", path[:-1][1::2])
        c.append("DEPOLARIZE1", path[:-1][2::2], 0.001)
    c.append("TICK")

    for path in paths:
        c.append("X_ERROR", path[1:-1], 0.001)
        c.append("MR", path[1:-1])
        c.append("X_ERROR", path[1:-1], 0.001)
    c.append("TICK")

    for j, path in enumerate(paths):
        tot_len = sum([len(p[1:-1]) for p in paths[j:]])
        size = len(path)
        for i in range(1 + (size%2), size-1, 2):
            c.append("CZ", [stim.target_rec(-tot_len+i-1), path[0]])
        c.append("DEPOLARIZE1", path[0], 0.001)
        for i in range(2 - (size%2), size-1, 2):
            c.append("CX", [stim.target_rec(-tot_len+i-1), path[-1]])
        c.append("DEPOLARIZE1", path[-1], 0.001)
    c.append("TICK")

    return c, sum([len(p[1:-1]) for p in paths])

def direct_CNOT(paths, p):
    c = stim.Circuit()
    qbt_paths = []
    for path in paths:
        qbt_paths += [all_qbts[path[0]], all_qbts[path[-1]]]
    c.append("CNOT", qbt_paths)
    c.append("DEPOLARIZE2", qbt_paths, p)
    c.append("TICK")
    return c


In [683]:
def measure_sr_z_checks_direct(p):
    def sr_z_down(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(4)][::-1] for z in z_checks_i]
        # pur_paths = [[all_qbts[(z_checks[z][0]-1, z_checks[z][1]+i)]+4*m*ell for i in range(4)] for z in z_checks_i]
        return direct_CNOT(z_paths, 0.0015451111874128982)
    def sr_z_across(checks, gen_index, mod, col_ind):
        if col_ind == [1,2]:
            z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
            z_paths = [[(z_checks[z][0], z_checks[z][1]), (z_checks[z][0]-3, z_checks[z][1]+6)][::-1] for z in z_checks_i]
            return direct_CNOT(z_paths, 0.0016733051369441137)
        elif col_ind == [0]:
            z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
            z_paths = [[(z_checks[z][0], z_checks[z][1]), (z_checks[z][0]+3, z_checks[z][1]+6)][::-1] for z in z_checks_i]
            return direct_CNOT(z_paths, 0.0016733051369441137)
    def z_right_bdy(checks, gen_index, mod, col_ind=[2]):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0]-i, z_checks[z][1]) for i in range(6)][::-1] for z in z_checks_i]
        # pur_paths = [[all_qbts[(z_checks[z][0]-i, z_checks[z][1]-1)]+4*m*ell for i in range(6)] for z in z_checks_i]
        return direct_CNOT(z_paths, 0.0013810611830413811)

    def z_down_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_left_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0]-i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_up_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0], z_checks[z][1]-i) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_right_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0]+i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)

    c = stim.Circuit()
    c += sr_z_down(sr_z_checks, 0, 1, [0,1,2])
    c += sr_z_across(sr_z_checks, 0, 2, [0])
    c += sr_z_across(sr_z_checks, 0, 1, [1,2])
    c += z_right_bdy(sr_z_checks, 0, 1)

    c += z_down_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_left_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_up_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_right_local(sr_z_checks, 0, 1, [0,1])
    return c

# 4 0.0015451111874128982 0.99022
# 6 0.0013810611830413811 0.98475
# 10 0.0016733051369441137 0.97412

def measure_sr_z_checks_bell(p):
    def sr_z_down(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        bell_paths = [[all_qbts[(z_checks[z][0], z_checks[z][1]+i)]+4*m*ell for i in range(4)][::-1] for z in z_checks_i]
        z_paths = [[z[0]-4*m*ell,z[-1]-4*m*ell,(z[0], z[-1])] for z in bell_paths]
        # pur_paths = [[all_qbts[(z_checks[z][0]-1, z_checks[z][1]+i)]+4*m*ell for i in range(4)] for z in z_checks_i]

        c = stim.Circuit()
        c += lr_bell_pair(bell_paths, p)
        c += lr_CNOT_bell(z_paths, p)
        c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])
        return c
    def z_right_bdy(checks, gen_index, mod, col_ind=[2]):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        bell_paths = [[all_qbts[(z_checks[z][0]-i, z_checks[z][1])]+4*m*ell for i in range(6)][::-1] for z in z_checks_i]
        z_paths = [[z[0]-4*m*ell,z[-1]-4*m*ell,(z[0], z[-1])] for z in bell_paths]
        # pur_paths = [[all_qbts[(z_checks[z][0]-i, z_checks[z][1]-1)]+4*m*ell for i in range(6)] for z in z_checks_i]

        c = stim.Circuit()
        c += lr_bell_pair(bell_paths, p)
        c += lr_CNOT_bell(z_paths, p)
        c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])
        return c
    def sr_z_across(checks, gen_index, mod, col_ind):
        c = stim.Circuit()
        if col_ind == [1,2] or col_ind == [1] or col_ind == [2]:
            z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
            bell_paths = []
            for z in z_checks_i:
                tmp_path = []
                tmp_path += [(z_checks[z][0], z_checks[z][1]+ii) for ii in range(4)]
                tmp_path += [(z_checks[z][0]-ii, z_checks[z][1]+3) for ii in range(1,4)]
                tmp_path += [(z_checks[z][0]-3, z_checks[z][1]+3+ii) for ii in range(1,4)]
                bell_paths.append([all_qbts[node]+4*m*ell for node in tmp_path[::-1]])
            # [[(z_checks[z][0], z_checks[z][1]), (z_checks[z][0]+3, z_checks[z][1]+6)][::-1] for z in z_checks_i]
            z_paths = [[z[0]-4*m*ell,z[-1]-4*m*ell,(z[0], z[-1])] for z in bell_paths]
        elif col_ind == [0]:
            z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
            bell_paths = []
            for z in z_checks_i:
                tmp_path = []
                tmp_path += [(z_checks[z][0], z_checks[z][1]+ii) for ii in range(4)]
                tmp_path += [(z_checks[z][0]+ii, z_checks[z][1]+3) for ii in range(1,4)]
                tmp_path += [(z_checks[z][0]+3, z_checks[z][1]+3+ii) for ii in range(1,4)]
                bell_paths.append([all_qbts[node]+4*m*ell for node in tmp_path[::-1]])
            # [[(z_checks[z][0], z_checks[z][1]), (z_checks[z][0]+3, z_checks[z][1]+6)][::-1] for z in z_checks_i]
            z_paths = [[z[0]-4*m*ell,z[-1]-4*m*ell,(z[0], z[-1])] for z in bell_paths]

        c += lr_bell_pair(bell_paths, p)
        c += lr_CNOT_bell(z_paths, p)
        c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])
        return c

    def z_down_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_left_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0]-i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_up_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0], z_checks[z][1]-i) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)
    def z_right_local(checks, gen_index, mod, col_ind):
        z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
        z_paths = [[(z_checks[z][0]+i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
        return direct_CNOT(z_paths, p)

    c = stim.Circuit()

    c += sr_z_down(sr_z_checks, 0, 2, [0,1,2])
    c += sr_z_down(sr_z_checks, 1, 2, [0,1,2])
    c += sr_z_across(sr_z_checks, 0, 2, [0])
    c += sr_z_across(sr_z_checks, 1, 2, [0])
    c += sr_z_across(sr_z_checks, 0, 2, [1])
    c += sr_z_across(sr_z_checks, 1, 2, [1])
    c += sr_z_across(sr_z_checks, 0, 2, [2])
    c += sr_z_across(sr_z_checks, 1, 2, [2])
    c += z_right_bdy(sr_z_checks, 0, 1)

    c += z_down_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_left_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_up_local(sr_z_checks, 0, 1, [0,1,2])
    c += z_right_local(sr_z_checks, 0, 1, [0,1])

    return c

# c = stim.Circuit()
# for key, value in all_qbts.items():
#     c.append("QUBIT_COORDS", value, (key[0],key[1],0))
#     c.append("QUBIT_COORDS", value+(4*m*ell), (key[0],key[1],1))
# c.append("R", [qbt for qbt in all_qbts.values()])
# c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])
# # c += measure_sr_z_checks_direct(0.001)
# c += measure_sr_z_checks_bell(0.001)

# with open("tmp.svg", "w") as f:
#     f.write(str(c.without_noise().diagram("timeslice-svg")))

In [649]:
def manhattan(qbts):
    p, q = qbts
    return np.abs(p[0]-q[0])+np.abs(p[1]-q[1])

def measure_x_checks(checks, p, scale=False):
    c = stim.Circuit()
    c.append("H", [all_qbts[x_checks[x_check]] for x_check in checks])
    c.append("DEPOLARIZE1", [all_qbts[x_checks[x_check]] for x_check in checks], p)
    for x in checks:
        gen_qbts = qbts[np.where(Hx[x])[0]]
        for qbt in gen_qbts:
            path_qbts = [all_qbts[x_checks[x]], all_qbts[qbt]]
            c.append("CNOT", path_qbts)
            if scale:
                c.append("DEPOLARIZE2", path_qbts, p*manhattan([x_checks[x], qbt])/2)
            else:
                c.append("DEPOLARIZE2", path_qbts, p)
    c.append("H", [all_qbts[x_checks[x_check]] for x_check in checks])
    c.append("DEPOLARIZE1", [all_qbts[x_checks[x_check]] for x_check in checks], p)
    return c

def measure_z_checks(checks, p, scale=False):
    c = stim.Circuit()
    for z in checks:
        gen_qbts = qbts[np.where(Hz[z])[0]]
        for qbt in gen_qbts:
            path_qbts = [all_qbts[qbt], all_qbts[z_checks[z]]]
            c.append("CNOT", path_qbts)
            if scale:
                c.append("DEPOLARIZE2", path_qbts, p*manhattan([qbt, z_checks[z]])/2)
            else:
                c.append("DEPOLARIZE2", path_qbts, p)
    return c

def all_checks():
    c = stim.Circuit()
    c += measure_z_checks(sr_z_checks, False)
    c += measure_z_checks(lr_z_checks, False)
    c += measure_x_checks(sr_x_checks, False)
    c += measure_x_checks(lr_x_checks, False)
    return c

In [805]:
class Simulation:
    def __init__(self, num_rounds, lr_time):
        self.num_rounds = num_rounds
        self.lr_time = lr_time

        self.prev_meas_z = np.arange(1, m*ell+1, dtype=int)
        self.prev_meas_x = np.arange(m*ell+1, 2*m*ell+1,  dtype=int)
        self.curr_meas_z = np.zeros(m*ell, dtype=int)
        self.curr_meas_x = np.zeros(m*ell, dtype=int)

        self.route_confirmation_z = np.ones(m*ell, dtype=int)
        self.route_confirmation_z[lr_z_checks] = 0
        self.detector_history = np.zeros(m*ell)

        self.c = stim.Circuit()
        for key, value in all_qbts.items():
            self.c.append("QUBIT_COORDS", value, (key[0],key[1],0))
            self.c.append("QUBIT_COORDS", value+(4*m*ell), (key[0],key[1],1))
        self.c.append("R", [qbt for qbt in all_qbts.values()])
        self.c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])

        self.c += all_checks().without_noise()
        self.c.append("MR", [all_qbts[z_check] for z_check in z_checks])
        self.c.append("MR", [all_qbts[x_check] for x_check in x_checks])

    def detectors(self, type):
        num_meas = self.c.num_measurements
        if not type:
            for i, z_check in enumerate(self.curr_meas_z):
                coord = z_checks[i]
                if z_check:
                    self.c.append("DETECTOR", [stim.target_rec(self.curr_meas_z[i]-num_meas-1), stim.target_rec(self.prev_meas_z[i]-num_meas-1)], (coord[0], coord[1], 0))
                    self.prev_meas_z[i] = self.curr_meas_z[i]
                    self.curr_meas_z[i] = 0
        else:
            pass # x type checks

    def observables(self, type):
        for i, logical in enumerate(get_logicals(type)):
            incl_qbts = np.where(logical)[0]
            incl_qbts = [-j-1 for j in incl_qbts]
            self.c.append("OBSERVABLE_INCLUDE", [stim.target_rec(j) for j in incl_qbts], i)

    def sr_round(self, with_noise=True):
        curr_sr_z_checks = sr_z_checks[self.route_confirmation_z[sr_z_checks]==1]
        self.c += measure_z_checks(curr_sr_z_checks, 0.001 if with_noise else 0, False)
        # self.c += measure_sr_z_checks_direct(0.001)
        # self.c += measure_sr_z_checks_bell(0.001)
        self.c += measure_x_checks(sr_x_checks, 0.001 if with_noise else 0, False)

        if with_noise: self.c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in curr_sr_z_checks], 0.001)
        if with_noise: self.c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks], 0.001)

        for i, z_check in enumerate(curr_sr_z_checks):
            self.c.append("MR", all_qbts[z_checks[z_check]])
            self.curr_meas_z[z_check] = self.c.num_measurements
        for i, z_check in enumerate(sr_z_checks):
            self.c.append("R", all_qbts[z_checks[z_check]])
        for i, x_check in enumerate(sr_x_checks):
            self.c.append("MR", all_qbts[x_checks[x_check]])
            self.curr_meas_x[x_check] = self.c.num_measurements

        if with_noise: self.c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in curr_sr_z_checks], 0.001)
        if with_noise: self.c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks], 0.001)


    def lr_round(self, with_noise=True):
        all_z_checks = np.concatenate([sr_z_checks, lr_z_checks])
        all_x_checks = np.concatenate([sr_x_checks, lr_x_checks])
        self.c += measure_z_checks(sr_z_checks, 0.001 if with_noise else 0, False)
        # self.c += measure_sr_z_checks_direct(0.001 if with_noise else 0)
        self.c += measure_z_checks(lr_z_checks, 0.001 if with_noise else 0, False)
        self.c += measure_x_checks(sr_x_checks, 0.001 if with_noise else 0, False)
        self.c += measure_x_checks(lr_x_checks, 0.001 if with_noise else 0, False)

        if with_noise: self.c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in all_z_checks], 0.001)
        if with_noise: self.c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in all_x_checks], 0.001)

        for i, z_check in enumerate(all_z_checks):
            self.c.append("MR", all_qbts[z_checks[z_check]])
            self.curr_meas_z[z_check] = self.c.num_measurements

        for i, x_check in enumerate(all_x_checks):
            self.c.append("MR", all_qbts[x_checks[x_check]])
            self.curr_meas_x[x_check] = self.c.num_measurements

        if with_noise: self.c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in all_z_checks], 0.001)
        if with_noise: self.c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in all_x_checks], 0.001)



    def simulate(self):
        for i in range(1,self.num_rounds+1):
            self.c.append("SHIFT_COORDS", [], (0,0,1))
            self.c.append("DEPOLARIZE1", [all_qbts[qbt] for qbt in qbts], 0.001)
            if (i%self.lr_time==0):
                self.route_confirmation_z = np.ones(m*ell)
                self.detector_history = np.vstack([self.detector_history, self.route_confirmation_z])
                self.lr_round()
            else:
                self.route_confirmation_z[sr_z_checks] = [1 if np.random.random() < 0.94988 else 0 for z in sr_z_checks] # 0.94988
                self.route_confirmation_z[lr_z_checks] = 0
                self.detector_history = np.vstack([self.detector_history, self.route_confirmation_z])
                self.sr_round()
            self.detectors(False)

        self.route_confirmation_z = np.ones(m*ell)
        self.detector_history = np.vstack([self.detector_history, self.route_confirmation_z])
        self.lr_round(with_noise=False)
        self.detectors(False)

        self.c.append("M",[all_qbts[qbt] for qbt in qbts[::-1]])
        self.observables(False)

In [819]:
s = Simulation(20, 500)
s.simulate()
c = s.c

In [820]:
detector_sampler = c.compile_detector_sampler()
one_sample = detector_sampler.sample(shots=1, append_observables=True)[0]

ind = 0
print("round\tnum\t")
for i in range(1,s.num_rounds+2):
    curr_checks = s.detector_history[i]
    num_checks = np.count_nonzero(curr_checks)
    timeslice = one_sample[ind:ind+num_checks]
    j = 0
    print(f"{i}\t{num_checks}\t", end="")
    for check in curr_checks:
        if check:
            if timeslice[j]: print("!", end="")
            else: print("_", end="")
            j += 1
        else: print(" ", end="")
    ind += num_checks
    print()
print()
print("observables\t", end="")
timeslice = one_sample[-k:]
print("".join("!" if e else "_" for e in timeslice))

round	num	
1	26	______   ____ _____!_   _!_   __!___
2	23	__!_!_    __ ____!___   ___    ____ 
3	26	______   _ __________   ___   ______
4	26	____!_   ___!________   ___    _____
5	27	____!_   ___!________   ___   ______
6	25	______    ___________   ___   ___ __
7	26	______   _____ ______   ___   ______
8	26	___ __   ____________   ___   ______
9	27	__!___   ________!___   ___   ______
10	24	______   ____________   ___    _ __ 
11	25	______   ____________   _ _   _____ 
12	25	______   __ _______ _   ___   ______
13	24	______   ___ ____  __   ___   ______
14	25	______   _____ ______   __    ______
15	26	______   __ _________   ___   ______
16	26	______   ______ _____   ___   ______
17	27	______   ____________   ___   ______
18	25	______   _ __________   ___   ____ _
19	27	______   ____!_______   ___   ______
20	27	__!___   ________!___   ___   __!!_!
21	36	_______!_______________!_!__!_______

observables	__!!__!_


In [821]:
dem = c.detector_error_model()
pcm = lil_matrix((dem.num_detectors, dem.num_errors), dtype=np.uint8)
lcm = lil_matrix((dem.num_observables, dem.num_errors), dtype=np.uint8)

errors = []
channel_probs = [e.args_copy()[0] for e in c.detector_error_model() if e.type=="error"]
for i, error_event in enumerate(c.explain_detector_error_model_errors()):
    dets = [det.dem_target.val for det in error_event.dem_error_terms if det.dem_target.is_relative_detector_id()]
    obs = [ob.dem_target.val for ob in error_event.dem_error_terms if ob.dem_target.is_logical_observable_id()]
    pcm[[dets],i] = 1
    lcm[[obs],i] = 1

print(pcm.shape)
print(lcm.shape)

(549, 5249)
(8, 5249)


In [822]:
bposd_dec = bposd_decoder(
    pcm, # the parity check matrix
    channel_probs=channel_probs, #assign error_rate to each qubit. This will override "error_rate" input variable
    max_iter=pcm.shape[1], #the maximum number of iterations for BP)
    bp_method="ms",
    ms_scaling_factor=0, #min sum scaling factor. If set to zero the variable scaling factor method is used
    osd_method="osd_cs", #the OSD method. Choose from:  1) "osd_e", "osd_cs", "osd0"
    osd_order=min(pcm.shape[0],10) #the osd search depth
)

In [823]:
count = 0
num_iters = 1000

sampler = c.compile_detector_sampler()
for i in tqdm(range(num_iters)):
    detection_events, observable_flips = sampler.sample(1, separate_observables=True)
    # guessed_errors = bp_dec.decode(detection_events[0])
    guessed_errors = bposd_dec.decode(detection_events[0])
    guessed_obs = (lcm @ guessed_errors) % 2

    if not np.all(observable_flips[0].astype(int) == guessed_obs):
        count += 1
print(count/num_iters)

100%|██████████| 1000/1000 [00:11<00:00, 90.31it/s]

0.005



