In [1]:
import stim
import numpy as np
from numpy.linalg import matrix_power, matrix_rank
from mec import make_circle
from ldpc import bp_decoder, bposd_decoder
import galois
from tqdm import tqdm

In [2]:
code = [3,15,12,1,2,0,14,1,3,15,2,4]
# code = [6,30,21,1,2,3,13,26,6,30,2,2]

In [3]:
def cyclic_shift_matrix(l):
    arr = np.eye(l, dtype=int)
    return np.roll(arr, axis=1, shift=1)

ell = code[1]
m = code[0]

x = np.kron(cyclic_shift_matrix(ell), np.eye(m))
y = np.kron(np.eye(ell), cyclic_shift_matrix(m))

A1 = matrix_power(x, code[2])
A2 = matrix_power(y, code[3])
A3 = matrix_power(y, code[4])
A = ( A1 + A2 + A3 ) % 2

B1 = matrix_power(y, code[5])
B2 = matrix_power(x, code[6])
B3 = matrix_power(x, code[7])
B = ( B1 + B2 + B3 ) % 2

Hx = np.hstack([A, B]).astype(int)
Hz = np.hstack([B.T, A.T]).astype(int)

GF = galois.GF(2)
arr = GF(Hz.T)
k = 2 * (Hz.T.shape[1] - matrix_rank(arr))



In [4]:
def par2gen(H):
    GF = galois.GF(2)
    gfH = GF(H)
    gfH_rank = np.linalg.matrix_rank(gfH)

    rref_H = gfH.row_reduce()

    swaps = []
    col_H = rref_H.copy()
    for i in range(gfH_rank):
        inds = np.where(col_H[i])[0]
        pivot = inds[0]
        col_H[:,[i,pivot]] = col_H[:,[pivot,i]]
        swaps.append((i,pivot))

    col_H = col_H[:gfH_rank]
    col_G = GF(np.hstack([col_H[:,gfH_rank:].T, np.eye(H.shape[1]-gfH_rank, dtype=int)]))

    G = col_G.copy()
    for swap in swaps[::-1]:
        G[:,[swap[1],swap[0]]] = G[:,[swap[0],swap[1]]]

    if (np.any(G @ rref_H[:gfH_rank].T) or np.any(col_G @ col_H.T)):
        print("FAILED")
        return
    return (np.array(G, dtype=int), np.array(col_G, dtype=int))

def commute(x, z, n):
    # 0 if commute, 1 if anticommute
    x1 = x[:n]
    x2 = x[n:]
    z1 = z[:n]
    z2 = z[n:]
    return (x1 @ z2 % 2) ^ (x2 @ z1 % 2)


def SGSOP(Gx, Gz, n):
    # symplectic gram-schmidt orthogonalization procedure
    sym_Gx = np.hstack([Gx, np.zeros(Gx.shape, dtype=int)])
    sym_Gz = np.hstack([np.zeros(Gz.shape, dtype=int), Gz])
    sym_G = np.vstack([sym_Gx, sym_Gz])
    logicals = []
    generators = []

    while(sym_G.shape[0]):
        g1 = sym_G[0]

        commutes = True
        for i in range(1, sym_G.shape[0]-1):
            g2 = sym_G[i]
            if (commute(g1,g2,n)):
                logicals.append((g1, g2))
                sym_G = np.delete(sym_G, [0, i], axis=0)

                for j in range(sym_G.shape[0]):
                    gj = sym_G[j]
                    sym_G[j] = gj ^ (commute(gj,g2,n) * g1) ^ (commute(gj,g1,n) * g2)
                commutes = False
                break

        if commutes:
            generators.append(g1)
            sym_G = np.delete(sym_G, 0, axis=0)

    return (logicals, generators)

In [5]:
def get_logicals(gen_type=False):
    n = Hx.shape[1]
    Gx, col_Gx = par2gen(Hx)
    Gz, col_Gz = par2gen(Hz)
    logicals, generators = SGSOP(Gx, Gz, n)

    logX = np.array([l[1][n:] for l in logicals])
    logZ = np.array([l[0][:n] for l in logicals])

    if gen_type: return logX
    else: return logZ

