In [35]:
import numpy as np
import matplotlib.pyplot as plt
import stim
from lib.stabilizer import measurement_gadgets, StabilizerCode, stabilizer_circuits
from lib.color_compass import *
from lib.decoder import checkmatrix,pL_from_checkmatrix
from lib.stim2pymatching import estimate_pL_noisy_graph
import stimcirq
from typing import *
from cirq.contrib.svg import SVGCircuit
import pymatching

In [36]:
class Lattice2D():
    """
    convention: 
    X coords extend vertically |
    Z coords extend horizontally --
    store the coloring as a list with values in {-1, 0, 1}
    
    Red  ~ -1 ~ Z-type cuts
    Blue ~ +1 ~ X-type cuts
    White ~ 0
    
    preallocate logical X and L as cuts accross the lattice
    """
    def __init__(self, dimX, dimZ):
        self.dimX = dimX
        self.dimZ = dimZ
        self.colors = [0] * (dimX-1)*(dimZ-1)
        self.stabs = bacon_shor_group(dimX, dimZ)
        self.gauge = bsgauge_group(dimX, dimZ)
        self.Lx = ''.join(['X']*dimX+['_']*dimX*(dimZ-1))
        self.Lz = ''.join((['Z']+['_']*(dimX-1))*dimZ)
        self.logicals = [self.Lx,self.Lz]
        
    def size(self):
        return self.dimX*self.dimZ
        
    def __str__(self):
        vertex_rows = []
        face_rows = []
        dimX = self.dimX
        dimZ = self.dimZ
        for i in range(dimX):
            vertex_string = ''
            for j in range(dimZ):
                vertex_string += str(i*dimZ + j).zfill(3)
                if (j != dimZ-1):
                    vertex_string += '---'
            vertex_rows.append(vertex_string)
                
        for i in range(dimX-1):
            face_string = ''
            for j in range(dimZ-1):
                if(self.colors[i*(dimZ-1) + j] == -1):
                    # face_string += ' | '+colored(' # ', 'red')
                    face_string += ' | ' + ' ░ '
                elif(self.colors[i*(dimZ-1) + j] == +1):
                    # face_string += ' | '+colored(' # ', 'blue')
                    face_string += ' | ' + ' ▓ '
                elif(self.colors[i*(dimZ-1) + j] == 0):
                    face_string += ' |    '
                else:
                    raise ValueError(f'Invalid color type {self.colors[i*dimZ+j]}')
                if j == dimZ-2:
                    face_string += ' |'
            face_rows.append(face_string)
        sout = ''
        for idx, row in enumerate(vertex_rows):
            sout += row +'\n'
            if idx != len(vertex_rows)-1:
                sout += face_rows[idx]+'\n'
        return sout
    
    def size(self):
        return self.dimX*self.dimZ
    
    def getG(self):
        return self.gauge[0]+self.gauge[1]
    
    def getGx(self):
        return self.gauge[0]
    
    def getGz(self):
        return self.gauge[1]
    
    def getS(self):
        return self.stabs[0]+self.stabs[1]
    
    def getSx(self):
        return self.stabs[0]
    
    def getSz(self):
        return self.stabs[1]
    
    def getDims(self):
        return (self.dimX, self.dimZ)
    
    def max_stab_number(self):
        return self.dimX*self.dimZ - 1
    
    def pcheckZ(self):
        """returns the Z parity check matrix"""
        return np.vstack([pauli2vector(s) for s in self.getSz()])
        
    def pcheckX(self):
        """returns the X parity check matrix"""
        return np.vstack([pauli2vector(s) for s in self.getSx()])
    
    def display(self, pauli):
        dimX = self.dimX
        dimZ = self.dimZ
        if (len(pauli) != dimX*dimZ):
            raise ValueError("Pauli string dimension mismatch with lattice size")
        sout = ''
        slist = list(pauli)
        for i in range(dimX):
            for j in range(dimZ):
                if slist[i*dimZ+j] == 'X':
                    sout += ' X '
                elif slist[i*dimZ+j] == 'Z':
                    sout += ' Z '
                else:
                    sout += '   '
                if (j != dimZ-1):
                    sout += '---'
            if (i != dimX -1):
                sout += '\n'
                sout += ' |    '*dimZ
            sout += '\n'
        print(sout)
        
    def color_lattice(self, colors):
        """
        replace color state with input and recalculate stab and gauge groups 
        """
        dimX = self.dimX-1
        dimZ = self.dimZ-1
        if(len(colors) != dimX*dimZ):
            raise ValueError("Color dimension mismatch with lattice size")
        
        self.stabs = bacon_shor_group(self.dimX, self.dimZ)
        self.gauge = bsgauge_group(self.dimX, self.dimZ)
        self.colors = colors
        
        
        for cidx, c in enumerate(colors):
            if c == -1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), -1)
            elif c == +1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), +1)
        
        
    def update_groups(self, coords, cut_type):
        """
        cut the stabilizer group by coloring the face with the given type
            AND
        update the gauge group 
    
        algo: 
        [0] pick the gauge operator g to cut around
        [1] find s \in S that has weight-2 overlap with g
        [2] divide that s 
        [3] update the gauge group 
        """
        (i, j) = coords
        dimX = self.dimX
        dimZ = self.dimZ
        [Sx, Sz] = self.getSx(), self.getSz()
        [Gx, Gz] = self.getGx(), self.getGz()
        
        if cut_type == -1:
            # -1 = red which is a Z-cut
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'Z'
            g[i*dimZ + j + 1] = 'Z'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sz):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two vertical parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, i+1):
                        s1[k*dimZ + j] = s[k*dimZ + j]
                        s1[k*dimZ + j+1] = s[k*dimZ + j+1]
                    for k in range(i+1, dimX):
                        s2[k*dimZ + j] = s[k*dimZ + j]
                        s2[k*dimZ + j+1] = s[k*dimZ + j+1]
                    del Sz[idx]
                    Sz.append(''.join(s1))
                    Sz.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[k + i*dimZ] = 'Z'
                gauge[k + i*dimZ + 1] = 'Z'
            Gx_new = []
            for g in Gx:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gx_new.append(g)
            Gx = Gx_new
                
        elif cut_type == +1:
            # +1 = blue that is a X-cut:
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'X'
            g[(i+1)*dimZ + j ] = 'X'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sx):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two horizontal parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, j+1):
                        s1[i*dimZ + k] = s[i*dimZ + k]
                        s1[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    for k in range(j+1, dimZ):
                        s2[i*dimZ + k] = s[i*dimZ + k]
                        s2[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    del Sx[idx]
                    Sx.append(''.join(s1))
                    Sx.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[k + i*dimZ] = 'X'
                gauge[k + (i+1)*dimZ] = 'X'
            Gz_new = []
            for g in Gz:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gz_new.append(g)
            Gz = Gz_new

        # update the groups
        self.stabs = [Sx, Sz]
        self.gauge = [Gx, Gz]
        
    def error_is_corrected(self, syn, l_1, l_2, l_op):
        #syn is a vector of syndrome measurements, which has the Sx syndrome bits first
        #l_1, l_2 are booleans corresponding to the logical operator measurement
        #l_op is the logical operator we're measuring, written as a stim Pauli string

        #also need to know the stabilizers to feed into the decoder

        #check parity of l_1, l_2. If they are the same: no logical error
        #If they are different: logical error
        logical_error = ((l_1+l_2) % 2 == 0)

        #syndrome measurement gives a syndrome s. Feed into decoder to get a correction operator c
        #set up decoder
        Sx = lat.getSx()
        Sz = lat.getSz()
        Hx = np.array([[1 if i != '_' else 0 for i in s] for s in Sx])
        Hz = np.array([[1 if i != '_' else 0 for i in s] for s in Sz])
        Mx = Matching(Hx)
        Mz = Matching(Hz)

        #obtain correction operator
        cx = Mx.decode(syn[:len(Sx)])
        cz = Mz.decode(syn[len(Sx):])
        Rx = stim.PauliString(''.join(['X' if i == 1 else '_' for i in cx]))
        Rz = stim.PauliString(''.join(['Z' if i == 1 else '_' for i in cz]))

        correction_op = Rx*Rz
class Lattice2D():
    """
    convention: 
    X coords extend vertically |
    Z coords extend horizontally --
    store the coloring as a list with values in {-1, 0, 1}
    
    Red  ~ -1 ~ Z-type cuts
    Blue ~ +1 ~ X-type cuts
    White ~ 0
    
    preallocate logical X and L as cuts accross the lattice
    """
    def __init__(self, dimX, dimZ):
        self.dimX = dimX
        self.dimZ = dimZ
        self.colors = [0] * (dimX-1)*(dimZ-1)
        self.stabs = bacon_shor_group(dimX, dimZ)
        self.gauge = bsgauge_group(dimX, dimZ)
        self.Lx = ''.join(['X']*dimX+['_']*dimX*(dimZ-1))
        self.Lz = ''.join((['Z']+['_']*(dimX-1))*dimZ)
        self.logicals = [self.Lx,self.Lz]
        
    def size(self):
        return self.dimX*self.dimZ
        
    def __str__(self):
        vertex_rows = []
        face_rows = []
        dimX = self.dimX
        dimZ = self.dimZ
        for i in range(dimX):
            vertex_string = ''
            for j in range(dimZ):
                vertex_string += str(i*dimZ + j).zfill(3)
                if (j != dimZ-1):
                    vertex_string += '---'
            vertex_rows.append(vertex_string)
                
        for i in range(dimX-1):
            face_string = ''
            for j in range(dimZ-1):
                if(self.colors[i*(dimZ-1) + j] == -1):
                    # face_string += ' | '+colored(' # ', 'red')
                    face_string += ' | ' + ' ░ '
                elif(self.colors[i*(dimZ-1) + j] == +1):
                    # face_string += ' | '+colored(' # ', 'blue')
                    face_string += ' | ' + ' ▓ '
                elif(self.colors[i*(dimZ-1) + j] == 0):
                    face_string += ' |    '
                else:
                    raise ValueError(f'Invalid color type {self.colors[i*dimZ+j]}')
                if j == dimZ-2:
                    face_string += ' |'
            face_rows.append(face_string)
        sout = ''
        for idx, row in enumerate(vertex_rows):
            sout += row +'\n'
            if idx != len(vertex_rows)-1:
                sout += face_rows[idx]+'\n'
        return sout
    
    def size(self):
        return self.dimX*self.dimZ
    
    def getG(self):
        return self.gauge[0]+self.gauge[1]
    
    def getGx(self):
        return self.gauge[0]
    
    def getGz(self):
        return self.gauge[1]
    
    def getS(self):
        return self.stabs[0]+self.stabs[1]
    
    def getSx(self):
        return self.stabs[0]
    
    def getSz(self):
        return self.stabs[1]
    
    def getDims(self):
        return (self.dimX, self.dimZ)
    
    def max_stab_number(self):
        return self.dimX*self.dimZ - 1
    
    def pcheckZ(self):
        """returns the Z parity check matrix"""
        return np.vstack([pauli2vector(s) for s in self.getSz()])
        
    def pcheckX(self):
        """returns the X parity check matrix"""
        return np.vstack([pauli2vector(s) for s in self.getSx()])
    
    def display(self, pauli):
        dimX = self.dimX
        dimZ = self.dimZ
        if (len(pauli) != dimX*dimZ):
            raise ValueError("Pauli string dimension mismatch with lattice size")
        sout = ''
        slist = list(pauli)
        for i in range(dimX):
            for j in range(dimZ):
                if slist[i*dimZ+j] == 'X':
                    sout += ' X '
                elif slist[i*dimZ+j] == 'Z':
                    sout += ' Z '
                else:
                    sout += '   '
                if (j != dimZ-1):
                    sout += '---'
            if (i != dimX -1):
                sout += '\n'
                sout += ' |    '*dimZ
            sout += '\n'
        print(sout)
        
    def color_lattice(self, colors):
        """
        replace color state with input and recalculate stab and gauge groups 
        """
        dimX = self.dimX-1
        dimZ = self.dimZ-1
        if(len(colors) != dimX*dimZ):
            raise ValueError("Color dimension mismatch with lattice size")
        
        self.stabs = bacon_shor_group(self.dimX, self.dimZ)
        self.gauge = bsgauge_group(self.dimX, self.dimZ)
        self.colors = colors
        
        
        for cidx, c in enumerate(colors):
            if c == -1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), -1)
            elif c == +1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), +1)
        
        
    def update_groups(self, coords, cut_type):
        """
        cut the stabilizer group by coloring the face with the given type
            AND
        update the gauge group 
    
        algo: 
        [0] pick the gauge operator g to cut around
        [1] find s \in S that has weight-2 overlap with g
        [2] divide that s 
        [3] update the gauge group 
        """
        (i, j) = coords
        dimX = self.dimX
        dimZ = self.dimZ
        [Sx, Sz] = self.getSx(), self.getSz()
        [Gx, Gz] = self.getGx(), self.getGz()
        
        if cut_type == -1:
            # -1 = red which is a Z-cut
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'Z'
            g[i*dimZ + j + 1] = 'Z'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sz):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two vertical parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, i+1):
                        s1[k*dimZ + j] = s[k*dimZ + j]
                        s1[k*dimZ + j+1] = s[k*dimZ + j+1]
                    for k in range(i+1, dimX):
                        s2[k*dimZ + j] = s[k*dimZ + j]
                        s2[k*dimZ + j+1] = s[k*dimZ + j+1]
                    del Sz[idx]
                    Sz.append(''.join(s1))
                    Sz.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[k + i*dimZ] = 'Z'
                gauge[k + i*dimZ + 1] = 'Z'
            Gx_new = []
            for g in Gx:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gx_new.append(g)
            Gx = Gx_new
                
        elif cut_type == +1:
            # +1 = blue that is a X-cut:
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'X'
            g[(i+1)*dimZ + j ] = 'X'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sx):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two horizontal parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, j+1):
                        s1[i*dimZ + k] = s[i*dimZ + k]
                        s1[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    for k in range(j+1, dimZ):
                        s2[i*dimZ + k] = s[i*dimZ + k]
                        s2[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    del Sx[idx]
                    Sx.append(''.join(s1))
                    Sx.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[k + i*dimZ] = 'X'
                gauge[k + (i+1)*dimZ] = 'X'
            Gz_new = []
            for g in Gz:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gz_new.append(g)
            Gz = Gz_new

        # update the groups
        self.stabs = [Sx, Sz]
        self.gauge = [Gx, Gz]
        
    def error_is_corrected(self, syn, l_1, l_2, l_op):
        #syn is a vector of syndrome measurements, which has the Sx syndrome bits first
        #l_1, l_2 are booleans corresponding to the logical operator measurement
        #l_op is the logical operator we're measuring, written as a stim Pauli string

        #also need to know the stabilizers to feed into the decoder

        #check parity of l_1, l_2. If they are the same: no logical error
        #If they are different: logical error
        logical_error = ((l_1+l_2) % 2 == 0)

        #syndrome measurement gives a syndrome s. Feed into decoder to get a correction operator c
        #set up decoder
        Sx = lat.getSx()
        Sz = lat.getSz()
        Hx = np.array([[1 if i != '_' else 0 for i in s] for s in Sx])
        Hz = np.array([[1 if i != '_' else 0 for i in s] for s in Sz])
        Mx = Matching(Hx)
        Mz = Matching(Hz)

        #obtain correction operator
        cx = Mx.decode(syn[:len(Sx)])
        cz = Mz.decode(syn[len(Sx):])
        Rx = stim.PauliString(''.join(['X' if i == 1 else '_' for i in cx]))
        Rz = stim.PauliString(''.join(['Z' if i == 1 else '_' for i in cz]))

        correction_op = Rx*Rz



        #check [c, l]
            #If l_1 = l_2 and [c, l_op] = 0, then the error has been properly corrected
            #in that c keeps the proper eigenstate
            #If l_1 = l_2 and [c, l_op] \neq 0, then the decoder takes the state out of the correct eigenstate
            #If l_1 \neq l_2 and [c,l_op] = 0, then the decoder fails to correct the error
            #If l_1 \neq l_2 and [c, l_op] \neq 0, then the decoder properly corrects the error
        is_corrected = (correction_op.commutes(l_op) != logical_error)

        return is_corrected


        #check [c, l]
            #If l_1 = l_2 and [c, l_op] = 0, then the error has been properly corrected
            #in that c keeps the proper eigenstate
            #If l_1 = l_2 and [c, l_op] \neq 0, then the decoder takes the state out of the correct eigenstate
            #If l_1 \neq l_2 and [c,l_op] = 0, then the decoder fails to correct the error
            #If l_1 \neq l_2 and [c, l_op] \neq 0, then the decoder properly corrects the error
        is_corrected = (correction_op.commutes(l_op) != logical_error)

        return is_corrected

In [37]:
# Construction of a Pauli noise model

class PauliNoiseModel():
    """
    Constructs noisy Stim circuits from 2D Compass Code stabilizers and logical observables

    A noise model is defined as mapping a perfect operation to an imperfect operation
    """
    def __init__(self, one_qb_gate_rates : List[float] = [0] * 3, two_qb_gate_rates : List[float] = [0] * 15, meas_error_rate : float = 0):
        self.one_qb_gate_rates = one_qb_gate_rates
        assert(len(one_qb_gate_rates) == 3)
        self.two_qb_gate_rates = two_qb_gate_rates
        assert(len(two_qb_gate_rates) == 15)
        self.meas_error_rate = meas_error_rate

    def one_qb_pauli_noise(self) -> str:
        """ 
        Returns a string representing a single qubit Stim Pauli error channel
        """
        channel_str = 'PAULI_CHANNEL_1({},{},{})'.format(*self.one_qb_gate_rates)
        return channel_str 
    
    def two_qb_pauli_noise(self) -> str:
        """ 
        Returns a string representing a two qubit Stim Pauli error channel
        """
        channel_str = 'PAULI_CHANNEL_2({},{},{},{},{},{},{},{},{},{},{},{},{},{},{})'.format(*self.two_qb_gate_rates)
        return channel_str 
    
    def measurement_gadget(self, pauli_observable : str):
        """ 
        Stim gadget to directly measure the specified 'pauli_observable'
        """
        meas_circ = ''
        x_meas_pos = ''
        y_meas_pos = '' 
        z_meas_pos = ''
        pos = {'I' : [], 'X' : [], 'Y' : [], 'Z' : []}
        for i, pauli in enumerate(pauli_observable):
            if pauli == 'X':
                x_meas_pos += f' {i}'
            elif pauli == 'Y':
                y_meas_pos += f' {i}'
            elif pauli == 'Z':
                z_meas_pos += f' {i}'
            else:
                pass 

        meas_circ += f'MX({self.meas_error_rate})' + (x_meas_pos * (len(x_meas_pos) != 0)) + '\n' + f'MY({self.meas_error_rate})' + (y_meas_pos * (len(y_meas_pos) != 0)) + '\n' + f'MZ({self.meas_error_rate})' + (z_meas_pos * (len(z_meas_pos) != 0)) + '\n'
        return stim.Circuit(meas_circ)
    
    def stabilizer_gadget(self, stabilizer_in : str, ancilla_index : int, construction : str = 'cnot'):
        """
        Input:
            stabilizer: a single stabilizer written in terms of {I/_,X,Y,Z}
            construction: direct or hadamard:
                1) `cnot` using only CNOTs from data to ancilla along with single qubit gates
                    - H then S    : rotates Z basis -> Y basis
                    - S_dag then H: rotates Y basis -> Z basis
                    verifiable via checking that: Y stabilizer == kron(S@H,I) @ CNOT @ kron(H@S_dag)
                2) `hadamard` using H gates on ancilla and C-Pauli from ancilla to data
        Output:
            Measurement gadget

        the data qubits that are indicated in the stabilizer appear first
        the ancilla index starts from 0, which is the first ancilla qubit after the data 
        """
         # allow both '_' and 'I' in stabilizers
        stabilizer = stabilizer_in.replace('_','I')
        
        N = len(stabilizer)
        circ_string = ''
        if construction == 'cnot':
            for i, pauli in enumerate(stabilizer):
                if pauli == 'Z':
                    # Z-gates are just cnots from data to ancilla
                    noise_string = self.two_qb_pauli_noise() + f' {i} {N+ancilla_index}\n'
                    circ_string += f'CX {i} {ancilla_index+N} \n' 
                    circ_string += noise_string
                elif pauli == 'X':
                    # X-gates are conjugated by hadamards
                    noise_string_1qb = self.one_qb_pauli_noise() + f' {i}\n'
                    noise_string_2qb = self.two_qb_pauli_noise() + f' {i} {N+ancilla_index}\n'
                    circ_string += f'H {i} \n'
                    circ_string += noise_string_1qb
                    circ_string += f'CX {i} {ancilla_index+N} \n' 
                    circ_string += noise_string_2qb
                    circ_string += f'H {i} \n'
                    circ_string += noise_string_1qb
                elif pauli == 'Y':
                    # Y-gates are conjugated by S-gates and hadamards
                    noise_string_1qb = self.one_qb_pauli_noise() + f' {i}\n'
                    noise_string_2qb = self.two_qb_pauli_noise() + f' {i} {N+ancilla_index}\n'
                    circ_string = f'S_DAG {i} \n'
                    circ_string += noise_string_1qb
                    circ_string += f'H {i} \n'
                    circ_string += noise_string_1qb
                    circ_string += f'CX {i} {ancilla_index+N} \n' 
                    circ_string += noise_string_2qb
                    circ_string += f'H {i} \n'
                    circ_string += noise_string_1qb
                    circ_string += f'S {i} \n'
                    circ_string += noise_string_1qb 

            # noisy ancilla measurement
            circ_string += f'MR({self.meas_error_rate}) {N+ancilla_index}\n'
        return stim.Circuit(circ_string)


    def stabilizer_gadget_v2(self, stabilizer_in : int):
        """
        Use Stim's built in 'MPP' function
        (IS THIS PREFERRED OVER SPLITTING UP MEASUREMENTS INTO CONSTITUENT PARTS AND APPLYING CIRCUIT-LEVEL NOISE?)
        """
        # allow both '_' and 'I' in stabilizers
        stabilizer = stabilizer_in.replace('_','I')
        
        N = len(stabilizer)
        circ_string = f'MPP({self.meas_error_rate}) '
        for i, pauli in enumerate(stabilizer):
            if (pauli != 'I'):
                circ_string += f'{pauli}{i}*'
        circ_string = circ_string[:-1] + '\n'
        return stim.Circuit(circ_string)

In [38]:
def construct_decoder_graph_weighted(pc_mat : np.array, weights : np.array, num_rounds: int = 1):
    decode_graph = nx.MultiGraph()
    color_map = []

    for k in range(num_rounds):
        curr_decode_graph = nx.MultiGraph()
        for i, stab in enumerate(pc_mat):
            curr_decode_graph.add_node(str(i) + "," + str(k))

        for j, qubit in enumerate(pc_mat.T):
            stabs = [str(i) + "," + str(k) for i in range(len(qubit)) if qubit[i] == 1]
            if (len(stabs) == 1):
                stabs = tuple(stabs + ['B' + "," + str(k)])
                curr_decode_graph.add_edge(*stabs, weight=weights[j])
            else:
                curr_decode_graph.add_edge(*tuple(stabs), weight=weights[j])

        for node in curr_decode_graph:
            if(str(node) == 'B' + "," + str(k)):
                color_map.append('tab:red')
            else:
                color_map.append('tab:blue')

        if (k > 0):
            for node in curr_decode_graph:
                decode_graph.add_edge(node[0] + "," + str(k), 
                                           node[0] + ',' + str(k - 1))
        decode_graph = nx.compose(decode_graph, curr_decode_graph)

    return decode_graph, color_map

def construct_decoder_graph(pc_mat : np.array, num_rounds: int = 1):
    decode_graph = nx.MultiGraph()
    color_map = []

    for k in range(num_rounds):
        curr_decode_graph = nx.MultiGraph()
        for i, stab in enumerate(pc_mat):
            curr_decode_graph.add_node(str(i) + "," + str(k))

        for j, qubit in enumerate(pc_mat.T):
            stabs = [str(i) + "," + str(k) for i in range(len(qubit)) if qubit[i] == 1]
            if (len(stabs) == 1):
                stabs = tuple(stabs + ['B' + "," + str(k)])
                curr_decode_graph.add_edge(*stabs)
            else:
                curr_decode_graph.add_edge(*tuple(stabs))

        for node in curr_decode_graph:
            if(str(node) == 'B' + "," + str(k)):
                color_map.append('tab:red')
            else:
                color_map.append('tab:blue')

        if (k > 0):
            for node in curr_decode_graph:
                decode_graph.add_edge(node[0] + "," + str(k), 
                                           node[0] + ',' + str(k - 1))
        decode_graph = nx.compose(decode_graph, curr_decode_graph)

    return decode_graph, color_map

## 2 Approaches
* Circuit-Agnostic Approach - 
    * Assume measurements aren't perfect
    * Define detector error model based on this
* Circuit-Level Approach - 
    * Differs from circuit-level noise only in that there is no noise associated with each logical gate (i.e. gates assumed perfect, measurements noisy)

### Circuit-Agnostic Phenomenological Simulation

In [39]:
"""Sample a random Pauli error
    for now, I'll assume biased, symmetric, uncorrelated Pauli channel
"""
import random

def random_pauli(num_qubits : int, rates : list):
    assert rates[0] + rates[1] + rates[2] <= 1, "Error rate must not exceed 1"""
    paulis = []
    for i in range(num_qubits):
        x = random.uniform(0, 1)
        if x <= rates[0]: 
            paulis.append('X')
        elif x <= rates[0] + rates[1]:
            paulis.append('Y')
        elif x <= rates[0] + rates[1] + rates[2]:
            paulis.append('Z')
        else:
            paulis.append('_')
    return ''.join(paulis)

def pcheck_clipZ(pcheck):
    """
    clip (remove) the 1st half of the parity check matrix
    """
    L = int(pcheck.shape[1]/2)
    return pcheck[:, L:]

def pcheck_clipX(pcheck):
    """
    clip (remove) the 2nd half of the parity check matrix
    """
    L = int(pcheck.shape[1]/2)
    return pcheck[:, :L]

### Circuit-Level Approach

In [128]:
def compile_compass_circuit(compass_code : Lattice2D, pauli_noise_model : PauliNoiseModel, noiseless_model : PauliNoiseModel, rounds : int):
    """ 
    We compile a compass code lattice into stim circuits with detectors between subsequent stabilizer measurements

    Params:
    * compass_code - Instance of 'Lattice2D' class that defines compass code
    * pauli_noise_model - Instance of 'PauliNoiseModel' that defines Pauli noise model
    * rounds - Number of rounds of stabilizer measurements we look to perform
    """
    
    compass_circuit = stim.Circuit()

    # Perform encoding into logical all-zeros state
    encoding_circ = StabilizerCode(compass_code.getS()).encoding_circuit(stim=True)
    compass_circuit += encoding_circ

    # Add dummy measurements at start of circuit (X stabs)
    num_X_stabs = len(compass_code.getSx())
    num_Z_stabs = len(compass_code.getSz())

    for idx, sx in enumerate(compass_code.getSx()):
        compass_circuit += noiseless_model.stabilizer_gadget(sx, idx)

    for idz, sz in enumerate(compass_code.getSz()):
        compass_circuit += noiseless_model.stabilizer_gadget(sz, idz)

    

    # Perform n rounds of stabilizer measurements and add detector
    for n in range(rounds):
        if (n > rounds - 1):
            for idx, sx in enumerate(compass_code.getSx()):
                compass_circuit += noiseless_model.stabilizer_gadget(sx, idx)
            for idz, sz in enumerate(compass_code.getSz()):
                compass_circuit += noiseless_model.stabilizer_gadget(sz, idz)
            for idx, sx in enumerate(compass_code.getSx()):
                compass_circuit += stim.Circuit(f"DETECTOR({idx}, {n + 1}, 0) rec[{-1 - idx - num_Z_stabs}] rec[{-1 - num_X_stabs - (2 * num_Z_stabs) - idx}]")
            for idz, sz in enumerate(compass_code.getSz()):
                compass_circuit += stim.Circuit(f"DETECTOR({idz}, {n + 1}, 1) rec[{-1 - idz}] rec[{-1 - num_Z_stabs - num_X_stabs - idz}]")
        else:
            for idx, sx in enumerate(compass_code.getSx()):
                compass_circuit += pauli_noise_model.stabilizer_gadget(sx, idx)
            for idz, sz in enumerate(compass_code.getSz()):
                compass_circuit += pauli_noise_model.stabilizer_gadget(sz, idz)
            for idx, sx in enumerate(compass_code.getSx()):
                compass_circuit += stim.Circuit(f"DETECTOR({idx}, {n + 1}, 0) rec[{-1 - idx - num_Z_stabs}] rec[{-1 - num_X_stabs - (2 * num_Z_stabs) - idx}]")
            for idz, sz in enumerate(compass_code.getSz()):
                compass_circuit += stim.Circuit(f"DETECTOR({idz}, {n + 1}, 1) rec[{-1 - idz}] rec[{-1 - num_Z_stabs - num_X_stabs - idz}]")
    
    return compass_circuit 

#### Initial Testing

In [129]:
dim = 5
lat = Lattice2D(dim, dim)
coloring = np.random.randint(-1, 2, size=(dim - 1)**2)
lat.color_lattice(coloring)
print(lat)

000---001---002---003---004
 |  ░  |     |     |  ░  |
005---006---007---008---009
 |     |  ░  |     |  ░  |
010---011---012---013---014
 |  ▓  |  ░  |  ▓  |  ░  |
015---016---017---018---019
 |  ▓  |     |     |     |
020---021---022---023---024



In [130]:
# Fix a noise model for phenomenological noise model
one_qb_rates = [0.01] * 3
two_qb_rates = [0.05] * 15
meas_rate = 0.01
pauli_nm = PauliNoiseModel(one_qb_rates, two_qb_rates, meas_rate)

# Define a noiseless model
noiseless_nm = PauliNoiseModel()

# Construct the encoding and measurement circuit for our chosen Compass code
num_stab_meas_rounds = 2
circ = compile_compass_circuit(lat, pauli_nm, noiseless_nm, num_stab_meas_rounds) 

In [131]:
model = circ.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True, ignore_decomposition_failures=True)

In [132]:
model.diagram("matchgraph-3d")

In [117]:
model

stim.DetectorErrorModel('''
    error(0.1544) D0
    error(0.1544) D0 D7
    error(0.147273) D0 D7 ^ D10
    error(0.0101021) D0 D7 ^ D10 D14
    error(0.05) D0 D7 ^ D10 ^ D14
    error(0.05) D0 D7 ^ D14
    error(0.1472) D0 D23
    error(0.05) D0 D23 ^ D10
    error(0.1) D0 D23 ^ D10 D14
    error(0.05) D0 D23 ^ D14
    error(0.05) D0 D23 ^ D14 ^ D10
    error(0.261121) D0 D30
    error(0.18896) D0 D30 ^ D10
    error(0.116) D0 D30 ^ D10 D14
    error(0.05) D0 D30 ^ D10 ^ D14
    error(0.095) D0 D30 ^ D14
    error(0.05) D0 D30 ^ D14 ^ D10
    error(0.108082) D0 ^ D9
    error(0.05) D0 ^ D14
    error(0.1544) D1
    error(0.1112) D1 D7
    error(0.095) D1 D7 ^ D12
    error(0.0101021) D1 D7 ^ D12 D14
    error(0.0101021) D1 D7 ^ D12 D17
    error(0.05) D1 D7 ^ D12 ^ D17
    error(0.05) D1 D7 ^ D14
    error(0.05) D1 D7 ^ D14 ^ D12
    error(0.05) D1 D7 ^ D17
    error(0.1472) D1 D24
    error(0.05) D1 D24 ^ D12 D14
    error(0.14) D1 D24 ^ D12 D17
    error(0.05) D1 D24 ^ D14 D17
    

In [118]:
matching = pymatching.Matching.from_detector_error_model(model)
sampler = circ.compile_detector_sampler()
syndrome, actual_observables = sampler.sample(shots = 100, separate_observables=True)

In [119]:
matching.decode_batch(syndrome)

array([], shape=(100, 0), dtype=uint8)

In [126]:
def count_logical_errors(circuit: stim.Circuit, num_shots: int) -> int:
    # Sample the circuit.
    sampler = circuit.compile_detector_sampler()
    detection_events, observable_flips = sampler.sample(num_shots, separate_observables=True)

    # Configure a decoder using the circuit.
    detector_error_model = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True)
    matcher = pymatching.Matching.from_detector_error_model(detector_error_model)

    # Run the decoder.
    predictions = matcher.decode_batch(detection_events)
    print(matcher.num_fault_ids)

    # Count the mistakes.
    num_errors = 0
    for shot in range(num_shots):
        actual_for_shot = observable_flips[shot]
        predicted_for_shot = predictions[shot]
        if not np.array_equal(actual_for_shot, predicted_for_shot):
            num_errors += 1
    return num_errors

In [127]:
count_logical_errors(circ, 1000)

0


0