In [1]:
import alist_loader
import numpy as np
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
import japanize_matplotlib
from jax.example_libraries import optimizers
from tqdm.notebook import trange
from functools import partial

In [2]:
filename = "../../DU-Book/Chapter_5/BCH_31_16_3_strip.alist"

In [3]:
noise_std = 0.5
K = 100
bp_itr = 5

In [4]:
H, U, V = alist_loader.load_alist(filename)
H = jnp.array(H)
U = jnp.array(U)
V = jnp.array(V)

n = 31, m = 15
cmap = 7, rmap = 8


In [5]:
n = H.shape[1]
m = H.shape[0]
esize = U.shape[1]

In [6]:
def bmod(x):
    return x - 2*jnp.floor(x/2)

In [7]:
@jax.jit
def parity_check(x):
    print(jnp.sum(H@x))
    tmp = jnp.where(jnp.sum(bmod(H@x)) > 0, 1, 0)
    return tmp
               

In [8]:
def random_codeword():
    x = jnp.array(np.random.randint(0, 2, n))
    for i in range(1, m+1):
        k = m - i + 1
        parity = 0
        for j in range(k+1, n+1):
            parity += H[k-1,j-1] * x[j-1]
        x = x.at[k-1].set(parity % 2)
    p = parity_check(x)
    if p == 0:
        return x
    else:
        print("encoding error!")
    return x

In [9]:
x = random_codeword()
parity_check(x)

Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


Array(0, dtype=int32, weak_type=True)

In [10]:
def mini_batch(K):
    x = np.zeros((n, K))
    y = np.zeros((n, K))
    for j in range(1, K+1):
        x[:, j-1] = -2*jax.device_get(random_codeword()) + 1
        y[:, j-1] = x[:, j-1] + noise_std*np.random.randn(n)
    return jnp.array(x), jnp.array(y)

In [11]:
eps = 1e-7

@partial(jax.jit, static_argnums = 1)
def BP_decoding(y, max_itr, xi):
    alpha = jnp.zeros((esize, K))
    beta = jnp.zeros((esize, K))
    Lambda = 2 * y/(noise_std**2)
    for i in range(max_itr):
        beta = (U.T @ U - jnp.eye(U.shape[1])) @ alpha + U.T@Lambda
        beta = jnp.tile(xi[:, i].reshape([-1,1]), (1, beta.shape[1])) * beta
        tmp = jnp.exp((V.T @ V - jnp.eye(V.shape[1])) @ jnp.log(jnp.abs(jnp.tanh(beta/2))))
        alpha_abs = 2*jnp.arctanh((1-eps)*jax.nn.hard_tanh(tmp)) #epsが小さすぎるとinfが出る
        tmp = -2*V.T@bmod(V@((-jnp.sign(beta) + 1)/2)) + 1
        alpha_sign = tmp * jnp.sign(beta)
        alpha = alpha_sign * alpha_abs
    gamma = U @ alpha + Lambda
    return jnp.tanh(gamma)

In [12]:
x, y = mini_batch(K)
xi = jnp.ones((esize, bp_itr))
x_hat = jnp.sign(BP_decoding(y, bp_itr, xi))
x_hat != x

Array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]], dtype=bool)

In [13]:
jnp.sum(x_hat != x)

Array(0, dtype=int32)

In [14]:
jnp.sum(jnp.sign(y) != x)

Array(83, dtype=int32)

In [15]:
@jax.jit
def get_dot(x):
    return x @ x.T
batch_get_dot = jax.vmap(get_dot, in_axes=-1, out_axes=-1)

In [16]:
def loss(x, y, xi):
    x_hat = BP_decoding(y, bp_itr, xi)
    return jnp.sum(batch_get_dot(x_hat - y))/K

In [31]:
train_itr = 500
adam_lr = 1e-5

opt_init, opt_update, get_params = optimizers.adam(adam_lr)

@partial(jax.jit, static_argnums=2)
def step(x, y, step_num, opt_state):
    value, grads = jax.value_and_grad(loss, argnums=-1)(x, y, get_params(opt_state))
    new_opt_state = opt_update(step_num, grads, opt_state)
    return value, new_opt_state

def train(xi):
    opt_state = opt_init(xi)
    for itr in trange(train_itr, leave=False):
        x, y = mini_batch(K)
        value, opt_state = step(x, y, itr, opt_state)
        print("\r"+"\rloss:{}".format(value), end=" ")
    return get_params(opt_state)

In [32]:
xi_init = jnp.ones((esize, bp_itr))
xi_trained = train(xi_init)

  0%|          | 0/500 [00:00<?, ?it/s]

loss:7.838134765625 18  

In [21]:
def get_num_of_error(xi, num_loop = 10):
    total_syms = num_loop * n * K
    error_syms = 0
    for i in range(num_loop):
        x, y = mini_batch(K)
        x_hat = BP_decoding(y, bp_itr, xi)
        error_syms += jnp.sum(jnp.sign(x_hat) != x)
    return error_syms, total_syms

In [22]:
error_syms, total_syms = get_num_of_error(xi_init)

In [23]:
print("total_syms = {}".format(total_syms))
print("error_syms = {}".format(error_syms))
print("symbols error rate = {}".format(error_syms/total_syms))

total_syms = 31000
error_syms = 14
symbols error rate = 0.00045161289745010436


In [33]:
error_syms, total_syms = get_num_of_error(xi_trained)

In [25]:
print("total_syms = {}".format(total_syms))
print("error_syms = {}".format(error_syms))
print("symbols error rate = {}".format(error_syms/total_syms))

total_syms = 31000
error_syms = 14
symbols error rate = 0.00045161289745010436