In [6]:
def embed_code(code, init):
    emb_m, emb_ell, A_ind, B_ind = code

    lattice = np.empty((2*emb_m, 2*emb_ell), dtype=object)
    lattice[0][0] = f"x{init}"

    # As = [[A1, A2.T], [A2, A3.T], [A1, A3.T]]
    # Bs = [[B1, B2.T], [B2, B3.T], [B1, B3.T]]
    As = [[A1, A2.T], [A2, A1.T], [A2, A3.T], [A3, A2.T], [A1, A3.T], [A3, A1.T]]
    Bs = [[B1, B2.T], [B2, B1.T], [B2, B3.T], [B3, B2.T], [B1, B3.T], [B3, B1.T]]

    def get_nbr(i, j):
        if (i % 2 == 0):
            if (j % 2 == 0):
                return "x"
            else:
                return "r"
        else:
            if (j % 2 == 0):
                return "l"
            else:
                return "z"

    for i in range(2*emb_m - 1):
        for j in range(2*emb_ell):
            curr_ind = int(lattice[i][j][1:])

            if (i % 2 == 0):
                tmp_A = As[A_ind][1]
            else:
                tmp_A = As[A_ind][0]
            if (j % 2 == 0):
                tmp_B = Bs[B_ind][1]
            else:
                tmp_B = Bs[B_ind][0]

            lattice[(i+1)%(2*emb_m)][j] = f"{get_nbr((i+1)%(2*emb_m), j)}{np.where(tmp_A @ np.eye(m*ell)[curr_ind])[0][0]}"
            lattice[i][(j+1)%(2*emb_ell)] = f"{get_nbr(i, (j+1)%(2*emb_ell))}{np.where(tmp_B @ np.eye(m*ell)[curr_ind])[0][0]}"

    for i in range(2*emb_m):
        for j in range(2*emb_ell):
            if (lattice[i][j][0] == "z"):
                lattice[i][j] = f"z{int(lattice[i][j][1:]) + m*ell}"
            elif (lattice[i][j][0] == "r"):
                lattice[i][j] = f"r{int(lattice[i][j][1:]) + m*ell}"

    return lattice

lattice = embed_code((code[8],code[9],code[10],code[11]), 0)

In [7]:
all_qbts = {}

qbts = np.array([None for i in range(2*m*ell)])
for i in range(lattice.shape[0]):
    for j in range(lattice.shape[1]):
        if lattice[i][j][0] == "r" or lattice[i][j][0] == "l":
            all_qbts[(i,j)] = int(lattice[i][j][1:])
            qbts[int(lattice[i][j][1:])] = (i, j)
x_checks = np.array([None for i in range(m*ell)])
z_checks = np.array([None for i in range(m*ell)])

for i in range(lattice.shape[0]):
    for j in range(lattice.shape[1]):
        if lattice[i][j][0] == "x":
            all_qbts[(i,j)] = int(lattice[i][j][1:]) + 2*m*ell
            x_checks[int(lattice[i][j][1:])] = (i, j)
        elif lattice[i][j][0] == "z":
            all_qbts[(i,j)] = int(lattice[i][j][1:]) + 2*m*ell
            z_checks[int(lattice[i][j][1:])-(m*ell)] = (i, j)

x_rs = []
z_rs = []
for i in range(m*ell):
    gen_qbts = qbts[np.where(Hx[i])[0]]
    x_rs.append(make_circle(gen_qbts)[2])
for i in range(m*ell):
    gen_qbts = qbts[np.where(Hz[i])[0]]
    z_rs.append(make_circle(gen_qbts)[2])

lr_x_checks = []
sr_x_checks = []
lr_z_checks = []
sr_z_checks = []

for i, x_check in enumerate(x_checks):
    gen_qbts = qbts[np.where(Hx[i])[0]]

    nonlocal_qbts = []
    if (x_rs[i] > (min(x_rs)+np.std(x_rs))):
        lr_x_checks.append(i)
    else:
        sr_x_checks.append(i)

for i, z_check in enumerate(z_checks):
    gen_qbts = qbts[np.where(Hz[i])[0]]

    nonlocal_qbts = []
    if (z_rs[i] > min(z_rs)+np.std(z_rs)):
        lr_z_checks.append(i)
    else:
        sr_z_checks.append(i)

In [8]:
def lr_bell_pair(paths):
    size = len(path)
    c = stim.Circuit()

    for path in paths:
        c.append("H", path[:-1][::2])

    c.append("CNOT", path[:size-(size%2)])
    c.append("TICK")
    c.append("CNOT", path[1:size-1+(size%2)])

    c.append("H", path[:-1][1::2])
    c.append("MR", path[1:-1])

    for i in range(2 - (size%2), size-1, 2):
        c.append("CZ", [stim.target_rec(-i), path[0]])
    for i in range(1 + (size%2), size-1, 2):
        c.append("CX", [stim.target_rec(-i), path[-1]])

    return c

def lr_CNOT_bell(control, target, bell_pair):
    c = stim.Circuit()

    c.append("TICK")
    c.append("CNOT", [control, bell_pair[0], bell_pair[1], target])
    c.append("MR", bell_pair[0])
    c.append("MRX", bell_pair[1])
    c.append("CX", [stim.target_rec(-2), target])
    c.append("CZ", [stim.target_rec(-1), control])

    return c

