In [1]:
from sympy import *   
import numpy as np    
from math import log
from pandas import *
from math import ceil
import math
x, y, z = symbols('x,y,z')


In [3]:
def get_irreducible(deg = 8):
	"""
	get_irreducible brute forces over all polynomials of 
	degree deg and attempts to factorize.	
	"""
	for i in range(1,1<<(deg-1)):
		f = 0
		for j in range(deg-1):
			if((i&(1<<j))):
				f += x**j

		f += x**deg
		f = f.as_poly(domain=FF(2))
		if f.is_irreducible:
			return f


def is_primitive(f, irred):
	"""
	is_primitive returns true if f generates all 
	non-zero elements in F_2[x]/irred
	"""
	freq = {}
	cur = f
	cnt = 0
	deg = irred.degree()
	for i in range(1<<deg):
		coeffs = tuple(cur.all_coeffs())
		if coeffs not in freq.keys():
			cnt += 1
		freq[coeffs] = 1
		
		cur *= f
		_,cur = div(cur,irred, domain=FF(2))

	return cnt == (1<<deg)-1


def find_primitive_element(irred):
	"""
	find_primitive_element brute forces over polynomials
	and checks whether they're primitive for F_2[x]/irred
	"""
	deg = irred.degree()
	for i in range(1,1<<deg):
		f = Poly(0,x, domain = FF(2))
		for j in range(deg):
			if((i&(1<<j))):
				f += x**j

		f = f.as_poly(domain=FF(2))

		if is_primitive(f, irred):
			return f


