<a href="https://colab.research.google.com/github/chenshuo/notes/blob/master/notebooks/ReedSolomonErasureCodes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reed-Solomon Erasure Codes in Python

By Shuo Chen (chenshuo_at_chenshuo.com)

https://github.com/chenshuo/notes/blob/master/notebooks/ReedSolomonErasureCodes.ipynb

https://colab.research.google.com/github/chenshuo/notes/blob/master/notebooks/ReedSolomonErasureCodes.ipynb

https://chenshuo.com/notes/reed-solomon/

## Polynomial codes

Reference
* https://tomverbeure.github.io/2022/08/07/Reed-Solomon.html
* https://innovation.vivint.com/introduction-to-reed-solomon-bc264d0794f8

### Introductory example: Shannon's birthday

Polynomial codes in real field.

We want to transmit 6 digits $\boldsymbol d = [d_0, d_1, d_2, \ldots, d_5]$,
using 9 numbers $\boldsymbol c = [c_0, c_1, \ldots, c_8]$, and being able to tolarance any 3 losses (erasures).

In [1]:
import math
import numpy as np

In [35]:
d = np.array([1, 6, 0, 4, 3, 0])
d.size

6

First, we make a polynomial $p(x) = d_0 + d_1 x +  d_2 x^2 + \cdots + d_5 x^5$.

Then evaluate its values at 9 points: $ x \in {0, 1, 2, 3, \ldots, 8}$

$\boldsymbol c = [p(0), p(1), \ldots, p(8)]$

$$V\cdot\boldsymbol d = \boldsymbol c$$

$$\begin{bmatrix}
1 & 0^1 & 0^2 & 0^3 & 0^4 & 0^5 \\
1 & 1^1 & 1^2 & 1^3 & 1^4 & 1^5 \\
1 & 2^1 & 2^2 & 2^3 & 2^4 & 2^5 \\
1 & 3^1 & 3^2 & 3^3 & 3^4 & 3^5 \\
1 & 4^1 & 4^2 & 4^3 & 4^4 & 4^5 \\
1 & 5^1 & 5^2 & 5^3 & 5^4 & 5^5 \\
1 & 6^1 & 6^2 & 6^3 & 6^4 & 6^5 \\
1 & 7^1 & 7^2 & 7^3 & 7^4 & 7^5 \\
1 & 8^1 & 8^2 & 8^3 & 8^4 & 8^5 \\
\end{bmatrix}\cdot\begin{pmatrix}
d_0\\
d_1\\
d_2\\
d_3\\
d_4\\
d_5\\
\end{pmatrix}=\begin{pmatrix}
1\\14\\93\\370\\1049\\2406\\4789\\8618\\14385\\
\end{pmatrix}$$

In [36]:
x = np.arange(9)
print('x =', x)

V = np.vander(x, N=d.size, increasing=True)
print(V)

x = [0 1 2 3 4 5 6 7 8]
[[    1     0     0     0     0     0]
 [    1     1     1     1     1     1]
 [    1     2     4     8    16    32]
 [    1     3     9    27    81   243]
 [    1     4    16    64   256  1024]
 [    1     5    25   125   625  3125]
 [    1     6    36   216  1296  7776]
 [    1     7    49   343  2401 16807]
 [    1     8    64   512  4096 32768]]


In [37]:
c = V @ d
print(c)

[    1    14    93   370  1049  2406  4789  8618 14385]


**Recover from loss**

Assuming $c_0, c_4, c_8$ are lost in transmission, we could solve linear equations using remaining 6 symbols in codeword.

$$\begin{bmatrix}
1 & 1^1 & 1^2 & 1^3 & 1^4 & 1^5 \\
1 & 2^1 & 2^2 & 2^3 & 2^4 & 2^5 \\
1 & 3^1 & 3^2 & 3^3 & 3^4 & 3^5 \\
1 & 5^1 & 5^2 & 5^3 & 5^4 & 5^5 \\
1 & 6^1 & 6^2 & 6^3 & 6^4 & 6^5 \\
1 & 7^1 & 7^2 & 7^3 & 7^4 & 7^5 \\
\end{bmatrix}\cdot\begin{pmatrix}
d_0\\
d_1\\
d_2\\
d_3\\
d_4\\
d_5\\
\end{pmatrix}=\begin{pmatrix}
14\\93\\370\\2406\\4789\\8618\\
\end{pmatrix}$$

In [40]:
recv = np.concatenate((c[1:4], c[5:8]))
print(recv)

M = np.concatenate((V[1:4], V[5:8]))
print(M)

