# Coloring Compass gym

## compass.py

In [131]:
'''compass.py'''
import numpy as np
import pymatching
import stim
from termcolor import colored

class Lattice2D():
    """
    convention: 
    X coords extend vertically |
    Z coords extend horizontally --
    store the coloring as a list with values in {-1, 0, 1}
    
    Red ~ -1
    Blue ~ +1
    White ~ 0
    
    preallocate logical X and L as cuts accross the lattice
    """
    def __init__(self, dimX, dimZ):
        self.dimX = dimX
        self.dimZ = dimZ
        self.colors = [0] * (dimX-1)*(dimZ-1)
        self.stabs = bacon_shor_group(dimX, dimZ)
        self.gauge = bsgauge_group(dimX, dimZ)
        self.Lx = ['X']*dimX+['_']*dimX*(dimZ-1)
        self.Lz = (['Z']+['_']*(dimX-1))*dimZ
        
    def __str__(self):
        vertex_rows = []
        face_rows = []
        dimX = self.dimX
        dimZ = self.dimZ
        for i in range(dimX):
            vertex_string = ''
            for j in range(dimZ):
                vertex_string += str(i*dimZ + j).zfill(3)
                if (j != dimZ-1):
                    vertex_string += '---'
            vertex_rows.append(vertex_string)
                
        for i in range(dimX-1):
            face_string = ''
            for j in range(dimZ-1):
                if(self.colors[i*(dimZ-1) + j] == -1):
                    face_string += ' | '+colored(' # ', 'red')
                elif(self.colors[i*(dimZ-1) + j] == +1):
                    face_string += ' | '+colored(' # ', 'blue')
                elif(self.colors[i*(dimZ-1) + j] == 0):
                    face_string += ' |    '
                else:
                    raise ValueError(f'Invalid color type {self.colors[i*dimZ+j]}')
                if j == dimZ-2:
                    face_string += ' |'
            face_rows.append(face_string)
        sout = ''
        for idx, row in enumerate(vertex_rows):
            sout += row +'\n'
            if idx != len(vertex_rows)-1:
                sout += face_rows[idx]+'\n'
        return sout
    
    def size(self):
        return self.dimX*self.dimZ
    
    def getG(self):
        return self.gauge[0]+self.gauge[1]
    
    def getGx(self):
        return self.gauge[0]
    
    def getGz(self):
        return self.gauge[1]
    
    def getS(self):
        return self.stabs[0]+self.stabs[1]
    
    def getSx(self):
        return self.stabs[0]
    
    def getSz(self):
        return self.stabs[1]
    
    def getDims(self):
        return (self.dimX, self.dimZ)
    
    def display(self, pauli):
        dimX = self.dimX
        dimZ = self.dimZ
        if (len(pauli) != dimX*dimZ):
            raise ValueError("Stabilizer dimension mismatch with lattice size")
        sout = ''
        slist = list(pauli)
        for i in range(dimX):
            for j in range(dimZ):
                if slist[i*dimZ+j] == 'X':
                    sout += ' X '
                elif slist[i*dimZ+j] == 'Z':
                    sout += ' Z '
                else:
                    sout += '   '
                if (j != dimZ-1):
                    sout += '---'
            if (i != dimX -1):
                sout += '\n'
                sout += ' |    '*dimZ
            sout += '\n'
        print(sout)
        
    def color_lattice(self, colors):
        """
        replace color state with input and recalculate stab and gauge groups 
        """
        dimX = self.dimX-1
        dimZ = self.dimZ-1
        if(len(colors) != dimX*dimZ):
            raise ValueError("Color dimension mismatch with lattice size")
            
        self.colors = colors
        
        for cidx, c in enumerate(colors):
            if c == -1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), -1)
            elif c == +1:
                self.update_groups((int(np.floor(cidx/dimZ)), cidx%dimZ), +1)
        
    def color_face(self, loc, cut_type):
        if self.colors[loc] != 0:
            raise ValueError(f'Face already colored')
        self.colors[loc] = cut_type
        dimZ = self.dimZ-1
        self.update_groups((int(np.floor(loc/dimZ)), loc%dimZ), cut_type)
            
    def update_groups(self, coords, cut_type):
        """
        cut the stabilizer group by coloring the face with the given type
            AND
        update the gauge group 
    
        algo: 
        [0] pick the gauge operator g to cut around
        [1] find s \in S that has weight-2 overlap with g
        [2] divide that s 
        [3] update the gauge group 
        """
        (i, j) = coords
        dimX = self.dimX
        dimZ = self.dimZ
        [Sx, Sz] = self.getSx(), self.getSz()
        [Gx, Gz] = self.getGx(), self.getGz()
        
        if cut_type == -1:
            # -1 = red which is a Z-cut
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'Z'
            g[i*dimZ + j + 1] = 'Z'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sz):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two vertical parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, i+1):
                        s1[k*dimZ + j] = s[k*dimZ + j]
                        s1[k*dimZ + j+1] = s[k*dimZ + j+1]
                    for k in range(i+1, dimX):
                        s2[k*dimZ + j] = s[k*dimZ + j]
                        s2[k*dimZ + j+1] = s[k*dimZ + j+1]
                    del Sz[idx]
                    Sz.append(''.join(s1))
                    Sz.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[i + k*dimZ] = 'Z'
                gauge[i + k*dimZ + 1] = 'Z'
            Gx_new = []
            for g in Gx:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gx_new.append(g)
            Gx = Gx_new
                
        elif cut_type == +1:
            # +1 = blue that is a X-cut:
            g = ['_'] * dimX*dimZ
            g[i*dimZ + j] = 'X'
            g[(i+1)*dimZ + j ] = 'X'
            
            gvec = pauli2vector(''.join(g))
            
            # cut the relevant stabilizer
            for idx, s in enumerate(Sx):
                # find the overlapping stabilizer
                if pauli_weight(np.bitwise_xor(gvec, pauli2vector(s))) == pauli_weight(s) - 2:
                    # cut s into two horizontal parts 
                    s1 = ['_'] * dimX*dimZ
                    s2 = ['_'] * dimX*dimZ
                    for k in range(0, j+1):
                        s1[i*dimZ + k] = s[i*dimZ + k]
                        s1[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    for k in range(j+1, dimX):
                        s2[i*dimZ + k] = s[i*dimZ + k]
                        s2[(i+1)*dimZ + k] = s[(i+1)*dimZ + k]
                    del Sx[idx]
                    Sx.append(''.join(s1))
                    Sx.append(''.join(s2))
                    break
            
            # make new gauge operator and update gauge group 
            gauge = ['_'] * dimX*dimZ
            for k in range(0, j+1):
                gauge[k + i*dimZ] = 'X'
                gauge[k + (i+1)*dimZ] = 'X'
            Gz_new = []
            for g in Gz:
                if twisted_product(pauli2vector(''.join(g)), pauli2vector(''.join(gauge))) == 0:
                    Gz_new.append(g)
            Gz = Gz_new
            
        # update the groups
        self.stabs = [Sx, Sz]
        self.gauge = [Gx, Gz]
        
def bacon_shor_group(dimX, dimZ):
    """
    return the BS code stabilizers
    """
    Xs = []
    Zs = []
    for i in range(dimZ-1):
        s = list('_'*dimZ*dimX)
        for j in range(dimX):
            s[j*dimZ+i] = 'Z'
            s[j*dimZ+i+1] = 'Z'
        Zs.append(''.join(s))
    for i in range(dimX-1):
        s = list('_'*dimX*dimZ)
        for j in range(dimZ):
            s[j+i*dimZ] = 'X'
            s[j+(i+1)*dimZ] = 'X'
        Xs.append(''.join(s))
    return [Xs, Zs]

def bsgauge_group(dimX, dimZ):
    """
    return the gauge group generators for the BS code
    """
    Xs = []
    Zs = []
    # make X-type gauge ops
    for i in range(0, dimZ):
        for j in range(dimX-1):
            s = list('_'*dimX*dimZ)
            s[i+j*dimZ] = 'X'
            s[i+(j+1)*dimZ] = 'X'
            Xs.append(''.join(s))
    # make Z-type gauge ops
    for j in range(dimZ-1):
        for i in range(0, dimX): 
            s = list('_'*dimX*dimZ)
            s[i*dimZ+j] = 'Z'
            s[i*dimZ+j+1] = 'Z'
            Zs.append(''.join(s))
    return [Xs, Zs]

def pauli_weight(pauli):
    """Get weight of pauli operator"""
    if type(pauli) is list:
        pvec = np.array(pauli)
    elif type(pauli) is str:
        pvec = pauli2vector(pauli)
    else:
        pvec = pauli
    return np.sum(np.bitwise_or(pvec[:int(len(pvec)/2)], pvec[int(len(pvec)/2):]))

def pauli2binary(pstr):
    """
    convert pstr to a binary vector
    """
    bstr = [0]*2*len(pstr)
    for idx, c in enumerate(pstr):
        if c == 'X':
            bstr[idx] = 1
        elif c == 'Z':
            bstr[idx+len(pstr)] = 1
        elif c == 'Y':
            bstr[idx] = 1
            bstr[idx+len(pstr)] = 1
    return np.array(bstr)

def twisted_product(stab_binary, pauli_binary):
    """
    take twisted product of stabilizer with pauli to calculate commutator 
    """
    L = int(len(stab_binary)/2)
    return (stab_binary[:L]@pauli_binary[L:] + stab_binary[L:]@pauli_binary[:L]) % 2

def parity_check(stabs, pauli):
    if len(stabs) == 0:
        return np.array([0])
    if type(pauli[0]) is str:
        bvec = pauli2vector(pauli)
    else: 
        bvec = pauli
    if type(stabs[0]) is str:
        return np.array([twisted_product(pauli2vector(s), bvec) for s in stabs])
    else:
        return np.array([twisted_product(s, bvec) for s in stabs])
    
class StimWrapper():
    def __init__(self, lattice):
        self.lat = lattice
        self.stabs = []
        for s in lattice.getS():
            self.stabs.append(stim.PauliString(s))
    def syndrome(self, error):
        syn = []
        for s in self.stabs:
            if s.commutes(stim.PauliString(error)):
                syn.append(0)
            else:
                syn.append(1)
        return syn
    def logical_check(self, pauli):
        if ( pauli.commutes(stim.PauliString(''.join(self.lat.Lx))) and 
            pauli.commutes(stim.PauliString(''.join(self.lat.Lz)))):
            return True
        else: 
            return False
    def test(self, hist, decoder = 'matching'):
        
        if decoder == 'matching':
            Decoder = pymatching.Matching
        elif decoder == 'uf':
            Decoder = UnionFindPy.Decoder
            
        # setup MWPM decoder
        Sx = self.lat.getSx()
        Sz = self.lat.getSz()
        Hx = np.array([[1 if i != '_' else 0 for i in s] for s in Sx])
        Hz = np.array([[1 if i != '_' else 0 for i in s] for s in Sz])
        print(Hx)
        print(Hz)
        Mx = Decoder(Hx)
        Mz = Decoder(Hz)
        
        r_failure = 0
        
        for error in hist.keys():
            syn = self.syndrome(error)
            cx = Mx.decode(syn[:len(Sx)])
            cz = Mz.decode(syn[len(Sx):])
            # cx = Mx.decode(np.array(syn[:len(Sx)]))
            # cz = Mz.decode(np.array(syn[len(Sx):]))
            Rx = stim.PauliString(''.join(['Z' if i == 1 else '_' for i in cx]))
            Rz = stim.PauliString(''.join(['X' if i == 1 else '_' for i in cz]))
            recovery = Rx*Rz
            corrected = recovery*stim.PauliString(error)
            if self.logical_check(corrected) is False:
                r_failure += hist[error]
        return r_failure
    
def pauli2vector(pstr):
    """
    convert pstr to a binary vector
    """
    bstr = [0]*2*len(pstr)
    for idx, c in enumerate(pstr):
        if c == 'X':
            bstr[idx] = 1
        elif c == 'Z':
            bstr[idx+len(pstr)] = 1
        elif c == 'Y':
            bstr[idx] = 1
            bstr[idx+len(pstr)] = 1
    return np.array(bstr)


def check_distribution_globalbias(qubits, wt_min, wt_max, rx, ry, rz, N):
    """
    sample from global biased noise model and
    partition into a group of error checks along with rates below and above wt_min and wt_max
    """
    assert rx+ry+rz <= 1, "dephasing rates > 1"
    hist = dict()
    amp_below = 0
    amp_above = 0
    for n in range(N):
        e = []
        for q in range(qubits):
            s = np.random.uniform(0, 1)
            if 0 <= s < rx:
                e.append('X')
            elif rx <= s < rx+ry:
                e.append('Y')
            elif rx+ry <= s < rx+ry+rz:
                e.append('Z')
            else:
                e.append('I')
        estr = ''.join(e)
        wt = pauli_weight(estr)
        if (wt_min <= wt) and (wt <= wt_max):
            if estr not in hist:
                hist[estr] = 1/N
            else:
                hist[estr] += 1/N
        elif wt < wt_min:
            amp_below += 1/N
        elif wt > wt_max: 
            amp_above += 1/N
    return [hist, amp_below, amp_above]

## colorcompass.py

In [167]:
'''colorcompass.py'''
import gym
from gym import spaces

class ColorCompassCodeEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, **kw):
        self.nrow,self.ncol,self.p_fail_threshold,noise_params = kw.values()
        self.min_correctable_weight = (min(self.nrow,self.ncol)-1)//2
        if noise_params['fixed']:
            self.hist,self.amp_below,self.amp_above = check_distribution_globalbias(self.nrow*self.ncol,
                                                                                   self.min_correctable_weight+1,
                                                                                   self.nrow*self.ncol,
                                                                                   noise_params['px'],
                                                                                   noise_params['py'],
                                                                                   noise_params['pz'],
                                                                                   noise_params['num_samples'])
            self.hist1,_,_ = check_distribution_globalbias(self.nrow*self.ncol,1,1,
                                                       noise_params['px'],
                                                       noise_params['py'],
                                                       noise_params['pz'],
                                                       noise_params['num_samples'])
        else:
            print('Varying noise models not implemented')
            raise NotImplementedError
        self.fixed_noise = noise_params['fixed']
        
        self.reset()
        
