In [2]:
import unionfind #do "pip install unionfind" to get this
import itertools

# Data Structures

In [3]:
class edgedict(dict):
    def __getitem__(self, key):
        if key in self:
            return super().__getitem__(key)
        return super().__getitem__(tuple(reversed(key)))
    def __setitem__(self, key, value):
        if tuple(reversed(key)) in self:
            return super().__setitem__(tuple(reversed(key)), value)
        else: 
            return super().__setitem__(key, value)

In [4]:
### graph data structure
class graph:
    def __init__(self, vertexlist, edgelist):

        self.vertices = vertexlist
        self.edges = edgelist
        self.neighbors = {v: [] for v in self.vertices}
        self.leaves = vertexlist.copy()
        for (u,v) in edgelist:
            if v not in self.neighbors[u]:
                self.neighbors[u].append(v)
                self.neighbors[v].append(u)
                if len(self.neighbors[u]) == 2:
                    self.leaves.remove(u)
                if len(self.neighbors[v]) == 2:
                    self.leaves.remove(v)

    def incident_edges(self, u):
        incident = []
        for v in self.vertices:
            if (u,v) in self.edges: 
                incident.append({u,v})
        return incident

    def find_neighbors(self, u):
        neighbors = []
        for v in self.vertices:
            if (u,v) in self.edges or (v,u) in self.edges: 
                neighbors.append(v)
        return neighbors
    
    def add_vertex(self, u):
        if u not in self.vertices:
            self.vertices.append(u)
            self.neighbors[u] = []
            
    def add_edge(self, e):
        if e not in self.edges:
            (u,v) = e
            if u not in self.vertices:
                self.add_vertex(u)
            if v not in self.vertices:
                self.add_vertex(v)
            self.neighbors[u].append(v)
            self.neighbors[v].append(u)
            if u in self.leaves and len(self.neighbors[u]) > 1:
                self.leaves.remove(u)
            if v in self.leaves and len(self.neighbors[v]) > 1:
                self.leaves.remove(v)
            if u not in self.leaves and len(self.neighbors[u]) == 1:
                self.leaves.append(u)
            if v not in self.leaves and len(self.neighbors[v]) == 1:
                self.leaves.append(v)
            
            self.edges.append(e)
        
    def remove_edge(self, e):
        (u,v) = e
        if len(self.neighbors[u]) == 1:
            self.leaves.remove(u)
        if len(self.neighbors[v]) == 1:
            self.leaves.remove(v)
        if len(self.neighbors[u]) == 2:
            self.leaves.append(u)
        if len(self.neighbors[v]) == 2:
            self.leaves.append(v)    
        self.neighbors[u].remove(v)
        self.neighbors[v].remove(u)
        self.edges.remove(e)

In [5]:
def make_decoder_graph(S):
    n = len(S[0])
    vertices = [s for s in S]
    vertices.append('b')
    edges_by_qubit = edgedict({})
    edges = []
    unused_qs = list(range(n))
    for s1, s2 in itertools.combinations(S, 2):
        for q in range(n):
            if s1[q] =='1' and  s2[q]=='1':
                if (s1,s2) not in edges_by_qubit.keys():
                    edges_by_qubit[(s1,s2)] = []
                edges_by_qubit[(s1,s2)].append(q)
                edges.append((s1,s2))
                if q in unused_qs:
                    unused_qs.remove(q)
    for q in unused_qs:
        for s1 in S:
            if s1[q]=='1':
                if (s1,'b') not in edges_by_qubit.keys():
                    edges_by_qubit[(s1,'b')] = []
                edges_by_qubit[(s1,'b')].append(q)
                edges.append((s1,'b'))
    return((graph(vertices, edges), edges_by_qubit))

# Union Find Decoder Part 1 (Construct Modified Erasure)

In [9]:
def construct_modified_erasure(decoder_graph, syn, erasure):
    #1: initialize clusers, support list, boundary lists
    syndrome=syn.copy()
    syndrome['b'] = sum(syndrome.values()) % 2

    #print(syndrome)
    clusters = unionfind.UnionFind(decoder_graph.vertices) #I'm assuming the initial erasure is empty here
    support = edgedict({e: 1 if e in erasure else 0 for e in decoder_graph.edges})
    boundaries = {s: [s] for s in decoder_graph.vertices}
    
    #2: list all clusters with an odd number of marked vertices
    L = [clusters[clusters.find(list(v)[0])] for v in clusters.components() if sum(syndrome[i] for i in v) % 2 == 1]

    #3 while there are odd clusters 
    #This refers to the number of *marked* vertices being odd
    while L != []:
        #print("current L:", L)
        #4: initialize empty fusion list
        fusion = []
        #5 for all u in L, grow the cluster by half an edge
        for u in L:
            #grow the cluster
            root_u = clusters[clusters.find(u)]
            for b in boundaries[root_u]:
                for n in decoder_graph.neighbors[b]:
                    #grow each edge from the boundary by 0.5
                    if support[(b,n)] < 1:
                        support[(b, n)] += 0.5
                        #add new edges to fusion list
                        if support[(b,n)] == 1:
                            fusion.append((b,n))
        #6: for all edges in fusion list, union the clusters if needed
        for edge in fusion:
            #print(edge)
            u = edge[0]
            v = edge[1]
            if u not in clusters._elts:
                clusters.add(u)
                boundaries[u] = [u]
            if v not in clusters._elts:
                clusters.add(v)
                boundaries[v] = [v]
            if not clusters.connected(u,v):
                #7 for all edges in fusion list, update the boundary lists
                root_u = clusters[clusters.find(u)]
                root_v = clusters[clusters.find(v)]
                if len(clusters.component(root_u)) > len(clusters.component(root_v)):
                #append boundary list of v to boundary list of u
                #later we will update boundary lists to remove anything that isn't on the boundary anymore
                    boundaries[root_u].extend(boundaries[root_v])
                else:
                #append boundary list of u to boundary list of v
                    boundaries[root_v].extend(boundaries[root_u])
                clusters.union(u,v)
        fusion = []

        #8 replace each u in L with find(u) (new root)
        Lnew = []
        for u in L:
            if sum(syndrome[i] for i in clusters.component(u)) % 2 == 1: #this checks the number of -1 syndromes (marked vertices) in a cluster
                u_new = clusters[clusters.find(u)]
                if u_new not in Lnew:
                    Lnew.append(u_new)
                    ####9: remove vertices in boundary list of u that are not boundary vertices
                    for v in boundaries[u_new]:
                        #if all edges have support 1 (and there's at least one edge, remove from boundary
                        if len(decoder_graph.incident_edges(v)) > 0:
                            if all(support[tuple(s)]==1 for s in decoder_graph.incident_edges(v)): boundaries[u_new].remove(v)


        L = Lnew
        #10 remove even components from the list of components to grow
        if len(clusters.component(u)) % 2 == 0 and u in L:
            L.remove(u)
    #11 add edges to the modified erasure
    for edge in support.keys():
        if support[edge] == 1: #fully-grown edge
            erasure.append(edge)
    
    #result at this point is a list of edges making up the modified erasure
    return erasure

