In [1]:
import unionfind

# Data Structures

In [121]:
### graph data structure
class graph:
    def __init__(self, vertexlist, edgelist):

        self.vertices = vertexlist
        self.edges = edgelist
        self.neighbors = {v: [] for v in vertexlist}
        self.leaves = vertexlist.copy()
        for (u,v) in edgelist:
            self.neighbors[u].append(v)
            self.neighbors[v].append(u)
            if len(self.neighbors[u]) > 1:
                self.leaves.remove(u)
            if len(self.neighbors[v]) > 1:
                self.leaves.remove(v)

    def incident_edges(self, u):
        incident = []
        for v in self.vertices:
            if (u,v) in self.edges: 
                incident.append({u,v})
        return incident

    def find_neighbors(self, u):
        neighbors = []
        for v in self.vertices:
            if (u,v) in self.edges or (v,u) in self.edges: 
                neighbors.append(v)
        return neighbors
    
    def add_vertex(self, u):
        if u not in self.vertices:
            self.vertices.append(u)
            self.neighbors[u] = []
            self.leaves.append(u)
            
    def add_edge(self, e):
        if e not in self.edges:
            (u,v) = e
            if u not in self.vertices:
                self.add_vertex(u)
            if v not in self.vertices:
                self.add_vertex(v)
            self.neighbors[u].append(v)
            self.neighbors[v].append(u)
            if len(self.neighbors[u]) > 1:
                self.leaves.remove(u)
            if len(self.neighbors[v]) > 1:
                self.leaves.remove(v)
            self.edges.append(e)
        
    def remove_edge(self, e):
        (u,v) = e
        if len(self.neighbors[u]) <= 1:
            self.leaves.remove(u)
            self.leaves.append(v)
        if len(self.neighbors[v]) <= 1:
            self.leaves.remove(v)
            self.leaves.append(u)
        self.neighbors[u].remove(v)
        self.neighbors[v].remove(u)
        self.edges.remove(e)


In [4]:
class edgedict(dict):
    def __getitem__(self, key):
        if key in self:
            return super().__getitem__(key)
        return super().__getitem__(tuple(reversed(key)))
    def __setitem__(self, key, value):
        if tuple(reversed(key)) in self:
            return super().__setitem__(tuple(reversed(key)), value)
        else: 
            return super().__setitem__(key, value)

# Union Find Decoder Part 1 (Construct Modified Erasure)

In [127]:
def construct_modified_erasure(decoder_graph, syndrome, erasure):
    #1: initialize clusers, support list, boundary lists
    clusters = unionfind.UnionFind(syndrome) #I'm assuming the initial erasure is empty here
    support = edgedict({e: 1 if e in erasure else 0 for e in decoder_graph.edges})
    boundaries = {s: [s] for s in syndrome}#{s: decoder_graph.neighbors[s] for s in syndrome}
    #2: list all clusters with an odd number of marked vertices
    L = [clusters[clusters.find(list(v)[0])]  for v in clusters.components() if len(v) % 2 == 1]
    #3 while there are odd clusters
    while L != []:
        #4: initialize empty fusion list
        fusion = []
        print('current L: ',L)
        #5 for all u in L, grow the cluster by half an edge
        for u in L:
            #grow the cluster
            root_u = clusters[clusters.find(u)]
            for b in boundaries[root_u]:
                for n in decoder_graph.find_neighbors(b):
                    #grow each edge from the boundary by 0.5
                    if support[(b,n)] < 1:
                        support[(b, n)] += 0.5
                    #add new edges to fusion list
                    if support[(b,n)] == 1:
                        fusion.append((b,n))
        #6: for all edges in fusion list, union the clusters if needed
        for edge in fusion:
            u = edge[0]
            v = edge[1]
            if u not in clusters._elts:
                clusters.add(u)
                boundaries[u] = [u]
            if v not in clusters._elts:
                clusters.add(v)
                boundaries[v] = [v]
            if not clusters.connected(u,v):
                #7 for all edges in fusion list, update the boundary lists
                #####but didn't we just union the clusters?####
                root_u = clusters[clusters.find(u)]
                root_v = clusters[clusters.find(v)]
                if len(clusters.component(root_u)) > len(clusters.component(root_v)):
                #append boundary list of v to boundary list of u
                    boundaries[root_u].extend(boundaries[root_v])
                    #boundaries[root_u].remove(u)
                    #boundaries[root_u].remove(v)
                else:
                #append boundary list of u to boundary list of v
                    boundaries[root_v].extend(boundaries[root_u])
                    #boundaries[root_v].remove(u)
                    #boundaries[root_v].remove(v)
                clusters.union(u,v)
            else:
                fusion.remove(edge)

        #8 replace each u in L with find(u) (new root)
        Lnew = []
        for u in L:
            u_new = clusters[clusters.find(u)]
            if u_new not in Lnew:
                Lnew.append(u_new)
                ####9: remove vertices in boundary list of u that are not boundary vertices
                for v in boundaries[u_new]:
                    #if all edges have support 1, remove from boundary
                    if all(support[tuple(s)]==1 for s in decoder_graph.incident_edges(v)): boundaries[u_new].remove(v)


        L = Lnew
        #10
        if len(clusters.component(u)) % 2 == 0 and u in L:
            L.remove(u)
    #11
    for edge in support.keys():
        if support[edge] == 1: #fully-grown edge
            erasure.append(edge)
    #result at this point is a list of edges making up the modified erasure
    return erasure