#         self.action_space = spaces.Discrete(len(self.actions))
#         self.observation_space = spaces.Box(-2,2,[len(qstate2state(self.init_state))])
        
    def reset(self):
        self.lat = Lattice2D(self.nrow,self.ncol)
        slat = StimWrapper(self.lat)
        self.p_fail = self.amp_above + slat.test(self.hist)
        return self.lat.colors
        
    def step(self, action):
        self.lat.color_face(*action)
        slat = StimWrapper(self.lat)
        if self.fixed_noise:
            p_fail = self.amp_above + slat.test(self.hist)
        else:
            print('Varying noise models not implemented')
            raise NotImplementedError
        done = True if (0 not in self.lat.colors or p_fail<self.p_fail_threshold) else False
        reward = self.p_fail-p_fail
        self.p_fail = p_fail
        return self.lat.colors,reward,done,{}
    
    def render(self, mode='human'):
        print(self.lat)

# Test gym

## Test pymatch on bands of error weight

In [119]:
hists = []
for i in range(10):
    hist,amp_below,amp_above = check_distribution_globalbias(25,i,i,0.1,0.1,0.1,int(3e4))
    hists.append(hist)
    print(amp_below,amp_above)

0 0.9999666666665663
0.00013333333333333334 0.9983999999998998
0.0016000000000000005 0.9910999999999006
0.009366666666666624 0.9651999999999035
0.03403333333333343 0.9106333333332428
0.08916666666666662 0.8069999999999209
0.1930999999999885 0.6602666666666037
0.34116666666663886 0.4905666666666224
0.51216666666662 0.32183333333330766
0.6773333333332685 0.19026666666665548