# Union Find Decoder Part 2 (Apply Peeling Decoder)

In [37]:
####12 apply peeling decoder to the erasure
def peeling_decoder(erasure, edges_to_qubits, num_qs, decoder_graph, syn):
    syndrome=syn.copy()
    syndrome['b'] = sum(syndrome.values()) % 2
    
    if len(erasure) == 0:
        correction = ''
    else:
        erasure_vertices = set(list(zip(*erasure))[0])
        erasure_vertices.update(set(list(zip(*erasure))[1]))
        erasure_vertices = list(erasure_vertices)
        #1: construct spanning forest of erasure
        erasure_graph = graph(erasure_vertices, erasure)
        erasure_forest = graph([],[]) #need vertex list for modified erasure
        r = erasure_vertices[0] #pick a root
        remaining_vertices = erasure_vertices
        visited = set()
        queue = []
        
        #erasure forest should not have isolated vertices
        #we should remove vertices that have no neighbors or not call them leaves

        while remaining_vertices: 
            #find a connected component
            queue.append(remaining_vertices.pop(0))
            while queue:
                u = queue.pop(0)
                if u not in visited:
                    erasure_forest.add_vertex(u)
                    visited.add(u)
                    neighbors = erasure_graph.neighbors[u]
                    for v in neighbors:
                        if v in remaining_vertices:
                            erasure_forest.add_edge({u,v}) #modify add_edge to do this automatically
                            queue.append(v)
                            remaining_vertices.remove(v)
                
            

        #2 initialize A
        A = []
        #3: while the forest is nonempty
        while erasure_forest.edges != []:
            #pick a leaf and remove e, with pendant vertex u
            #get vertices, id pendant vertex
            u = erasure_forest.leaves[0]
            v = erasure_forest.neighbors[u][0]
            e = {u,v}
            erasure_forest.remove_edge(e)
            #4: if u is in the syndrome, add e to A, remove u from syndrome, and flip v in syndrome
            if syndrome[u]==1:
                A.append(edges_to_qubits[(u,v)][0])
                syndrome[u]=1
                if syndrome[v]==1:
                    syndrome[v]=0
                else:
                    syndrome[v]=1   
        P = str()
        correction = [0 for i in range(num_qs)]
        for e in A:
            P += 'Z'+str(e)
            correction[e] = 1
    return correction


# Testing

In [19]:
decoder_graph, edges_by_qubit =make_decoder_graph(['111111111111110000000000000000000000000000', '000000011111111111111000000000000000000000', '000000000000001111111111111100000000000000', '000000000000000000000111111111111110000000', '000000000000000000000000000011111111111111'])

In [16]:
edges_by_qubit

{('111111111111110000000000000000000000000000',
  '000000011111111111111000000000000000000000'): [7, 8, 9, 10, 11, 12, 13],
 ('000000011111111111111000000000000000000000',
  '000000000000001111111111111100000000000000'): [14, 15, 16, 17, 18, 19, 20],
 ('000000000000001111111111111100000000000000',
  '000000000000000000000111111111111110000000'): [21, 22, 23, 24, 25, 26, 27],
 ('000000000000000000000111111111111110000000',
  '000000000000000000000000000011111111111111'): [28, 29, 30, 31, 32, 33, 34],
 ('111111111111110000000000000000000000000000', 'b'): [0, 1, 2, 3, 4, 5, 6],
 ('000000000000000000000000000011111111111111', 'b'): [35,
  36,
  37,
  38,
  39,
  40,
  41]}

In [17]:
syndrome = {'111111111111110000000000000000000000000000':1, 
            '000000011111111111111000000000000000000000': 0, 
            '000000000000001111111111111100000000000000': 0, 
            '000000000000000000000111111111111110000000': 0, 
            '000000000000000000000000000011111111111111': 0}

In [20]:
new_erasure =construct_modified_erasure(decoder_graph, syndrome, [])


In [21]:
new_erasure

[('111111111111110000000000000000000000000000', 'b')]

In [22]:
num_qs = len(new_erasure[0][0])

In [38]:
peeling_decoder(new_erasure, edges_by_qubit, num_qs, decoder_graph, syndrome)

111111111111110000000000000000000000000000
neighbors ['b']
b
neighbors ['111111111111110000000000000000000000000000']
[{'111111111111110000000000000000000000000000', 'b'}]


[1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]