In [390]:
import matplotlib.pyplot as plt
import numpy as np
from typing import Callable, List, Dict, Any, Set, FrozenSet, Iterable, Tuple
import math
import random
import networkx as nx

import pandas as pd
import os
import galois

from TCd_decoder import *

import itertools

In [391]:
GF=galois.GF(3)

def n_hood_minus_q(T, c,q):
        n_hood_minus_q = []
        for q1 in nx.neighbors(T, c):
            if q1 !=q:
                n_hood_minus_q.append(q1)
        #make sure neighbors are ordered    

        return sorted(n_hood_minus_q)

In [392]:
code = TC3(6,6)
code.generate_code()

In [393]:
sorted(nx.neighbors(code.T,10))

[-12, -6]

In [474]:
class BP:
    def __init__(self, T, p): # initialize the massages
        self.T = T
        self.p = p
    
        #Temp function specific to qutrit
        self.P = np.array([1-2*self.p, self.p,self.p], dtype= float)
        m_qc = {}
        m_cq = {}
        for q in self.T.VD: # sum over data qubits
            for c in nx.neighbors(self.T, q):
                m_qc[q,c] = np.array([1-2*self.p, self.p,self.p], dtype= float)
                
                #m_qc[q,c,0] = 1-2*self.p
                #m_qc[q,c,1] = self.p
                #m_qc[q,c,2] = self.p

        for q in self.T.VD:
            for c in nx.neighbors(self.T, q):
                m_cq[q,c] = np.array([0,0,0],dtype=float)
                
                #m_cq[q,c,0] = 0
                #m_cq[q,c,1] = 0
                #m_cq[q,c,2] = 0
    
        self.mqc = m_qc
        self.mcq = m_cq

    def error(self):
        self.error = np.random.binomial(3,self.p, len(self.T.VD)).astype(int)

    #def syndrome_match(T, c, E):
     #   sc = T.H[-c-1] @ GF(self.error)   

      #  n_hoodSc = T.H[-c-1][sorted(nx.neighbors(T, c))] @ GF(E)

       # return sc == n_hoodSc

    def message_pass(self):

        syndrome_match = lambda c,E: self.T.H[-c-1] @ GF(self.error)  == self.T.H[-c-1][sorted(nx.neighbors(self.T, c))] @ GF(E)


        for c in self.T.VC:
            num_nbrs = len(sorted(nx.neighbors(self.T, c)))
            sc = self.T.H[-c-1] @ GF(self.error)
            #print(num_nbrs)

            for i, qi in enumerate(sorted(nx.neighbors(self.T, c))):
                error_list=[]
                for e in itertools.product(range(3),repeat=num_nbrs):
                    error_list.append(np.asarray(e, dtype=int))
                    #print(np.array(i))
                
                for pauli in range(3):
                    for error in error_list:
                        E = error
                        E[i] = pauli

                        if syndrome_match(c, E):
                            self.mcq[qi,c][pauli] += np.array([self.mqc[qj, c][E[j]] for j,qj in enumerate(sorted(nx.neighbors(self.T, c))) if qj !=qi]).prod()
                            #print(f'successful pass with E = {E} and pauli = {pauli} and syndrome = {sc}')

        #now normalize
        for c in self.T.VC:
            for q in sorted(nx.neighbors(self.T, c)):
                message_norm = np.sqrt((self.mcq[q,c]**2).sum())
                self.mcq[q,c] *= 1/message_norm
        #Now update the message mqc

        for q in self.T.VD:
            num_nbrs_q = len(set(nx.neighbors(self.T, q)))
            for c in sorted(nx.neighbors(self.T, q)):
                for pauli in range(3):
                    self.mqc[q,c][pauli] = self.P[pauli] * np.array([self.mcq[q,c1][pauli] for c1 in sorted(nx.neighbors(self.T, q)) if c1 != c]).prod() 
                
            

        #now normalize
        for q in self.T.VD:
            for c in sorted(nx.neighbors(self.T, q)):
                message_norm = np.sqrt((self.mqc[q,c]**2).sum())
                self.mqc[q,c] *= 1/message_norm
        #Now update the message mqc
        
    def belief(self):
        
        self.bq = {}
        for q in self.T.VD:
            self.bq[q] = np.array([1-2*self.p, self.p,self.p], dtype= float)
            for c in sorted(nx.neighbors(self.T, q)):
                for pauli in range(3):
                    self.bq[q][pauli] *= self.mcq[q,c][pauli]

        self.pred_error = np.zeros(len(self.T.VD), dtype = int)
        for q in self.T.VD:
            self.pred_error[q] = self.bq[q].argmax()
        