In [68]:
noise_params= {'fixed':True,
               'px':0.1,
               'py':0.1,
               'pz':0.1,
               'num_samples':int(1e4)}
kw = {'nrow':5,
      'ncol':5,
      'p_fail_threshold':0.01,
      'noise_params':noise_params}
env = ColorCompassCodeEnv(**kw)
sequence = np.stack([range(16),[1,-1,-1,1,-1,1,1,-1,1,-1,-1,1,-1,1,1,-1]]).T
env.reset()
for action in sequence:
    state,reward,done,info = env.step(action)
    print()
    env.render()
    slat = StimWrapper(env.lat)
    for i in range(10):
        p_fail = slat.test(hists[i])
        p_total = sum(hists[i].values())
        # ratio = np.nan_to_num([p_fail/p_total])
        print(f'Prob fail from weight {i}: {p_fail:.7f}/{p_total:.7f} = {p_fail/p_total*100:.1f}%')


000---001---002---003---004
 | [34m # [0m |     |     |     |
005---006---007---008---009
 |     |     |     |     |
010---011---012---013---014
 |     |     |     |     |
015---016---017---018---019
 |     |     |     |     |
020---021---022---023---024

Prob fail from weight 0: 0.0000000/0.0002000 = 0.0%
Prob fail from weight 1: 0.0000000/0.0017000 = 0.0%
Prob fail from weight 2: 0.0000000/0.0079000 = 0.0%
Prob fail from weight 3: 0.0065667/0.0222667 = 29.5%
Prob fail from weight 4: 0.0261000/0.0580667 = 44.9%
Prob fail from weight 5: 0.0589000/0.1007667 = 58.5%
Prob fail from weight 6: 0.0967000/0.1497333 = 64.6%
Prob fail from weight 7: 0.1189333/0.1703000 = 69.8%
Prob fail from weight 8: 0.1166333/0.1624333 = 71.8%
Prob fail from weight 9: 0.1010333/0.1385333 = 72.9%

000---001---002---003---004
 | [34m # [0m | [31m # [0m |     |     |
005---006---007---008---009
 |     |     |     |     |
010---011---012---013---014
 |     |     |     |     |
015---016---017---018---019
 |

## Test non-square lattice

In [168]:
noise_params= {'fixed':True,
               'px':0.1,
               'py':0.1,
               'pz':0.1,
               'num_samples':int(1e4)}
kw = {'nrow':2,
      'ncol':3,
      'p_fail_threshold':0.01,
      'noise_params':noise_params}
env = ColorCompassCodeEnv(**kw)
sequence = np.stack([range(16),[1,-1,-1,1,-1,1,1,-1,1,-1,-1,1,-1,1,1,-1]]).T
env.reset()
for action in sequence:
    print('action',action)
    state,reward,done,info = env.step(action)
    print()
    env.render()
    slat = StimWrapper(env.lat)
    for i in range(10):
        p_fail = slat.test(hists[i])
        p_total = sum(hists[i].values())
        # ratio = np.nan_to_num([p_fail/p_total])
        print(f'Prob fail from weight {i}: {p_fail:.7f}/{p_total:.7f} = {p_fail/p_total*100:.1f}%')

[[1 1 1 1 1 1]]
[[1 1 0 1 1 0]
 [0 1 1 0 1 1]]
[[1 1 1 1 1 1]]
[[1 1 0 1 1 0]
 [0 1 1 0 1 1]]
action [0 1]
[[1 0 0 1 0 0]
 [0 1 0 0 1 0]]
[[1 1 0 1 1 0]
 [0 1 1 0 1 1]]


ValueError: Each qubit must be contained in either 1 or 2 check operators, not [0 1]

In [184]:
noise_params= {'fixed':True,
               'px':0.1,
               'py':0.1,
               'pz':0.1,
               'num_samples':int(1e4)}
kw = {'nrow':2,
      'ncol':3,
      'p_fail_threshold':0.01,
      'noise_params':noise_params}
env = ColorCompassCodeEnv(**kw)

env.reset()
env.render()
print(env.lat.getS())
print()
env.step([0,1])
env.render()
print(env.lat.getS())

[[1 1 1 1 1 1]]
[[1 1 0 1 1 0]
 [0 1 1 0 1 1]]
000---001---002
 |     |     |
003---004---005

['XXXXXX', 'ZZ_ZZ_', '_ZZ_ZZ']

[[1 0 0 1 0 0]
 [0 1 0 0 1 0]]
[[1 1 0 1 1 0]
 [0 1 1 0 1 1]]


ValueError: Each qubit must be contained in either 1 or 2 check operators, not [0 1]