current L:  [1, 2, 3]
current L:  [2]


# Union Find Decoder Part 2 (Apply Peeling Decoder)

In [132]:
####12 apply peeling decoder to the erasure
def peeling_decoder(erasure):
    erasure_vertices = set(list(zip(*erasure))[0])
    erasure_vertices.update(set(list(zip(*erasure))[1]))
    erasure_vertices = list(erasure_vertices)
    #1: construct spanning forest of erasure
    erasure_forest = graph([],[]) #need vertex list for modified erasure
    r = erasure_vertices[0] #pick a root
    remaining_vertices = erasure_vertices[1:]
    visited = set()
    queue = [r]

    while remaining_vertices: 
        #find a connected component
        while queue:
            u = queue.pop(0)
            if remaining_vertices:
                erasure_forest.add_vertex(u)
                neighbors = decoder_graph.neighbors[u]
                for v in neighbors:
                    if v not in visited:
                        visited.add(u)
                        erasure_forest.add_edge({u,v}) #modify add_edge to do this automatically
                        queue.append(v)
                        if v in remaining_vertices:
                            remaining_vertices.remove(v)
                            #need to break out of the queue loop if remaining_vertices is empty, possibly by removing from queue?


    #2 initialize A
    A = []
    #3: while the forest is nonempty
    while erasure_forest.edges != []:
        #pick a leaf and remove e, with pendant vertex u
        #get vertices, id pendant vertex
        print(erasure_forest.edges)
        print(erasure_forest.leaves)
        u = erasure_forest.leaves[0]
        print(u)
        v = erasure_forest.neighbors[u][0]
        e = {u,v}
        print(e)
        erasure_forest.remove_edge(e) #need something that updates the leaves when removing edges
        #4: if u is in the syndrome, add e to A, remove u from syndrome, and flip v in syndrome
        if u in syndrome:
            A.append(e)
            syndrome.remove(u)
            if v in syndrome:
                syndrome.remove(v)
            else:
                syndrome.append(v)    
        #(otherwise, do nothing)
    #return product of Z_e for e in A
    P = str()
    for e in A:
        P += 'Z'+str(e)
    return P


[{0, 1}, {0, 3}, {1, 2}]
[3, 2]
3
{0, 3}
[{0, 1}, {1, 2}]
[2, 0]
2
{1, 2}
[{0, 1}]
[0, 1]
0
{0, 1}
Z{0, 3}Z{1, 2}Z{0, 1}


# Testing

In [123]:
decoder_graph = graph([0, 1, 2, 3], [(0, 1), (1, 2), (2, 3), (3, 0)])


In [124]:
#need: 1) decoder graph 
decoder_graph = graph([0, 1, 2, 3], [(0, 1), (1, 2), (2, 3), (3, 0)])
#print(decoder_graph.vertices)
#syndrome 2) list of -1 syndrome vertices
syndrome = [1,2,3]
#erasure 3) list of erased edges
erasure = []
#produce modified erasure to apply peeling decoder to

In [None]:
erasure = construct_modified_erasure(decoder_graph, syndrome, erasure)
print(erasure)

In [None]:
correction = apply_peeling_decoder(erasure)
print(erasure)