def lr_CNOT_no_bell(paths):
    # path[0] is control, path[-1] is target. Reverse path to get reverse CNOT
    c = stim.Circuit()

    for path in paths:
        size = len(path)
        c.append("H", path[:-1][2::2])
        c.append("DEPOLARIZE1", path[:-1][2::2], 0.0005)
    c.append("TICK")

    for path in paths:
        size = len(path)
        c.append("CNOT", path[:size-(size%2)])
        c.append("DEPOLARIZE2", path[:size-(size%2)], 0.0005)
    c.append("TICK")

    for path in paths:
        size = len(path)
        c.append("CNOT", path[1:size-1+(size%2)])
        c.append("DEPOLARIZE2", path[1:size-1+(size%2)], 0.0005)
    c.append("TICK")

    for path in paths:
        c.append("H", path[:-1][1::2])
        c.append("DEPOLARIZE1", path[:-1][2::2], 0.0005)
    c.append("TICK")

    for path in paths:
        c.append("X_ERROR", path[1:-1], 0.0005)
        c.append("MR", path[1:-1])
        c.append("X_ERROR", path[1:-1], 0.0005)
    c.append("TICK")

    for j, path in enumerate(paths):
        tot_len = sum([len(p[1:-1]) for p in paths[j:]])
        size = len(path)
        for i in range(1 + (size%2), size-1, 2):
            c.append("CZ", [stim.target_rec(-tot_len+i-1), path[0]])
        c.append("DEPOLARIZE1", path[0], 0.0005)
        for i in range(2 - (size%2), size-1, 2):
            c.append("CX", [stim.target_rec(-tot_len+i-1), path[-1]])
        c.append("DEPOLARIZE1", path[-1], 0.0005)
    c.append("TICK")

    return c, sum([len(p[1:-1]) for p in paths])

def direct_CNOT(paths):
    c = stim.Circuit()
    qbt_paths = []
    for path in paths:
        qbt_paths += [all_qbts[path[0]], all_qbts[path[-1]]]
    c.append("CNOT", qbt_paths)
    c.append("DEPOLARIZE2", qbt_paths, 0.0005)
    c.append("TICK")
    return c

def bell_CNOT(paths):
    c = stim.Circuit()
    path_qbts = []
    for path in paths:
        path_qbts.append([all_qbts[qbt] + 4*m*ell if 0 < i < len(path)-1 else all_qbts[qbt] for i, qbt in enumerate(path)])
    c2, l = lr_CNOT_no_bell(path_qbts)
    c += c2
    return c, l

In [9]:
def sr_x_up(gen_index, mod, col_ind):
    x_checks_i = [x for x in sr_x_checks if ((x_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0], x_checks[x][1]-i) for i in range(4)] for x in x_checks_i]
    return x_paths

def x_down_local(checks, gen_index, mod, col_ind):
    x_checks_i = [x for x in checks if ((z_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0], x_checks[x][1]+i) for i in range(2)] for x in x_checks_i]
    return x_paths

def x_left_local(checks, gen_index, mod, col_ind): # only col_ind 1 or 2
    x_checks_i = [x for x in checks if ((z_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0]-i, x_checks[x][1]) for i in range(2)] for x in x_checks_i]
    return x_paths

def x_right_local(checks, gen_index, mod, col_ind):
    x_checks_i = [x for x in checks if ((z_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0]+i, x_checks[x][1]) for i in range(2)] for x in x_checks_i]
    return x_paths

def x_up_local(checks, gen_index, mod, col_ind):
    x_checks_i = [x for x in checks if ((z_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0], x_checks[x][1]-i) for i in range(2)] for x in x_checks_i]
    return x_paths