[  14   93  370 2406 4789 8618]
[[    1     1     1     1     1     1]
 [    1     2     4     8    16    32]
 [    1     3     9    27    81   243]
 [    1     5    25   125   625  3125]
 [    1     6    36   216  1296  7776]
 [    1     7    49   343  2401 16807]]


In [42]:
msg = np.linalg.solve(M, recv)
print(msg.reshape(-1, 1))

recover = np.rint(msg).astype(int)
print(recover)

[[ 1.00000000e+00]
 [ 6.00000000e+00]
 [ 1.34292577e-12]
 [ 4.00000000e+00]
 [ 3.00000000e+00]
 [-3.78956126e-15]]
[1 6 0 4 3 0]


In [46]:
assert (recover == d).all()

## Galois field arithmetic

Reference
* https://www.kernel.org/pub/linux/kernel/people/hpa/raid6.pdf
* https://research.swtch.com/field

Here we implement one particular $GF(256)$, with primitive element $\alpha = 2$ and irreducible polynomial `0x11d`.

Future optimizations:
* https://www.academia.edu/89610567/Fast_software_implementation_of_finite_field_operations
* https://stackoverflow.com/a/30460874

In [31]:
class GF8bit:
  bits = 8
  order = 2 ** bits
  max = order - 1

  def __init__(self):
    self._GenerateTables()

  @classmethod
  def _Check(cls, x):
    assert isinstance(x, int)
    assert 0 <= x and x < cls.order

  @classmethod
  def _CheckNZ(cls, x):
    cls._Check(x)
    assert x != 0

  @classmethod
  def _Mul2(cls, x):
    cls._Check(x)
    y = (x & 0x7f) << 1
    if x & 0x80:
      y ^= 0x1d
    cls._Check(y)
    return y

  def _GenerateTables(self):
    self.exp = [None] * (self.max * 2)
    self.log = [None] * self.order
    x = 1
    for i in range(self.max):
      self.exp[i] = x
      self.exp[i+self.max] = x
      self.log[x] = i
      x = self._Mul2(x)

    assert x == 1  # alpha ** 255 == 1
    assert 0 not in self.exp
    assert self.exp[255] == 1
    assert self.exp[:255] == self.exp[255:]
    assert len(set(self.exp)) == self.order - 1
    assert self.log[0] is None
    assert self.log[1] == 0
    assert None not in self.log[1:]
    assert max(self.log[1:]) == 254

  def Add(self, a, b):
    self._Check(a)
    self._Check(b)
    return a ^ b

  def Sub(self, a, b):
    return self.Add(a, b)

  def Exp(self, x):
    self._Check(x)
    return self.exp[x]

  def Log(self, x):
    self._CheckNZ(x)
    return self.log[x]

  def Mul(self, a, b):
    self._Check(a)
    self._Check(b)
    if a == 0 or b == 0:
      return 0
    return self.exp[self.log[a] + self.log[b]]

  def Inv(self, a):
    self._CheckNZ(a)
    return self.exp[self.order-1 - self.log[a]]

  def Div(self, a, b):
    """Mul(a, Inv(b))"""
    self._Check(a)
    self._CheckNZ(b)
    if a == 0:
      return 0
    return self.exp[self.log[a] + self.order-1 - self.log[b]]

  def Dot(self, a, b):
    assert len(a) == len(b)
    result = 0
    for i in range(len(a)):
      # result += a[i] * b[i]
      result = self.Add(result, self.Mul(a[i], b[i]))
    return result

gf = GF8bit()

In [5]:
gf.Mul(3, 3)

5

In [7]:
gf.Mul(5, 5)

17

### Cross verification

In [None]:
!pip install pyfinite galois

In [20]:
import pyfinite.ffield
import time

In [34]:
ff = pyfinite.ffield.FField(8)
start = time.time()

# GF = galois.GF(2**8)
for a in range(256):
  if a != 0:
    assert gf.Inv(a) == ff.Inverse(a), gf.Inverse(a)
    # assert gf.Inv(a) == GF8(1) / GF8(a)
  for b in range(256):
    assert gf.Mul(a, b) == ff.Multiply(a, b)
    # assert gf.Mul(a, b) == GF8(a) * GF8(b)
    if b != 0:
      assert gf.Div(a, b) == ff.Divide(a, b)
      # assert gf.Div(a, b) == GF8(a) / GF8(b)
end = time.time()
print('%.3f sec' % (end - start))

0.163 sec


## Original view

Easy to understand, just basic linear algebra, but over finite field. Good for erasure codes, used in some storage systems.

### Galois matrix

