In [1]:
import numpy as np
from itertools import chain, combinations
from scipy import sparse

In [4]:
# configuration model
# https://en.wikipedia.org/wiki/Configuration_model

n = 16
deg_v = 3 # w_c. Every bit is in this many checks
deg_c = 4 # w_r. Every check has this many bits in it
num_checks = (n*deg_v)//deg_c
k = n - num_checks

vs = np.array([[j for i in range(deg_v)] for j in range(n)]).flatten()
cs = np.array([[j for i in range(deg_c)] for j in range(num_checks)]).flatten()

H = np.zeros((num_checks, n), dtype=bool)

while (vs.size and cs.size):
    # choose random 'stub' from each array
    double_edge = True
    while(double_edge):
        v_ind = np.random.randint(0, len(vs))
        c_ind = np.random.randint(0, len(cs))

        if (H[cs[c_ind]][vs[v_ind]] != 1):
            double_edge = False
            H[cs[c_ind]][vs[v_ind]] = 1
            vs = np.delete(vs, v_ind)
            cs =np.delete(cs, c_ind)

H = sparse.csc_matrix(H)

In [None]:
from classical_code import *

write_code('./ldpc_codes/16_12_3_4.txt', )

In [6]:
class HypergraphProduct(object):
    def __init__(self, H):
        hx1 = sparse.kron(H, np.eye(H.shape[1], dtype=bool))
        hx2 = sparse.kron(np.eye(H.shape[0], dtype=bool), H.T)
        self.Hx = sparse.csr_matrix(sparse.hstack([hx1, hx2]))

        hz1 = sparse.kron(np.eye(H.shape[1], dtype=bool), H)
        hz2 = sparse.kron(H.T, np.eye(H.shape[0], dtype=bool))
        self.Hz = sparse.csr_matrix(sparse.hstack([hz1, hz2]))

        self.n = self.Hx.shape[1]
        self.k = self.Hx.shape[0]
        self.stabilizers = set(np.arange(self.k))

        self.Fx = np.array([list(powerset(self.Hx[i].indices))[1:] for i in range(self.k)], dtype=object)
        self.Fz = np.array([list(powerset(self.Hz[i].indices))[1:] for i in range(self.k)], dtype=object)

        self.sigma_Fx = np.array([[syn_from_F(g, self.Hz) for g in F] for F in self.Fx], dtype=object)
        self.sigma_Fz = np.array([[syn_from_F(g, self.Hx) for g in F] for F in self.Fz], dtype=object)

    def remove_stabilizers(self, indices=None, num=None):
        if (num):
            indices = set(np.random.choice(self.k, num, replace=False))
        indices = list(self.stabilizers ^ indices)

        return (self.Hx[indices], self.Hz[indices], 
            self.Fx[indices].flatten(), self.Fz[indices].flatten(), 
            self.sigma_Fx[indices].flatten(), self.sigma_Fz[indices].flatten())

    def get_all(self):
        return (self.Hx, self.Hz, self.Fx.flatten(), self.Fz.flatten(), self.sigma_Fx.flatten(), self.sigma_Fz.flatten())

    def powerset(iterable):
        s = list(iterable)
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

    def syn_from_F(F, H):
        eF = np.zeros(H.shape[1], dtype=bool)
        np.put(eF, F, [1])
        return set(np.where(H.dot(eF) % 2)[0])

In [7]:
hx1 = sparse.kron(H, np.eye(H.shape[1], dtype=bool))
hx2 = sparse.kron(np.eye(H.shape[0], dtype=bool), H.T)
Hx = sparse.csr_matrix(sparse.hstack([hx1, hx2], ))

hz1 = sparse.kron(np.eye(H.shape[1], dtype=bool), H)
hz2 = sparse.kron(H.T, np.eye(H.shape[0], dtype=bool))
Hz = sparse.csr_matrix(sparse.hstack([hz1, hz2]))

