In [7]:
from typing import List, Dict
from warnings import warn
from math import log, tanh, atanh
import numpy as np

## Lookup table decoder

In [6]:
def lookup_decode(syndrome: np.ndarray) -> np.ndarray:
    """A lookup table decoder for a distance 3 repetition code."""

    assert syndrome.size == 2

    if np.all(syndrome == np.array([False, False])):
        return np.array([False, False, False])
    elif np.all(syndrome == np.array([False, True])):
        return np.array([False, False, True])
    elif np.all(syndrome == np.array([True, False])):
        return np.array([True, False, False])
    elif np.all(syndrome == np.array([True, True])):
        return np.array([False, True, False])
    else:
        raise ValueError("Unrecognized syndrome.")

In [7]:
H = np.array([[True, True, False], [False, True, True]])
err = np.array([True, False, True])
syndrome = err @ H.T
print(lookup_decode(syndrome))

[False  True False]


## Hard input decoder

In [8]:
def hard_input_decode(syndrome: np.ndarray, H: np.ndarray) -> np.ndarray:
    """Uses the Hard input algorithm."""

    assert syndrome.size == H.shape[0]

    found = False
    for i in range(H.shape[1]):
        if np.all(syndrome == H[:, i]):
            found = True
            break
    if not found:
        raise Exception(f"Matching column not found for syndrome {syndrome}.")
    correction = np.array([False] * H.shape[1])
    correction[i] = True
    return correction

In [9]:
H = np.array([
    [1, 1, 0, 1, 1, 0, 0],
    [1, 0, 1, 1, 0, 1, 0],
    [0, 1, 1, 1, 0, 0, 1]
]).astype(bool)
syndrome = np.array([True, False, True])
correction = hard_input_decode(syndrome, H)
print(correction)

[False  True False False False False False]


## Belief propagation

In [3]:
def log_likelihoods(received_message: np.ndarray, flip_probabilities: np.ndarray) -> np.ndarray:
    """Compute posterior log likelihoods of a received message given the probability that
    each bit will flip. See Eqn. 15.9 of Moon.
    
    received_message - The vector y_n received.
    flip_probabilities - The probability that a given bit will flip.
    
    Returns:
    l_n - The vector of posterior likelihoods for the sent message b_n."""

    l_n = []
    for n, y_n in enumerate(received_message):
        if y_n:
            p_0 = flip_probabilities[n]
            p_1 = 1. - flip_probabilities[n]
        else:
            p_0 = 1. - flip_probabilities[n]
            p_1 = flip_probabilities[n]
        l_n.append(log(p_0 / p_1, 10))
    return np.array(l_n)

In [4]:
message = np.array([True, True, False])
flip_probabilities = np.array([0.1, 0.1, 0.1])
print(log_likelihoods(message, flip_probabilities))

[-0.95424251 -0.95424251  0.95424251]


In [41]:
def bp_decode(check_matrix: np.ndarray, log_likelihoods: np.ndarray, max_iter: int = 1_000) -> np.ndarray:
    """Decode with belief propagation. Algorithm 15.1 from Moon."""

    num_checks, num_bits = check_matrix.shape

    l_nm: Dict[Dict[float]] = {}
    for n in range(num_bits):
        l_nm[n] = {}
        for m in range(num_checks):
            if check_matrix[m, n]:
                l_nm[n][m] = log_likelihoods[n]

    l_mn: Dict[Dict[float]] = {}
    for m in range(num_checks):
        l_mn[m] = {}
        for n in range(num_bits):
            if check_matrix[m, n]:
                l_mn[m][n] = 0.
    
    l_out = np.zeros(num_bits)
    
    i = 0
    while i < max_iter:
        # Horizontal step.
        for m in range(num_checks):
            n_vals = sorted(l_mn[m].keys())
            for n in n_vals:
                ls = np.array([l_nm[np][m] for np in n_vals if np != n])
                tanh_prod = np.prod(np.tanh(ls / 2.))
                l_mn[m][n] = 2.0 * atanh(tanh_prod)
        # Vertical step.
        for n in range(num_bits):
            m_vals = sorted(l_nm[n].keys())
            for m in m_vals:
                ls = np.array([l_mn[mp][n] for mp in m_vals if mp != m])
                l_nm[n][m] = log_likelihoods[n] + np.sum(ls)
            ls = [l_mn[mp][n] for mp in m_vals]
            l_out[n] = log_likelihoods[n] + np.sum(ls)
        c = np.array([ln <= 0. for ln in l_out])
        s = ((check_matrix.astype(int) @ c.astype(int)) % 2).astype(bool)
        print(f"i={i}, c={c}, s={s}")
        if np.all(np.invert(s)):
            break
        if i > max_iter:
            warn("Maximum iterations exceeded.")
            break
        i += 1

    return c

In [42]:
h = np.array([
    [True, True, False],
    [False, True, True]
])
message = np.array([True, True, False])
flip_probabilities = np.array([0.1, 0.1, 0.1])
likelihoods = log_likelihoods(message, flip_probabilities)
codeword = bp_decode(h, likelihoods)
print("c=", codeword)

i=0, c=[ True  True False], s=[False  True]
i=1, c=[ True  True  True], s=[False False]
c= [ True  True  True]