In [88]:
class GaloisMatrix:
  def __init__(self, rows, cols):
    self.gf = GF8bit()
    self.rows = rows
    self.cols = cols
    self.mat = [[0] * cols for r in range(rows)]
    assert len(self.mat) == rows

  def __str__(self):
    return '\n'.join(' '.join('%4d' % x for x in row) for row in self.mat)

  def Dot(self, vec:list) -> list:
    assert self.cols == len(vec)
    result = [None] * self.rows
    for i in range(self.rows):
      result[i] = self.gf.Dot(self.mat[i], vec)
    return result

  def SubRows(self, row_idx: list):
    mat = GaloisMatrix(len(row_idx), self.cols)
    for i, r in enumerate(row_idx):
      mat.mat[i] = self.mat[r][:]
    return mat

  def Equals(self, other: GaloisMatrix):
    if self.rows != other.rows or self.cols != other.cols:
      return False
    return self.mat == other.mat
    
  @classmethod
  def Vander(cls, rows, cols, start=0):
    mat = GaloisMatrix(rows, cols)
    for i in range(rows):
      a = i + start
      x = 1
      row = mat.mat[i]
      for j in range(cols):
        row[j] = x
        x = gf.Mul(x, a)
    return mat

  @classmethod
  def Identity(cls, n):
    mat = GaloisMatrix(n, n)
    for i in range(n):
      mat.mat[i][i] = 1
    return mat

### Vandermonde matrix

In [48]:
V = GaloisMatrix.Vander(9, 6)
print(V)

   1    0    0    0    0    0
   1    1    1    1    1    1
   1    2    4    8   16   32
   1    3    5   15   17   51
   1    4   16   64   29  116
   1    5   17   85   28  108
   1    6   20  120   13   46
   1    7   21  107   12   36
   1    8   64   58  205   38


In [49]:
V.Dot([1, 6, 0, 4, 3, 0])

[1, 0, 29, 4, 35, 114, 255, 182, 147]

In [50]:
V.Dot([1, 1, 4, 5, 1, 4])

[1, 4, 171, 248, 200, 237, 103, 20, 147]

In [51]:
msg = [ord(ch) for ch in "Hello!"]
V.Dot(msg)

[72, 99, 130, 180, 48, 38, 244, 255, 13]

**Cross check with** `pyfinite`

In [52]:
import pyfinite.rs_code

In [53]:
rs96 = pyfinite.rs_code.RSCode(n=9, k=6, log2FieldSize=8, systematic=0)
rs96

<RSCode (n,k) = (9, 6)  over GF(2^8)
<matrix
   1   0   0   0   0   0
   1   1   1   1   1   1
   1   2   4   8  16  32
   1   3   5  15  17  51
   1   4  16  64  29 116
   1   5  17  85  28 108
   1   6  20 120  13  46
   1   7  21 107  12  36
   1   8  64  58 205  38>
>

In [54]:
rs96.Encode([1, 6, 0, 4, 3, 0])

[1, 0, 29, 4, 35, 114, 255, 182, 147]

In [55]:
rs96.Encode([1, 1, 4, 5, 1, 4])

[1, 4, 171, 248, 200, 237, 103, 20, 147]

In [56]:
rs96.Encode(msg)

[72, 99, 130, 180, 48, 38, 244, 255, 13]

### Systematic encoding

In [57]:
V = GaloisMatrix.Vander(9, 6)
print(V)

   1    0    0    0    0    0
   1    1    1    1    1    1
   1    2    4    8   16   32
   1    3    5   15   17   51
   1    4   16   64   29  116
   1    5   17   85   28  108
   1    6   20  120   13   46
   1    7   21  107   12   36
   1    8   64   58  205   38


In [66]:
def Systematic(V):
  assert V.rows >= V.cols

  def SwapColumn(V, i, j):
    for r in range(V.rows):
      V.mat[r][i], V.mat[r][j] = V.mat[r][j], V.mat[r][i]

  def MulColumn(V, col, x):
    for i in range(V.rows):
      V.mat[i][col] = gf.Mul(V.mat[i][col], x)

  def AddColumn(V, to, src, factor):
    # TODO: skip some rows
    for i in range(V.rows):
      V.mat[i][to] ^= V.gf.Mul(factor, V.mat[i][src])

  for i in range(V.cols):
    row = V.mat[i]
    if row[i] == 0:
      # find and swap with column j
      for j in range(i+1, V.cols):
        if row[j] != 0:
          SwapColumn(V, i, j)
          # print("Swapped column %d and %d" % (i, j))
          break
    assert row[i] != 0
    if row[i] != 1:
      # print('  row', i, row[i], gf.Inv(row[i]))
      # unity pivot
      MulColumn(V, i, gf.Inv(row[i]))
    assert row[i] == 1
    # Elminate other columns
    for j in range(V.cols):
      if i != j and row[j] != 0:
        factor = V.mat[i][j]
        AddColumn(V, j, i, factor)
    #print("Row", i , 'done')
    #print(V)