def x_left_bdy(checks, gen_index, mod, col_ind=[0]):
    x_checks_i = [x for x in checks if ((x_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = [[(x_checks[x][0]+i, x_checks[x][1]) for i in range(6)] for x in x_checks_i]
    return x_paths

def sr_x_up_left(gen_index, mod, col_ind=2):
    x_checks_i = [x for x in sr_x_checks if ((x_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 == col_ind)]
    x_paths = []
    for x in x_checks_i:
        tmp_path = []
        for i in range(4): tmp_path.append((x_checks[x][0], x_checks[x][1]-i))
        tmp_path.append((x_checks[x][0]-1, x_checks[x][1]-3))
        for i in range(3): tmp_path.append((x_checks[x][0]-1, x_checks[x][1]-4-i))
        for i in range(2): tmp_path.append((x_checks[x][0]-2-i, x_checks[x][1]-6))
        x_paths.append(tmp_path)
    return x_paths

def sr_x_up_right(gen_index, mod, col_ind):
    x_checks_i = [x for x in sr_x_checks if ((x_checks[x][1]//2)%mod==gen_index) and (x_checks[x][0]//2 in col_ind)]
    x_paths = []
    for x in x_checks_i:
        if (x_checks[x][0]//2 == 0):
            tmp_path = []
            for i in range(2): tmp_path.append((x_checks[x][0], x_checks[x][1]-i))
            tmp_path.append((x_checks[x][0]+1, x_checks[x][1]-1))
            for i in range(4): tmp_path.append((x_checks[x][0]+1, x_checks[x][1]-2-i))
            for i in range(2): tmp_path.append((x_checks[x][0]+2+i, x_checks[x][1]-5))
            tmp_path.append((x_checks[x][0]+3, x_checks[x][1]-6))
            x_paths.append(tmp_path)
        else:
            tmp_path = []
            for i in range(3): tmp_path.append((x_checks[x][0]+i, x_checks[x][1]))
            tmp_path.append((x_checks[x][0]+2, x_checks[x][1]-1))
            for i in range(6): tmp_path.append((x_checks[x][0]+3, x_checks[x][1]-1-i))
            x_paths.append(tmp_path)
    return x_paths



def measure_sr_x_checks():
    c = stim.Circuit()
    tot_meas = 0

    c2, l = bell_CNOT(sr_x_up_left(0,2))
    c += c2
    tot_meas += l
    c2, l = bell_CNOT(sr_x_up_left(1,2))
    c += c2
    tot_meas += l

    c2,l = bell_CNOT(sr_x_up_right(1,3,[0]) + sr_x_up_right(2,3,[1])
                 + sr_x_up(1,3,[2]) + sr_x_up(1,3,[1]) + sr_x_up(0,3,[0]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(sr_x_up_right(2,3,[0]) + sr_x_up_right(0,3,[1])
                 + sr_x_up(2,3,[2]) + sr_x_up(2,3,[1]) + sr_x_up(1,3,[0]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(sr_x_up_right(0,3,[0]) + sr_x_up_right(1,3,[1])
                 + sr_x_up(0,3,[2]) + sr_x_up(0,3,[1]) + sr_x_up(2,3,[0]))
    c += c2
    tot_meas += l


    c += direct_CNOT(x_down_local(sr_x_checks, 0,1,[0,1,2]))
    c += direct_CNOT(x_left_local(sr_x_checks, 0,1,[1,2]))
    c += direct_CNOT(x_right_local(sr_x_checks, 0,1,[0,1,2]))
    c2,l = bell_CNOT(x_left_bdy(sr_x_checks, 0,1))
    c += c2
    tot_meas += l
    c += direct_CNOT(x_up_local(sr_x_checks, 0,1,[0,1,2]))

    return c, tot_meas

c = stim.Circuit()
for key, value in all_qbts.items():
    c.append("QUBIT_COORDS", value, (key[0],key[1],0))
    c.append("QUBIT_COORDS", value+(4*m*ell), (key[0],key[1],1))
c += measure_sr_x_checks()[0].without_noise()

with open("tmp.svg", "w") as f:
    f.write(str(c.without_noise().diagram("timeslice-svg")))

In [10]:
def sr_z_down(gen_index, mod, col_ind):
    z_checks_i = [z for z in sr_z_checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(4)][::-1] for z in z_checks_i]
    return z_paths

def sr_z_down_right(gen_index, mod, col_ind=0):
    z_checks_i = [z for z in sr_z_checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 == col_ind)]
    z_paths = []
    for z in z_checks_i:
        tmp_path = []
        for i in range(4): tmp_path.append((z_checks[z][0], z_checks[z][1]+i))
        tmp_path.append((z_checks[z][0]+1, z_checks[z][1]+3))
        for i in range(3): tmp_path.append((z_checks[z][0]+1, z_checks[z][1]+4+i))
        for i in range(2): tmp_path.append((z_checks[z][0]+2+i, z_checks[z][1]+6))
        z_paths.append(tmp_path[::-1])
    return z_paths

def sr_z_down_left(gen_index, mod, col_ind):
    z_checks_i = [z for z in sr_z_checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = []
    for z in z_checks_i:
        if (z_checks[z][0]//2 == 2):
            tmp_path = []
            for i in range(2): tmp_path.append((z_checks[z][0], z_checks[z][1]+i))
            tmp_path.append((z_checks[z][0]-1, z_checks[z][1]+1))
            for i in range(4): tmp_path.append((z_checks[z][0]-1, z_checks[z][1]+2+i))
            for i in range(2): tmp_path.append((z_checks[z][0]-2-i, z_checks[z][1]+5))
            tmp_path.append((z_checks[z][0]-3, z_checks[z][1]+6))
            z_paths.append(tmp_path[::-1])
        else:
            tmp_path = []
            for i in range(3): tmp_path.append((z_checks[z][0]-i, z_checks[z][1]))
            tmp_path.append((z_checks[z][0]-2, z_checks[z][1]+1))
            for i in range(6): tmp_path.append((z_checks[z][0]-3, z_checks[z][1]+1+i))
            z_paths.append(tmp_path[::-1])
    return z_paths


def z_down_local(checks, gen_index, mod, col_ind):
    z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(2)][::-1] for z in z_checks_i]
    return z_paths

def z_left_local(checks, gen_index, mod, col_ind):
    z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0]-i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
    return z_paths

def z_up_local(checks, gen_index, mod, col_ind):
    z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]-i) for i in range(2)][::-1] for z in z_checks_i]
    return z_paths

def z_right_local(checks, gen_index, mod, col_ind):
    z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0]+i, z_checks[z][1]) for i in range(2)][::-1] for z in z_checks_i]
    return z_paths

def z_right_bdy(checks, gen_index, mod, col_ind=[2]):
    z_checks_i = [z for z in checks if ((z_checks[z][1]//2)%mod==gen_index) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0]-i, z_checks[z][1]) for i in range(6)][::-1] for z in z_checks_i]
    return z_paths



def lr_z_down_up(col_ind): # long dim boundary conditions
    z_checks_i = [z for z in lr_z_checks if (z_checks[z][1]==(ell*2-1)) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]-i) for i in range(30)][::-1] for z in z_checks_i]
    return z_paths

def lr_z_down(col_ind):
    z_checks_i = [z for z in lr_z_checks if (z_checks[z][1]==(ell*2-5)) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]+i) for i in range(4)][::-1] for z in z_checks_i]
    return z_paths

