In [10]:
from typing import List
from math import ceil
import numpy as np
import networkx as nx

## Lookup table decoder

In [6]:
def lookup_decode(syndrome: np.ndarray) -> np.ndarray:
    """A lookup table decoder for a distance 3 repetition code."""

    assert syndrome.size == 2

    if np.all(syndrome == np.array([False, False])):
        return np.array([False, False, False])
    elif np.all(syndrome == np.array([False, True])):
        return np.array([False, False, True])
    elif np.all(syndrome == np.array([True, False])):
        return np.array([True, False, False])
    elif np.all(syndrome == np.array([True, True])):
        return np.array([False, True, False])
    else:
        raise ValueError("Unrecognized syndrome.")

In [7]:
H = np.array([[True, True, False], [False, True, True]])
err = np.array([True, False, True])
syndrome = err @ H.T
print(lookup_decode(syndrome))

[False  True False]


## Hard input decoder

In [8]:
def hard_input_decode(syndrome: np.ndarray, H: np.ndarray) -> np.ndarray:
    """Uses the Hard input algorithm."""

    assert syndrome.size == H.shape[0]

    found = False
    for i in range(H.shape[1]):
        if np.all(syndrome == H[:, i]):
            found = True
            break
    if not found:
        raise Exception(f"Matching column not found for syndrome {syndrome}.")
    correction = np.array([False] * H.shape[1])
    correction[i] = True
    return correction

In [9]:
H = np.array([
    [1, 1, 0, 1, 1, 0, 0],
    [1, 0, 1, 1, 0, 1, 0],
    [0, 1, 1, 1, 0, 0, 1]
]).astype(bool)
syndrome = np.array([True, False, True])
correction = hard_input_decode(syndrome, H)
print(correction)

[False  True False False False False False]


## LDPC codes

In [21]:
def permutation_ldpc_matrix(
    m: int, n: int, w_r: int, w_c: int, permutations: List[List[int]]
) -> np.ndarray:
    """Construct the check matrix by permuatations. See Example 15.3 of Moon.
    
    m - The number of checks.
    n - The number of bits.
    w_r - The number of non-zero elements in a row.
    w_c - The number of non-zero elements in a column.
    permutations - A list of permutations of the form e.g. [1, 2, 0]"""

    assert n * w_c == m * w_r
    assert n % w_r == 0
    assert len(permutations) == w_c

    h0 = np.zeros((n // w_r, n), dtype=bool)
    for j in range(n):
        i = (j - j % w_r) // w_r
        h0[i, j] = True
    permuted_hs = []
    for permutation in permutations:
        permuted_hs.append(h0.copy()[:, permutation])
    return np.vstack(permuted_hs)

In [28]:
n = 4
w_r = 2
w_c = 2
m = n * w_c // w_r
permutations = [
    [0, 1, 2, 3],
    [0, 3, 2, 1]
]
h = permutation_ldpc_matrix(m, n, w_r, w_c, permutations)
print(h)

[[ True  True False False]
 [False False  True  True]
 [ True False False  True]
 [False  True  True False]]


In [29]:
def tanner_graph_from_check_matrix(check_matrix: np.ndarray) -> nx.Graph:
    """Conver the the check matrix to a Tanner graph."""

    num_checks, num_bits = check_matrix.shape

    tanner_graph = nx.Graph()
    for i in range(num_checks):
        tanner_graph.add_node(f"c{i}")
    for j in range(num_bits):
        tanner_graph.add_node(f"b{j}")
    for i in range(num_checks):
        for j in range(num_bits):
            if check_matrix[i, j]:
                tanner_graph.add_edge(f"c{i}", f"b{j}")
    return tanner_graph

In [33]:
tanner_graph = tanner_graph_from_check_matrix(h)