In [96]:
V = GaloisMatrix.Vander(9, 6)
Systematic(V)
print(V)

   1    0    0    0    0    0
   0    1    0    0    0    0
   0    0    1    0    0    0
   0    0    0    1    0    0
   0    0    0    0    1    0
   0    0    0    0    0    1
   7    6    5    4    3    2
   6    7    4    5    2    3
 160  223  223  183  254  232


In [100]:
I = GaloisMatrix.Identity(6)
assert I.Equals(V.SubRows(list(range(6))))

In [68]:
rs96s = pyfinite.rs_code.RSCode(n=9, k=6, log2FieldSize=8, systematic=1)
rs96s

<RSCode (n,k) = (9, 6)  over GF(2^8)
<matrix
   1   0   0   0   0   0
   0   1   0   0   0   0
   0   0   1   0   0   0
   0   0   0   1   0   0
   0   0   0   0   1   0
   0   0   0   0   0   1
   7   6   5   4   3   2
   6   7   4   5   2   3
 160 223 223 183 254 232>
>

In [69]:
V.Dot([1, 6, 0, 4, 3, 0])

[1, 6, 0, 4, 3, 0, 6, 6, 161]

In [82]:
rs96s.Encode([1, 6, 0, 4, 3, 0])

[1, 6, 0, 4, 3, 0, 6, 6, 161]

**Encoder class**

In [89]:
class RSEncoder:
  """Reed-Solomon erasure codes systematic encoder."""

  def __init__(self, n, k):
    self.n = n
    self.k = k
    V = GaloisMatrix.Vander(n, k)
    Systematic(V)
    top = V.SubRows(list(range(k)))
    assert top.Equals(GaloisMatrix.Identity(k))
    self.G = V.SubRows(list(range(k, n)))
    assert self.G.cols == self.k
    assert self.G.rows == self.n - self.k

  def Encode(self, message: list):
    """Return the codeword of n bytes."""
    parity = self.Parity(message)
    return message + parity

  def Parity(self, message: list):
    """Return the parity of (n-k) bytes."""
    assert len(message) == self.k
    return self.G.Dot(message)

In [90]:
rs96 = RSEncoder(9, 6)
print(rs96.G)

   7    6    5    4    3    2
   6    7    4    5    2    3
 160  223  223  183  254  232


In [83]:
msg = [1, 6, 0, 4, 3, 0]

print('Message: ', msg)
print('Parity:  ', rs96.Parity(msg))
print('Codeword:', rs96.Encode(msg))

Message:  [1, 6, 0, 4, 3, 0]
Parity:   [6, 6, 161]
Codeword: [1, 6, 0, 4, 3, 0, 6, 6, 161]


In [84]:
msg = [1, 1, 4, 5, 1, 4]

print('Message: ', msg)
print('Parity:  ', rs96.Parity(msg))
print('Codeword:', rs96.Encode(msg))

Message:  [1, 1, 4, 5, 1, 4]
Parity:   [10, 14, 12]
Codeword: [1, 1, 4, 5, 1, 4, 10, 14, 12]


In [85]:
msg = [ord(ch) for ch in "Hello!"]

print('Message: ', msg)
print('Parity:  ', rs96.Parity(msg))
print('Codeword:', rs96.Encode(msg))

Message:  [72, 101, 108, 108, 111, 33]
Parity:   [57, 90, 253]
Codeword: [72, 101, 108, 108, 111, 33, 57, 90, 253]


In [86]:
rs96s.Encode(msg)

[72, 101, 108, 108, 111, 33, 57, 90, 253]

### Erasure decoding

In [87]:
def InverseMatrix(m : GaloisMatrix):
  assert m.rows == m.cols
  result = GaloisMatrix.Identity(m.rows)
  for i in range(m.rows):
    row = m.mat[i]
    if row[i] == 0:
      # find and swap with row j
      for j in range(i+1, m.rows):
        if m.mat[j][i] != 0:
          m.SwapRow(i, j)
          result.SwapRow(i, j)
          row = m.mat[i]
          break
        # print("Swapped column %d and %d" % (i, j))
    assert row[i] != 0
    if row[i] != 1:
      inv = m.gf.Inv(row[i])
      m.MulRow(i, inv)
      result.MulRow(i, inv)
    assert row[i] == 1
    # Eliminate other rows
    for j in range(m.rows):
      if i != j and m.mat[j][i] != 0:
        factor = m.mat[j][i]
        m.AddRow(j, i, factor)
        result.AddRow(j, i, factor)
  assert m.IsIdentity()
  return result

## BCH view

most common used, both erasure and error corrections

### Generator polynomial

### Systematic encoding (10 LOC)