def lr_z_up(gen_ind, mod, col_ind):
    z_checks_i = [z for z in lr_z_checks if ((z_checks[z][1]//2)%mod==gen_ind) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0], z_checks[z][1]-i) for i in range(28)][::-1] for z in z_checks_i]
    return z_paths

def lr_z_up_left(gen_ind, mod, col_ind): # only col_ind 1 or 2
    z_checks_i = [z for z in lr_z_checks if ((z_checks[z][1]//2)%mod==gen_ind) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0]-i, z_checks[z][1]-24) for i in range(4)][::-1]+[(z_checks[z][0], z_checks[z][1]-i) for i in range(24)][::-1] for z in z_checks_i]
    return z_paths

def lr_z_up_right(gen_ind, mod, col_ind=[0]):
    z_checks_i = [z for z in lr_z_checks if ((z_checks[z][1]//2)%mod==gen_ind) and (z_checks[z][0]//2 in col_ind)]
    z_paths = [[(z_checks[z][0]+i, z_checks[z][1]-24) for i in range(4)][::-1]+[(z_checks[z][0], z_checks[z][1]-i) for i in range(24)][::-1] for z in z_checks_i]
    return z_paths


def measure_sr_z_checks():
    c = stim.Circuit()
    tot_meas = 0

    c2, l = bell_CNOT(sr_z_down_right(0,2))
    c += c2
    tot_meas += l
    c2, l = bell_CNOT(sr_z_down_right(1,2))
    c += c2
    tot_meas += l

    c2,l = bell_CNOT(sr_z_down_left(0,3,[2]) + sr_z_down_left(2,3,[1])
                 + sr_z_down(1,3,[2]) + sr_z_down(0,3,[1]) + sr_z_down(0,3,[0]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(sr_z_down_left(1,3,[2]) + sr_z_down_left(0,3,[1])
                    + sr_z_down(2,3,[2]) + sr_z_down(1,3,[1]) + sr_z_down(1,3,[0]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(sr_z_down_left(2,3,[2]) + sr_z_down_left(1,3,[1])
                    + sr_z_down(0,3,[2]) + sr_z_down(2,3,[1]) + sr_z_down(2,3,[0]))
    c += c2
    tot_meas += l

    c += direct_CNOT(z_down_local(sr_z_checks, 0,1,[0,1,2]))
    c += direct_CNOT(z_left_local(sr_z_checks, 0,1,[0,1,2]))
    c += direct_CNOT(z_right_local(sr_z_checks, 0,1,[0,1]))
    c2,l = bell_CNOT(z_right_bdy(sr_z_checks, 0,1))
    c += c2
    tot_meas += l
    c += direct_CNOT(z_up_local(sr_z_checks, 0,1,[0,1,2]))

    return c, tot_meas

def measure_lr_z_checks():
    c = stim.Circuit()
    tot_meas = 0

    c2,l = bell_CNOT(lr_z_up_right(0,3))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up_right(1,3))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up_right(2,3))
    c += c2
    tot_meas += l


    c2,l = bell_CNOT(lr_z_up_left(0,3,[2]) + lr_z_up_left(1,3,[1]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up_left(1,3,[2]) + lr_z_up_left(2,3,[1]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up_left(2,3,[2]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up_left(0,3,[1]))
    c += c2
    tot_meas += l

    c2,l = bell_CNOT(lr_z_down([0,1,2]))
    c += c2
    tot_meas += l

    c2,l = bell_CNOT(lr_z_up(1,3,[0,1,2]))
    c += c2
    tot_meas += l
    c2,l = bell_CNOT(lr_z_up(2,3,[0,1,2]))
    c += c2
    tot_meas += l

    c2,l = bell_CNOT(lr_z_down_up([0,1,2]))
    c += c2
    tot_meas += l
    c += direct_CNOT(z_down_local(lr_z_checks, 0,3,[0,1,2]) + z_down_local(lr_z_checks, 1,3,[0,1,2]))
    c += direct_CNOT(z_left_local(lr_z_checks, 0,1,[0,1,2]))
    c += direct_CNOT(z_right_local(lr_z_checks, 0,1,[0,1]))
    c2,l = bell_CNOT(z_right_bdy(lr_z_checks, 0,1))
    c += c2
    tot_meas += l
    c += direct_CNOT(z_up_local(lr_z_checks, 0,1,[0,1,2]))

    return c, tot_meas

In [11]:
def measure_x_checks(checks):
    c = stim.Circuit()
    c.append("H", [all_qbts[x_checks[x_check]] for x_check in checks])
    c.append("DEPOLARIZE1", [all_qbts[x_checks[x_check]] for x_check in checks], 0.0005)
    path_qbts = []
    for x in checks:
        gen_qbts = qbts[np.where(Hx[x])[0]]
        for qbt in gen_qbts:
            path_qbts += [all_qbts[x_checks[x]], all_qbts[qbt]]
    c.append("CNOT", path_qbts)
    c.append("DEPOLARIZE2", path_qbts, 0.0005)
    c.append("H", [all_qbts[x_checks[x_check]] for x_check in checks])
    c.append("DEPOLARIZE1", [all_qbts[x_checks[x_check]] for x_check in checks], 0.0005)
    return c

def measure_z_checks(checks):
    c = stim.Circuit()
    path_qbts = []
    for z in checks:
        gen_qbts = qbts[np.where(Hz[z])[0]]
        for qbt in gen_qbts:
            path_qbts += [all_qbts[qbt], all_qbts[z_checks[z]]]

    # pairs = [path_qbts[i:i + 2] for i in range(0, len(path_qbts), 2)]
    # np.random.shuffle(pairs)
    # path_qbts = [item for pair in pairs for item in pair]
    c.append("CNOT", path_qbts)
    c.append("DEPOLARIZE2", path_qbts, 0.0005)
    return c

In [12]:
def init_detectors():
    c = stim.Circuit()
    for i, z_check in enumerate(z_checks):
        coord = z_check
        c.append("DETECTOR", [stim.target_rec(-(m*ell)+i)], (coord[0], coord[1], 0))
    return c

def inter_detectors(checks, meas_offset=0):
    c = stim.Circuit()
    # for i, x_check in enumerate(x_checks):
    #     coord = x_check
    #     c.append("DETECTOR", [stim.target_rec(-(2*m*ell)+i), stim.target_rec(-(4*m*ell)-meas_offset+i)], (coord[0], coord[1], 0))
    for i, z_check in enumerate(checks):
        coord = z_checks[z_check]
        c.append("DETECTOR", [stim.target_rec(-len(checks)+i), stim.target_rec(-meas_offset+i)], (coord[0], coord[1], 0))
    return c

def observables():
    c = stim.Circuit()
    for i, logical in enumerate(get_logicals(False)):
        incl_qbts = np.where(logical)[0]
        incl_qbts = [-j-1 for j in incl_qbts]
        c.append("OBSERVABLE_INCLUDE", [stim.target_rec(j) for j in incl_qbts], i)
    return c

def final_detectors():
    c = stim.Circuit()

    for i, z_check in enumerate(z_checks):
        coord = z_check
        incl_qbts = np.where(Hz[i])[0]
        incl_qbts = [-j-1 for j in incl_qbts]
        c.append("DETECTOR", [stim.target_rec(-(3*m*ell)+i)]+[stim.target_rec(j) for j in incl_qbts], (coord[0], coord[1], 1))
    c += observables()
    return c

In [21]:
num_rounds = 3
lr_time = 50
num_meas = []
num_gen_meas = []

c = stim.Circuit()
for key, value in all_qbts.items():
    c.append("QUBIT_COORDS", value, (key[0],key[1],0))
    c.append("QUBIT_COORDS", value+(4*m*ell), (key[0],key[1],1))
c.append("R", [qbt for qbt in all_qbts.values()])
c.append("R", [qbt+(4*m*ell) for qbt in all_qbts.values()])

c += measure_z_checks(sr_z_checks+lr_z_checks).without_noise()
c += measure_x_checks(sr_x_checks+lr_x_checks).without_noise()
c.append("MR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks+lr_z_checks])
c.append("MR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks+lr_x_checks])
num_meas.append(measure_sr_z_checks()[1])
num_gen_meas.append(2*m*ell)

def sr_round():
    c = stim.Circuit()
    # c += measure_z_checks(sr_z_checks)
    # l = 0
    c2,l = measure_sr_z_checks()
    c += c2
    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks], 0.0005)
    c.append("MR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks])
    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks], 0.0005)
    c += inter_detectors(sr_z_checks, num_meas[-1]+num_gen_meas[-1]+len(sr_z_checks))

    # c += measure_x_checks(sr_x_checks)
    # l2 = 0
    c2,l2 = measure_sr_x_checks()
    c += c2
    c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks], 0.0005)
    c.append("MR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks])
    c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks], 0.0005)

    num_meas.append(l+l2)
    num_gen_meas.append(len(sr_z_checks+sr_x_checks))
    return c

def lr_round():
    c = stim.Circuit()
    # c += measure_z_checks(sr_z_checks)
    # l=0
    c2, l = measure_sr_z_checks()
    c += c2

    # c += measure_z_checks(lr_z_checks)
    # l2 = 0
    c2,l2 = measure_lr_z_checks()
    c += c2

    last_lr = len(num_meas) - num_gen_meas[::-1].index(2*m*ell) - 1
    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks], 0.0005)
    c.append("MR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks])
    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in sr_z_checks], 0.0005)
    c += inter_detectors(sr_z_checks, num_gen_meas[-1]+l+l2+len(sr_z_checks))

    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in lr_z_checks], 0.0005)
    c.append("MR", [all_qbts[z_checks[z_check]] for z_check in lr_z_checks])
    c.append("X_ERROR", [all_qbts[z_checks[z_check]] for z_check in lr_z_checks], 0.0005)

    # print(last_lr, sum(num_meas[last_lr+1:])+2*m*ell+len(lr_z_checks)+l+l2)
    c += inter_detectors(lr_z_checks, l+l2+sum(num_meas[last_lr+1:])+2*m*ell+len(lr_z_checks))

    c += measure_x_checks(sr_x_checks+lr_x_checks)
    c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks+lr_x_checks], 0.0005)
    c.append("MR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks+lr_x_checks])
    c.append("X_ERROR", [all_qbts[x_checks[x_check]] for x_check in sr_x_checks+lr_x_checks], 0.0005)

    num_meas.append(2*m*ell+l+l2)
    num_gen_meas.append(2*m*ell)
    return c

for i in range(1,num_rounds+1):
    c.append("SHIFT_COORDS", [], (0,0,1))
    c.append("DEPOLARIZE1", [all_qbts[qbt] for qbt in qbts], 0.0005)
    if (i%lr_time==0): c += lr_round()
    else: c += sr_round()

# c += lr_round().without_noise()
# c.append("M",[all_qbts[qbt] for qbt in qbts[::-1]])
# c += observables()

# with open("tmp.svg", "w") as f:
#     f.write(str(c.without_noise().diagram("timeslice-svg")))
print(num_meas)
print(c)

[408, 816, 816, 816]
QUBIT_COORDS(0, 1, 0) 48
QUBIT_COORDS(0, 1, 1) 228
QUBIT_COORDS(0, 3, 0) 51
QUBIT_COORDS(0, 3, 1) 231
QUBIT_COORDS(0, 5, 0) 54
QUBIT_COORDS(0, 5, 1) 234
QUBIT_COORDS(0, 7, 0) 57
QUBIT_COORDS(0, 7, 1) 237
QUBIT_COORDS(0, 9, 0) 60
QUBIT_COORDS(0, 9, 1) 240
QUBIT_COORDS(0, 11, 0) 63
QUBIT_COORDS(0, 11, 1) 243
QUBIT_COORDS(0, 13, 0) 66
QUBIT_COORDS(0, 13, 1) 246
QUBIT_COORDS(0, 15, 0) 69
QUBIT_COORDS(0, 15, 1) 249
QUBIT_COORDS(0, 17, 0) 72
QUBIT_COORDS(0, 17, 1) 252
QUBIT_COORDS(0, 19, 0) 75
QUBIT_COORDS(0, 19, 1) 255
QUBIT_COORDS(0, 21, 0) 78
QUBIT_COORDS(0, 21, 1) 258
QUBIT_COORDS(0, 23, 0) 81
QUBIT_COORDS(0, 23, 1) 261
QUBIT_COORDS(0, 25, 0) 84
QUBIT_COORDS(0, 25, 1) 264
QUBIT_COORDS(0, 27, 0) 87
QUBIT_COORDS(0, 27, 1) 267
QUBIT_COORDS(0, 29, 0) 45
QUBIT_COORDS(0, 29, 1) 225
QUBIT_COORDS(1, 0, 0) 2
QUBIT_COORDS(1, 0, 1) 182
QUBIT_COORDS(1, 2, 0) 5
QUBIT_COORDS(1, 2, 1) 185
QUBIT_COORDS(1, 4, 0) 8
QUBIT_COORDS(1, 4, 1) 188
QUBIT_COORDS(1, 6, 0) 11
QUBIT_COORDS(1, 6, 

In [28]:
np.set_printoptions(linewidth=200)
detector_sampler = c.compile_detector_sampler()
one_sample = detector_sampler.sample(shots=1, append_observables=True)[0]
print(len(one_sample))

ind = 0
for i in range(1,num_rounds+1):
    if (i%lr_time==0):
        timeslice = one_sample[ind:ind+len(z_checks)]
        ind += len(z_checks)
    else:
        timeslice = one_sample[ind:ind+len(sr_z_checks)]
        ind += len(sr_z_checks)
    print("".join("!" if e else "_" for e in timeslice))

timeslice = one_sample[ind:-k]
print("".join("!" if e else "_" for e in timeslice))
timeslice = one_sample[-k:]
print("".join("!" if e else "_" for e in timeslice))

108
_____________________!______________
!_________!!_________!______________
____________________________________

________


In [2011]:
from scipy.sparse import lil_matrix

dem = c.detector_error_model()
pcm = lil_matrix((dem.num_detectors, dem.num_errors), dtype=np.uint8)
lcm = lil_matrix((dem.num_observables, dem.num_errors), dtype=np.uint8)

errors = []
channel_probs = [e.args_copy()[0] for e in c.detector_error_model() if e.type=="error"]
for i, error_event in enumerate(c.explain_detector_error_model_errors()):
    dets = [det.dem_target.val for det in error_event.dem_error_terms if det.dem_target.is_relative_detector_id()]
    obs = [ob.dem_target.val for ob in error_event.dem_error_terms if ob.dem_target.is_logical_observable_id()]
    pcm[[dets],i] = 1
    lcm[[obs],i] = 1

print(pcm.shape)
print(lcm.shape)

(405, 2646)
(8, 2646)


In [2012]:
bp_dec = bp_decoder(
    pcm,
    channel_probs=channel_probs,
    max_iter=pcm.shape[1],
    bp_method="msl",
    ms_scaling_factor=0
)

bposd_dec = bposd_decoder(
    pcm, # the parity check matrix
    channel_probs=channel_probs, #assign error_rate to each qubit. This will override "error_rate" input variable
    max_iter=pcm.shape[1], #the maximum number of iterations for BP)
    bp_method="msl",
    ms_scaling_factor=0, #min sum scaling factor. If set to zero the variable scaling factor method is used
    osd_method="osd_cs", #the OSD method. Choose from:  1) "osd_e", "osd_cs", "osd0"
    osd_order=min(pcm.shape[0],10) #the osd search depth
)

In [2013]:
count = 0
num_iters = 10000

sampler = c.compile_detector_sampler()
for i in tqdm(range(num_iters)):
    detection_events, observable_flips = sampler.sample(1, separate_observables=True)
    # guessed_errors = bp_dec.decode(detection_events[0])
    guessed_errors = bposd_dec.decode(detection_events[0])
    guessed_obs = (lcm @ guessed_errors) % 2

    if not np.all(observable_flips[0].astype(int) == guessed_obs):
        count += 1
print(count/num_iters)

100%|██████████| 10000/10000 [00:28<00:00, 353.36it/s]

0.0