In [8]:
2**Hx[0].nnz*Hx.shape[0] # number of error syndromes we have to check

24576

In [9]:
def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def syn_from_F(F, H):
    eF = np.zeros(Hx.shape[1], dtype=bool)
    np.put(eF, F, [True])
    return set(np.where(H.dot(eF) % 2)[0])

Fx = [list(powerset(Hx[i].indices))[1:] for i in range(Hx.shape[0])]
Fx = list(chain(*Fx)) # can take set of this list for slight reduction in size
# Fx_coords = np.array([(item, i) for i, j in enumerate(Fx) for item in j])
# Fx_arr = sparse.coo_matrix(([True for i in range(len(Fx_coords))], (Fx_coords[:,0], Fx_coords[:,1])), dtype=bool).tocsc()
# sigma_Fx = (Hz.tocsc() @ Fx_arr).transpose()

# Fx = [set(g) for g in Fx] 
Fz = [list(powerset(Hz[i].indices))[1:] for i in range(Hz.shape[0])]
Fz = list(chain(*Fz))
# Fz_coords = np.array([(item, i) for i, j in enumerate(Fz) for item in j])
# Fz_arr = sparse.coo_matrix(([True for i in range(len(Fz_coords))], (Fz_coords[:,0], Fz_coords[:,1])), dtype=bool).tocsc()
# sigma_Fz = (Hx.tocsc() @ Fz_arr).transpose()

# Fz = [set(g) for g in Fz]

sigma_Fx = [syn_from_F(g, Hz) for g in Fx] # set of indices where syndrome is 1
sigma_Fz = [syn_from_F(g, Hx) for g in Fz]

In [10]:
def ssf(syn, x_z):
    # given a syndrome, syn: sigma_x or sigma_z 
    # x_z false for x stabilizers, true for z stabilizers
    s = set(np.where(syn.copy())[0])
    e = set()
    F = Fx if x_z else Fz
    sigma_F = sigma_Fx if x_z else sigma_Fz
    
    while True:
        max = -1
        max_gen = None
        max_sigma_gen = None
        for g, sigma_g in zip(F, sigma_F):
            s_i = s ^ sigma_g
            if (len(s_i) < len(s)):
                rel_weight = (len(s) - len(s_i)) / len(g)
                if (rel_weight > max):
                    max = rel_weight
                    max_gen = g
                    max_sigma_gen = sigma_g
            else:
                continue

        if (max == -1):
            if (len(s) == 0):
                return e
            else:
                return "FAIL"
        else:
            e = e ^ set(max_gen)
            s = s ^ max_sigma_gen

In [152]:
# hgp = HypergraphProduct(H)
# Hx, Hz, Fx, Fz, sigma_Fx, sigma_Fz = hgp.remove_stabilizers(num=2)

In [11]:
p = 0.01

# how is this able to decode an error with more than (d-1)/2 errors?
sum = 0

for i in range(1):
    eX = [1 if np.random.uniform() < p else 0 for i in range(Hx.shape[1])]
    # eZ = [True if np.random.uniform() < p else False for i in range(Hz.shape[1])]
    sigma_eX = Hx.dot(eX) % 2
    # s = set(np.where(sigma_eX.copy())[0])
    # sigma_eZ = np.dot(Hz, eZ)
    e1 = ssf(sigma_eX, False)
    # e2 = ssf(sigma_eZ, True)
    if(e1 == set(np.where(eX)[0])):
        sum += 1
    # new_e = e1 ^ set(np.where(eX)[0])
    # print(len(np.where(eX)[0]), len(np.where(eZ)[0]))

In [12]:
sum

1

In [182]:
from math import comb
import math
n = Hx.shape[1]
delta_g = 7
print(n, int(p*n))
num_possible_errors = comb(n, int(p*n))
print(n-delta_g, int(p*n))
disjoint_errors = comb(n-delta_g, int(p*n))
print(disjoint_errors/num_possible_errors)

1225 12
1218 12
0.93325230169843