def fast_pow(cur, k, irred):
	"""
	computes cur^k mod irred in O(log k)
	"""
	if k == 0:
		return Poly(1,x,domain=FF(2))
	half = fast_pow(cur,k//2,irred)
	if k%2 == 0:
		res = half*half
		_,res = div(res,irred,domain=FF(2))
		return res
	else:
		res =  cur*half*half
		_,res = div(res,irred,domain=FF(2))
		return res


In [5]:
def gen_vandermonde(n,k, beta, irred):
	"""
	Outputs an nxk vandermonde matrix
	"""
	M = []
	for i in range(n):
		row = []
		cur = fast_pow(beta,i,irred)
		for j in range(k):
			row.append(fast_pow(cur,j,irred))
		M.append(row)

	return M

def mat_vec_prod(M,v,irred):
	"""
	Multiply a matrix and vector of polynomials 
	"""
	assert(len(M[0]) == len(v))
	res = [0 for i in range(len(M))]

	for i in range(len(M)):
		for j in range(len(v)):
			res[i] += M[i][j]*v[j]
			_,res[i] = div(res[i],irred,domain=FF(2))
	return res


In [9]:
print(
"""
#################################################################
# Q3 RS decoding using the Welch Berlekamp Algorithm            #
#################################################################
"""
)

n = int(input("Enter n "))
k = int(input("Enter k "))
siz = int(log(n,2))+1
f = get_irreducible(siz)
beta = find_primitive_element(f)

message = [	Poly(0,x,domain=FF(2)) for i in range(k)]

non_zero = int(input("Enter the number of non-zero coefficients "))
for i in range(non_zero):
	j = int(input("Enter power of X "))
	assert(j<len(message))
	v = int(input("Enter p, where coefficient in the form beta^p (beta is the primitive element) "))
	message[j] = fast_pow(beta,v,f)

print(
"""
####################
# Message Vector   #
####################
"""
)
print(np.array(message)[:,None])

G = gen_vandermonde(n,k,beta,f)
codeword = mat_vec_prod(G,message,f)
print(
"""
####################
# Codeword         #
####################
"""
)
print(np.array(codeword)[:,None])

dmin = n-k+1
t = (dmin-1)//2
e = int(input("Enter the number of errors(at most {}): ".format(t)))
assert(e <= t)
for i in range(e):
	j = int(input("Enter position "))
	assert(j<len(codeword))
	v = int(input("Enter p, used to derive a polynomial from F_2[x] of the form beta^p (beta is the primitive element) "))
	codeword[j] = fast_pow(beta,v,f)

print(
"""
####################
# Received Vector  #
####################
"""
)

print(np.array(codeword)[:,None])


#################################################################
# Q3 RS decoding using the Welch Berlekamp Algorithm            #
#################################################################

Enter n 10
Enter k 2
Enter the number of non-zero coefficients 2
Enter power of X 0
Enter p, where coefficient in the form beta^p (beta is the primitive element) 1
Enter power of X 1
Enter p, where coefficient in the form beta^p (beta is the primitive element) 2

####################
# Message Vector   #
####################

[[Poly(x, x, modulus=2)]
 [Poly(x**2, x, modulus=2)]]

####################
# Codeword         #
####################

[[Poly(x**2 + x, x, modulus=2)]
 [Poly(x**3 + x, x, modulus=2)]
 [Poly(1, x, modulus=2)]
 [Poly(x**2, x, modulus=2)]
 [Poly(x**3 + x**2 + x, x, modulus=2)]
 [Poly(x**3 + 1, x, modulus=2)]
 [Poly(x**2 + x + 1, x, modulus=2)]
 [Poly(x**3, x, modulus=2)]
 [Poly(x**2 + 1, x, modulus=2)]
 [Poly(x**3 + x**2, x, modulus=2)]]
Enter the number of errors(at mos

In [11]:
G1 = gen_vandermonde(n,t+1,beta,f)
G2 = gen_vandermonde(n,t+k,beta,f)

constraints = []
for i in range(n):
	row = []
	for j in range(t):
		res = G1[i][j]*codeword[i]
		_,res = div(res,f,domain=FF(2))
		row.append(res)

	for j in range(t+k):
		res = f-G2[i][j]
		_,res = div(res,f,domain=FF(2))
		row.append(res)

	res = G1[i][t]*codeword[i]
	_,res = div(res,f,domain=FF(2))
	res = f-res
	_,res = div(res,f,domain=FF(2))
	row.append(res)
	constraints.append(row)


def inv(a, irr_poly):
	# print("trying to invert", a)
	return invert(a,irr_poly)

def print_mat(a):
	copy = []
	for i in range(len(a)):
		copy_row=[]
		for j in range(len(a[0])):
			copy_row.append(a[i][j].as_expr())

		copy.append(copy_row)
	print(DataFrame(copy))

def GaussElim(A, irred, silent = False):
	"""
	Solves the system of linear equations given by A 
	(constant specified by the last column)
	over polynomials in F_2[x] mod irred
	"""

	mat = A
	row = [-1 for i in range(len(A[0]))]
	R = len(mat)
	C = len(mat[0])-1
	r = 0
	for c in range(C):
		if not silent:
			print("Eliminating Column {}\n".format(c))
			print_mat(mat)
			print("\n")
		k = r
		while k<R and mat[k][c] == Poly(0,x,domain=FF(2)):
			k+=1
		if k==R:
			if not silent:
				print("No pivot found")
			continue

		if not silent:
			print("Pivot: ", k, mat[k][c].as_expr())
		mat[k],mat[r] = mat[r],mat[k]

		mod_inv = inv(mat[r][c],irred)
		for i in range(R):
			if i != r:
				w = mat[i][c]*mod_inv
				_,w = div(w, irred, domain=FF(2))
				w = irred - w
				_,w = div(w, irred, domain=FF(2))

				for j in range(C+1):
					mat[i][j] = (mat[i][j] + mat[r][j] * w)
					_,mat[i][j] = div(mat[i][j],irred,domain=FF(2))

		row[c] = r
		r += 1


	if not silent:
		print("Final Constraint Matrix\n")
		print_mat(mat),
		print("\n")
	ans = [0 for i in range(C)]
	for i in range(C):
		r = row[i]
		if r == -1:
			ans[i] = Poly(0,x,domain=FF(2))
		else:
			ans[i] = (mat[r][C] * inv(mat[r][i],irred))
			_,ans[i] = div(ans[i],irred,domain=FF(2))

	return ans,row


eandn,_ = GaussElim(constraints,f)

ex = [eandn[i] for i in range(t)]
ex.append(Poly(1,x,domain=FF(2)))
nx = [eandn[i] for i in range(t,len(eandn))]

print(
"""
####################
# Error Polynomial #
####################
"""
)
print(np.array(ex)[:,None])

print(
"""
####################
# n(x) Polynomial  #
####################
"""
)
print(np.array(nx)[:,None])


Eliminating Column 0

              0             1                2                    3  4  \
0      x**2 + x      x**2 + x         x**2 + x             x**2 + x  1   
1      x**3 + x  x**2 + x + 1  x**3 + x**2 + x  x**3 + x**2 + x + 1  1   
2             1          x**2            x + 1          x**3 + x**2  1   
3          x**2      x**2 + x         x**2 + 1      x**3 + x**2 + x  1   
4          x**3  x**3 + x + 1  x**3 + x**2 + x                    1  1   
5      x**3 + 1         x + 1         x**3 + x             x**3 + 1  1   
6  x**2 + x + 1             x     x**3 + x + 1      x**3 + x**2 + 1  1   
7          x**3  x**2 + x + 1             x**2             x**3 + x  1   
8      x**2 + 1             x         x**3 + x                 x**2  1   
9   x**3 + x**2             1         x**3 + x                 x**3  1   

              5                    6                    7  \
0             1                    1                    1   
1             x                 x**2     

9         x**2 + x + 1             x**2      x**3 + x**2 + 1  


Pivot:  4 x**3 + x**2 + x
Eliminating Column 5

          0                1                2  3                4  \
0  x**2 + x                0                0  0                0   
1         0  x**3 + x**2 + 1                0  0                0   
2         0                0  x**3 + x**2 + 1  0                0   
3         0                0                0  x                0   
4         0                0                0  0  x**3 + x**2 + x   
5         0                0                0  0                0   
6         0                0                0  0                0   
7         0                0                0  0                0   
8         0                0                0  0                0   
9         0                0                0  0                0   

                     5                6                7                8  \
0          x**3 + x**2         x**3 + x           

In [12]:
def poly_div(u,v,irred):
	"""
	Returns q,r s.t. v = qu + r
	where u and v are polynomials in F_2[x] mod irred
	"""
	i = len(u)-1
	j = len(v)-1
	mx = [Poly(0,x,domain=FF(2)) for i in range(len(u)-len(v)+1)]
	while(j>=0 and v[j] == Poly(0,x,domain=FF(2))):
		j-=1
	if(j == 0):
		return 0,0
	while(i>=j):
		if(u[i] == Poly(0,x,domain=FF(2))):
			i-=1
			continue

		tmp = u[i]*inv(v[j],irred)
		_,tmp = div(tmp,irred,domain=FF(2))
		mx[i-j] = tmp

		for p in range(len(v)):
			tmp2 = tmp*v[p]
			_,tmp2 = div(tmp2,irred,domain=FF(2))
			tmp2 = irred - tmp2
			_,tmp2 = div(tmp2,irred,domain=FF(2))
			u[p+i-j] += tmp2
			_,u[p+i-j] = div(u[p+i-j],irred,domain=FF(2))

		i-=1

	return mx,u 

mx,rem = poly_div(nx,ex,f)


In [13]:
print(
"""
####################
# Decoded Message  #
####################
"""
)
print(np.array(mx)[:,None])

print(
"""
####################
# Original Message #
####################
"""
)
print(np.array(message)[:,None])



####################
# Decoded Message  #
####################

[[Poly(x, x, modulus=2)]
 [Poly(x**2, x, modulus=2)]]

####################
# Original Message #
####################

[[Poly(x, x, modulus=2)]
 [Poly(x**2, x, modulus=2)]]