In [506]:
BP_test = BP(code.T, .02)

In [507]:
BP_test.error()
BP_test.error


for i in range(10):
    BP_test.message_pass()

BP_test.belief()

In [508]:
BP_test.mcq

{(0, -1): array([1.00000000e+00, 5.76860377e-06, 5.46137240e-06]),
 (0, -7): array([9.99999998e-01, 6.99502528e-06, 6.00952311e-05]),
 (1, -1): array([1.00000000e+00, 8.57554733e-06, 3.78110954e-06]),
 (1, -2): array([1.00000000e+00, 2.46712313e-05, 1.15560876e-05]),
 (2, -2): array([1.00000000e+00, 7.96606860e-06, 2.25330246e-05]),
 (2, -8): array([9.99999999e-01, 2.78659974e-05, 4.71243158e-05]),
 (3, -2): array([1.00000000e+00, 7.82567815e-06, 5.62903801e-06]),
 (3, -3): array([9.99999981e-01, 1.90278676e-04, 4.44933392e-05]),
 (4, -3): array([9.99999983e-01, 2.52129913e-05, 1.84466588e-04]),
 (4, -9): array([9.99999977e-01, 8.87720633e-05, 1.93508547e-04]),
 (5, -3): array([9.99999999e-01, 4.16291401e-05, 8.88013576e-06]),
 (5, -4): array([9.99998285e-01, 1.85190211e-03, 2.33689969e-05]),
 (6, -4): array([9.99998383e-01, 2.41762131e-05, 1.79824652e-03]),
 (6, -10): array([9.99999670e-01, 8.12120084e-04, 2.52476614e-05]),
 (7, -4): array([9.99998285e-01, 2.51730748e-05, 1.85186049e-

In [509]:
BP_test.bq

{0: array([9.59999998e-01, 8.07030585e-13, 6.56404872e-12]),
 1: array([9.60000000e-01, 4.23138623e-12, 8.73896661e-13]),
 2: array([9.59999998e-01, 4.43964894e-12, 2.12370673e-11]),
 3: array([9.59999982e-01, 2.97811936e-11, 5.00909394e-12]),
 4: array([9.59999962e-01, 4.47641853e-11, 7.13917229e-10]),
 5: array([9.59998353e-01, 1.54186185e-09, 4.15039730e-12]),
 6: array([9.59998131e-01, 3.92679764e-10, 9.08030387e-10]),
 7: array([9.59998353e-01, 5.13283253e-12, 8.15936270e-10]),
 8: array([9.59999976e-01, 6.39172720e-11, 2.35023072e-11]),
 9: array([9.59999982e-01, 1.21906453e-12, 1.74102127e-11]),
 10: array([9.60000000e-01, 1.86861077e-12, 4.97296177e-12]),
 11: array([9.60000000e-01, 5.21913006e-13, 3.85486761e-12]),
 12: array([9.59999690e-01, 3.06850510e-12, 7.77093698e-11]),
 13: array([9.59999997e-01, 3.33901259e-11, 4.30449021e-12]),
 14: array([9.59999956e-01, 8.49930286e-11, 5.49323264e-11]),
 15: array([9.59999976e-01, 5.88592963e-11, 8.54004245e-11]),
 16: array([9.5999

In [510]:
code.H @ (GF(BP_test.pred_error)-GF(BP_test.error))

GF([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], order=3)

In [511]:
GF(BP_test.pred_error)-GF(BP_test.error)

GF([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0], order=3)