Skip to content

Commit

Permalink
Always dynamically compile ufuncs
Browse files Browse the repository at this point in the history
This means ufuncs don't reference other JIT functions, but other specific ufuncs. This provides more flexibility, and runtime speed. This comes at a slight increase in JIT compilation time during the first invocation.

JIT functions also reference specific ufuncs rather than general JIT
arithmetic functions. This allows for using lookup tables in many common
JIT functions, such as dense polynomial arithmetic.
  • Loading branch information
mhostetter committed May 3, 2022
1 parent c0d306b commit f44002d
Show file tree
Hide file tree
Showing 17 changed files with 1,380 additions and 1,621 deletions.
96 changes: 64 additions & 32 deletions galois/_codes/_bch.py
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import math
from typing import Tuple, List, Optional, Union, Type, overload
from typing import Tuple, List, Optional, Union, Type, Any, overload
from typing_extensions import Literal

import numba
Expand Down Expand Up @@ -256,21 +256,6 @@ def __init__(
self._is_primitive = True
self._is_narrow_sense = True

# Pre-compile the arithmetic methods
self._add_jit = self.field._func_calculate("add")
self._subtract_jit = self.field._func_calculate("subtract")
self._multiply_jit = self.field._func_calculate("multiply")
self._reciprocal_jit = self.field._func_calculate("reciprocal")
self._power_jit = self.field._func_calculate("power")

# Pre-compile the JIT functions
self._berlekamp_massey_jit = _lfsr.jit_calculate("berlekamp_massey")
self._poly_roots_jit = self.field._function("poly_roots")
self._poly_divmod_jit = GF2._function("poly_divmod")

# Pre-compile the JIT decoder
self._decode_jit = numba.jit(DECODE_CALCULATE_SIG.signature, nopython=True, cache=True)(_decode_calculate)

def __repr__(self) -> str:
"""
A terse representation of the BCH code.
Expand Down Expand Up @@ -790,17 +775,20 @@ def decode(self, codeword, errors=False):
syndrome = codeword.view(self.field) @ self.H[:,-ns:].T

if self.field.ufunc_mode != "python-calculate":
dec_codeword = self._decode_jit(codeword.astype(np.int64), syndrome.astype(np.int64), self.t, int(self.field.primitive_element), self._add_jit, self._subtract_jit, self._multiply_jit, self._reciprocal_jit, self._power_jit, self._berlekamp_massey_jit, self._poly_roots_jit, self.field.characteristic, self.field.degree, int(self.field.irreducible_poly))
N_errors = dec_codeword[:, -1]

if self.systematic:
message = dec_codeword[:, 0:ks]
else:
message, _ = GF2._poly_divmod(dec_codeword[:, 0:self.n].view(GF2), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
codeword_ = codeword.astype(np.int64)
syndrome_ = syndrome.astype(np.int64)
y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element))
else:
codeword_ = codeword.view(np.ndarray)
syndrome_ = syndrome.view(np.ndarray)
y = function("decode", self.field)(codeword_, syndrome_, self.t, int(self.field.primitive_element))

if self.systematic:
message = y[:, 0:ks]
else:
raise NotImplementedError("BCH codes haven't been implemented for extremely large Galois fields.")
message, _ = GF2._poly_divmod(y[:, 0:ns].view(GF2), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
N_errors = y[:, -1]

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down Expand Up @@ -990,20 +978,64 @@ def is_narrow_sense(self) -> bool:


###############################################################################
# JIT-compiled implementation of the specified functions
# JIT functions
###############################################################################

DECODE_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._UNARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, _lfsr.BERLEKAMP_MASSEY_CALCULATE_SIG, FieldArray._POLY_ROOTS_CALCULATE_SIG, int64, int64, int64))
POLY_ROOTS: Any
BERLEKAMP_MASSEY: Any


def function(name: str, field: Type[FieldArray]):
"""
Returns a function implemented over the given field and ufunc mode.
"""
if field.ufunc_mode != "python-calculate":
return function_jit(name, field)
else:
return function_python(name, field)


def function_jit(name: str, field: Type[FieldArray]):
"""
Returns a JIT-compiled function implemented over the given field.
"""
key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element))
if key not in function_jit.cache:
# Set the globals once before JIT compiling the function
eval(f"set_{name}_globals")(field)
sig = eval(f"{name.upper()}_SIG")
function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit"))

return function_jit.cache[key]

function_jit.cache = {}


def function_python(name: str, field: Type[FieldArray]):
"""
Returns a pure-Python function.
"""
# Set the globals each time before invoking the pure-Python ufunc
eval(f"set_{name}_globals")(field)
return eval(f"{name}_jit")


def set_decode_globals(field: Type[FieldArray]):
global POLY_ROOTS, BERLEKAMP_MASSEY
POLY_ROOTS = field._function("poly_roots")
BERLEKAMP_MASSEY = _lfsr.function("berlekamp_massey", field)


def _decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, BERLEKAMP_MASSEY, POLY_ROOTS, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY): # pragma: no cover
DECODE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64))


def decode_jit(codeword, syndrome, t, primitive_element): # pragma: no cover
"""
References
----------
* Lin, S. and Costello, D. Error Control Coding. Section 7.4.
"""
args = CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY
dtype = codeword.dtype

N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)

Expand All @@ -1024,7 +1056,7 @@ def _decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, M

# Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1], ADD, SUBTRACT, MULTIPLY, RECIPROCAL, *args)[::-1]
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1]
v = sigma_rev.size - 1 # The number of errors

if v > t:
Expand All @@ -1033,7 +1065,7 @@ def _decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, M

# Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x)
degrees = np.arange(sigma_rev.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element, ADD, MULTIPLY, POWER, *args)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element)
beta = results[0,:] # The roots of σ(x^-v)
error_locations = results[1,:] # The roots as powers of the primitive element α

Expand Down
133 changes: 89 additions & 44 deletions galois/_codes/_reed_solomon.py
Expand Up @@ -3,7 +3,7 @@
"""
from __future__ import annotations

from typing import Tuple, Optional, Union, Type, overload
from typing import Tuple, Optional, Union, Type, Any, overload
from typing_extensions import Literal

import numba
Expand Down Expand Up @@ -144,23 +144,6 @@ def __init__(

self._is_narrow_sense = c == 1

# Pre-compile the arithmetic methods
self._add_jit = self.field._func_calculate("add")
self._subtract_jit = self.field._func_calculate("subtract")
self._multiply_jit = self.field._func_calculate("multiply")
self._reciprocal_jit = self.field._func_calculate("reciprocal")
self._power_jit = self.field._func_calculate("power")

# Pre-compile the JIT functions
self._berlekamp_massey_jit = _lfsr.jit_calculate("berlekamp_massey")
self._poly_divmod_jit = self.field._function("poly_divmod")
self._poly_roots_jit = self.field._function("poly_roots")
self._poly_eval_jit = self.field._function("poly_evaluate")
self._convolve_jit = self.field._function("convolve")

# Pre-compile the JIT decoder
self._decode_jit = numba.jit(DECODE_CALCULATE_SIG.signature, nopython=True, cache=True)(_decode_calculate)

def __repr__(self) -> str:
"""
A terse representation of the Reed-Solomon code.
Expand Down Expand Up @@ -684,17 +667,20 @@ def decode(self, codeword, errors=False):
syndrome = codeword.view(self.field) @ self.H[:,-ns:].T

if self.field.ufunc_mode != "python-calculate":
dec_codeword = self._decode_jit(codeword.astype(np.int64), syndrome.astype(np.int64), self.c, self.t, int(self.field.primitive_element), self._add_jit, self._subtract_jit, self._multiply_jit, self._reciprocal_jit, self._power_jit, self._berlekamp_massey_jit, self._poly_roots_jit, self._poly_eval_jit, self._convolve_jit, self.field.characteristic, self.field.degree, int(self.field.irreducible_poly))
N_errors = dec_codeword[:, -1]

if self.systematic:
message = dec_codeword[:, 0:ks]
else:
message, _ = self.field._poly_divmod(dec_codeword[:, 0:self.n].view(self.field), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
codeword_ = codeword.astype(np.int64)
syndrome_ = syndrome.astype(np.int64)
y = function("decode", self.field)(codeword_, syndrome_, self.c, self.t, int(self.field.primitive_element))
else:
codeword_ = codeword.view(np.ndarray)
syndrome_ = syndrome.view(np.ndarray)
y = function("decode", self.field)(codeword_, syndrome_, self.c, self.t, int(self.field.primitive_element))

if self.systematic:
message = y[:, 0:ks]
else:
raise NotImplementedError("Reed-Solomon codes haven't been implemented for extremely large Galois fields.")
message, _ = self.field._poly_divmod(y[:, 0:ns].view(self.field), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
N_errors = y[:, -1]

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down Expand Up @@ -892,23 +878,82 @@ def is_narrow_sense(self) -> bool:


###############################################################################
# JIT-compiled implementation of the specified functions
# JIT functions
###############################################################################

DECODE_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, int64, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._UNARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, _lfsr.BERLEKAMP_MASSEY_CALCULATE_SIG, FieldArray._POLY_ROOTS_CALCULATE_SIG, FieldArray._POLY_EVALUATE_CALCULATE_SIG, FieldArray._CONVOLVE_CALCULATE_SIG, int64, int64, int64))
CHARACTERISTIC: int
ORDER: int
SUBTRACT = np.subtract
MULTIPLY = np.multiply
RECIPROCAL = np.reciprocal
POWER = np.power
CONVOLVE = np.convolve
POLY_ROOTS: Any
POLY_EVALUATE: Any
BERLEKAMP_MASSEY: Any


def _decode_calculate(codeword, syndrome, c, t, primitive_element, ADD, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, BERLEKAMP_MASSEY, POLY_ROOTS, POLY_EVAL, CONVOLVE, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY): # pragma: no cover
def function(name: str, field: Type[FieldArray]):
"""
Returns a function implemented over the given field and ufunc mode.
"""
if field.ufunc_mode != "python-calculate":
return function_jit(name, field)
else:
return function_python(name, field)


def function_jit(name: str, field: Type[FieldArray]):
"""
Returns a JIT-compiled function implemented over the given field.
"""
key = (name, field.characteristic, field.degree, int(field.irreducible_poly), int(field.primitive_element))
if key not in function_jit.cache:
# Set the globals once before JIT compiling the function
eval(f"set_{name}_globals")(field)
sig = eval(f"{name.upper()}_SIG")
function_jit.cache[key] = numba.jit(sig.signature, nopython=True)(eval(f"{name}_jit"))

return function_jit.cache[key]

function_jit.cache = {}


def function_python(name: str, field: Type[FieldArray]):
"""
Returns a pure-Python function.
"""
# Set the globals each time before invoking the pure-Python ufunc
eval(f"set_{name}_globals")(field)
return eval(f"{name}_jit")


def set_decode_globals(field: Type[FieldArray]):
global CHARACTERISTIC, ORDER, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, CONVOLVE, POLY_ROOTS, POLY_EVALUATE, BERLEKAMP_MASSEY
CHARACTERISTIC = field.characteristic
ORDER = field.order
SUBTRACT = field._ufunc("subtract")
MULTIPLY = field._ufunc("multiply")
RECIPROCAL = field._ufunc("reciprocal")
POWER = field._ufunc("power")
CONVOLVE = field._function("convolve")
POLY_ROOTS = field._function("poly_roots")
POLY_EVALUATE = field._function("poly_evaluate")
BERLEKAMP_MASSEY = _lfsr.function("berlekamp_massey", field)


DECODE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, int64))

def decode_jit(codeword, syndrome, c, t, primitive_element): # pragma: no cover
"""
References
----------
* Lin, S. and Costello, D. Error Control Coding. Section 7.4.
"""
args = CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY
dtype = codeword.dtype

N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)
design_n = CHARACTERISTIC**DEGREE - 1 # The designed codeword size
design_n = ORDER - 1 # The designed codeword size

# The last column of the returned decoded codeword is the number of corrected errors
dec_codeword = np.zeros((N, n + 1), dtype=dtype)
Expand All @@ -927,7 +972,7 @@ def _decode_calculate(codeword, syndrome, c, t, primitive_element, ADD, SUBTRACT

# Compute the error-locator polynomial σ(x)
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma = BERLEKAMP_MASSEY(syndrome[i,:], ADD, SUBTRACT, MULTIPLY, RECIPROCAL, *args)[::-1]
sigma = BERLEKAMP_MASSEY(syndrome[i,:])[::-1]
v = sigma.size - 1 # The number of errors, which is the degree of the error-locator polynomial

if v > t:
Expand All @@ -936,7 +981,7 @@ def _decode_calculate(codeword, syndrome, c, t, primitive_element, ADD, SUBTRACT

# Compute βi^-1, the roots of σ(x)
degrees = np.arange(sigma.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma, primitive_element, ADD, MULTIPLY, POWER, *args)
results = POLY_ROOTS(degrees, sigma, primitive_element)
beta_inv = results[0,:] # The roots βi^-1 of σ(x)
error_locations_inv = results[1,:] # The roots βi^-1 as powers of the primitive element α
error_locations = -error_locations_inv % design_n # The error locations as degrees of c(x)
Expand All @@ -955,21 +1000,21 @@ def _decode_calculate(codeword, syndrome, c, t, primitive_element, ADD, SUBTRACT
sigma_prime = np.zeros(v, dtype=dtype)
for j in range(v):
degree = v - j
sigma_prime[j] = MULTIPLY(degree % CHARACTERISTIC, sigma[j], *args) # Scalar multiplication
sigma_prime[j] = MULTIPLY(degree % CHARACTERISTIC, sigma[j]) # Scalar multiplication

# The error-value evaluator polynomial Z0(x) = S0*σ0 + (S1*σ0 + S0*σ1)*x + (S2*σ0 + S1*σ1 + S0*σ2)*x^2 + ...
# with degree v-1
Z0 = CONVOLVE(sigma[-v:], syndrome[i,0:v][::-1], ADD, MULTIPLY, *args)[-v:]
Z0 = CONVOLVE(sigma[-v:], syndrome[i,0:v][::-1])[-v:]

# The error value δi = -1 * βi^(1-c) * Z0(βi^-1) / σ'(βi^-1)
for j in range(v):
beta_i = POWER(beta_inv[j], c - 1, *args)
Z0_i = POLY_EVAL(Z0, np.array([beta_inv[j]], dtype=dtype), ADD, MULTIPLY, *args)[0] # NOTE: poly_eval() expects a 1-D array of values
sigma_prime_i = POLY_EVAL(sigma_prime, np.array([beta_inv[j]], dtype=dtype), ADD, MULTIPLY, *args)[0] # NOTE: poly_eval() expects a 1-D array of values
delta_i = MULTIPLY(beta_i, Z0_i, *args)
delta_i = MULTIPLY(delta_i, RECIPROCAL(sigma_prime_i, *args), *args)
delta_i = SUBTRACT(0, delta_i, *args)
dec_codeword[i, n - 1 - error_locations[j]] = SUBTRACT(dec_codeword[i, n - 1 - error_locations[j]], delta_i, *args)
beta_i = POWER(beta_inv[j], c - 1)
Z0_i = POLY_EVALUATE(Z0, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values
sigma_prime_i = POLY_EVALUATE(sigma_prime, np.array([beta_inv[j]], dtype=dtype))[0] # NOTE: poly_eval() expects a 1-D array of values
delta_i = MULTIPLY(beta_i, Z0_i)
delta_i = MULTIPLY(delta_i, RECIPROCAL(sigma_prime_i))
delta_i = SUBTRACT(0, delta_i)
dec_codeword[i, n - 1 - error_locations[j]] = SUBTRACT(dec_codeword[i, n - 1 - error_locations[j]], delta_i)

dec_codeword[i,-1] = v # The number of corrected errors

Expand Down

0 comments on commit f44002d

Please sign in to comment.