Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ufunc and function dispatchers #354

Merged
merged 9 commits into from May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/test_field_arithmetic.py
Expand Up @@ -81,15 +81,15 @@ class Test_GF257_calculate(Base):
N = 100_000


@pytest.mark.benchmark(group="GF(3^5) Array Arithmetic: shape=(1_000,), ufunc_mode='jit-lookup'")
@pytest.mark.benchmark(group="GF(3^5) Array Arithmetic: shape=(100_000,), ufunc_mode='jit-lookup'")
class Test_GF3_5_lookup(Base):
order = 3**5
ufunc_mode = "jit-lookup"
N = 1_000
N = 100_000


@pytest.mark.benchmark(group="GF(3^5) Array Arithmetic: shape=(1_000,), ufunc_mode='jit-calculate'")
@pytest.mark.benchmark(group="GF(3^5) Array Arithmetic: shape=(10_000,), ufunc_mode='jit-calculate'")
class Test_GF3_5_calculate(Base):
order = 3**5
ufunc_mode = "jit-calculate"
N = 1_000
N = 10_000
5 changes: 4 additions & 1 deletion docs/conf.py
Expand Up @@ -184,7 +184,10 @@
html_domain_indices = True

# Sphinx Immaterial API config
include_object_description_fields_in_toc = False
object_description_options = [
("py:function", dict(include_fields_in_toc=False)),
("py:method", dict(include_fields_in_toc=False)),
]


# -- Extension configuration -------------------------------------------------
Expand Down
10 changes: 6 additions & 4 deletions docs/performance/benchmarks.rst
Expand Up @@ -129,23 +129,25 @@ advised to pass extra arguments to format the display `--benchmark-columns=min,m
Compare with a previous benchmark
---------------------------------

If you would like to compare the performance impacts of a branch, first run a benchmark on `master` using the `--benchmark-save` option.
If you would like to compare the performance impact of a branch, first run a benchmark on `master` using the `--benchmark-save` option.
This will save the file `.benchmarks/0001_master.json`.

.. code-block::

$ git checkout master
$ python3 -m pytest benchmarks/test_field_arithmetic.py --benchmark-save=master --benchmark-columns=min,max,mean,stddev,median --benchmark-sort=name

Next, checkout your branch and run another benchmark. This will save the file `.benchmarks/0001_branch.json`.
Next, run a benchmark on the branch under test while comparing against the benchmark from `master`.

.. code-block::

$ git checkout branch
$ python3 -m pytest benchmarks/test_field_arithmetic.py --benchmark-save=branch --benchmark-columns=min,max,mean,stddev,median --benchmark-sort=name
$ python3 -m pytest benchmarks/test_field_arithmetic.py --benchmark-compare=0001_master --benchmark-columns=min,max,mean,stddev,median --benchmark-sort=name

And finally, compare the two benchmarks.
Or, save a benchmark run from `branch` and compare it explicitly against the one from `master`. This benchmark run will save the file `.benchmarks/0001_branch.json`.

.. code-block::

$ git checkout branch
$ python3 -m pytest benchmarks/test_field_arithmetic.py --benchmark-save=branch --benchmark-columns=min,max,mean,stddev,median --benchmark-sort=name
$ python3 -m pytest-benchmark compare 0001_master 0001_branch
2 changes: 1 addition & 1 deletion docs/requirements.txt
@@ -1,5 +1,5 @@
sphinx>=3
git+https://github.com/jbms/sphinx-immaterial@0e0dcc8c9cc88f2d160ec80a38410547c8a904d3
git+https://github.com/jbms/sphinx-immaterial@98d0ab6b618fb9a9fe6a0f33ac6875bbbbb05ad2
myst-parser
sphinx-design
sphinxcontrib-details-directive
Expand Down
172 changes: 85 additions & 87 deletions galois/_codes/_bch.py
Expand Up @@ -11,10 +11,12 @@
from numba import int64
import numpy as np

from .. import _lfsr
from .._domains._function import Function
from .._fields import Field, FieldArray, GF2 # pylint: disable=unused-import
from .._lfsr import berlekamp_massey_jit
from .._overrides import set_module
from .._polys import Poly, matlab_primitive_poly
from .._polys._dense import roots_jit, divmod_jit
from .._prime import factors
from ..typing import PolyLike

Expand Down Expand Up @@ -256,21 +258,6 @@ def __init__(
self._is_primitive = True
self._is_narrow_sense = True

# Pre-compile the arithmetic methods
self._add_jit = self.field._func_calculate("add")
self._subtract_jit = self.field._func_calculate("subtract")
self._multiply_jit = self.field._func_calculate("multiply")
self._reciprocal_jit = self.field._func_calculate("reciprocal")
self._power_jit = self.field._func_calculate("power")

# Pre-compile the JIT functions
self._berlekamp_massey_jit = _lfsr.jit_calculate("berlekamp_massey")
self._poly_roots_jit = self.field._function("poly_roots")
self._poly_divmod_jit = GF2._function("poly_divmod")

# Pre-compile the JIT decoder
self._decode_jit = numba.jit(DECODE_CALCULATE_SIG.signature, nopython=True, cache=True)(_decode_calculate)

def __repr__(self) -> str:
"""
A terse representation of the BCH code.
Expand Down Expand Up @@ -593,7 +580,7 @@ def detect(self, codeword: Union[np.ndarray, "GF2"]) -> Union[np.bool_, np.ndarr
return detected

@overload
def decode(self, codeword: Union[np.ndarray, "GF2"], errors: Literal[False]) -> Union[np.ndarray, "GF2"]:
def decode(self, codeword: Union[np.ndarray, "GF2"], errors: Literal[False] = False) -> Union[np.ndarray, "GF2"]:
...
@overload
def decode(self, codeword: Union[np.ndarray, "GF2"], errors: Literal[True]) -> Tuple[Union[np.ndarray, "GF2"], Union[np.integer, np.ndarray]]:
Expand Down Expand Up @@ -779,7 +766,6 @@ def decode(self, codeword, errors=False):
raise ValueError(f"For a non-systematic code, argument `codeword` must be a 1-D or 2-D array with last dimension equal to {self.n}, not shape {codeword.shape}.")

codeword_1d = codeword.ndim == 1
dtype = codeword.dtype
ns = codeword.shape[-1] # The number of input codeword bits (could be less than self.n for shortened codes)
ks = self.k - (self.n - ns) # The equivalent number of input message bits (could be less than self.k for shortened codes)

Expand All @@ -789,18 +775,14 @@ def decode(self, codeword, errors=False):
# Compute the syndrome by matrix multiplying with the parity-check matrix
syndrome = codeword.view(self.field) @ self.H[:,-ns:].T

if self.field.ufunc_mode != "python-calculate":
dec_codeword = self._decode_jit(codeword.astype(np.int64), syndrome.astype(np.int64), self.t, int(self.field.primitive_element), self._add_jit, self._subtract_jit, self._multiply_jit, self._reciprocal_jit, self._power_jit, self._berlekamp_massey_jit, self._poly_roots_jit, self.field.characteristic, self.field.degree, int(self.field.irreducible_poly))
N_errors = dec_codeword[:, -1]

if self.systematic:
message = dec_codeword[:, 0:ks]
else:
message, _ = GF2._poly_divmod(dec_codeword[:, 0:self.n].view(GF2), self.generator_poly.coeffs)
message = message.astype(dtype).view(type(codeword))
# Invoke the JIT compiled function
dec_codeword, N_errors = decode_jit(self.field)(codeword, syndrome, self.t, int(self.field.primitive_element))

if self.systematic:
message = dec_codeword[:, 0:ks]
else:
raise NotImplementedError("BCH codes haven't been implemented for extremely large Galois fields.")
message, _ = divmod_jit(GF2)(dec_codeword[:, 0:ns].view(GF2), self.generator_poly.coeffs)
message = message.view(type(codeword)) # TODO: Remove this

if codeword_1d:
message, N_errors = message[0,:], N_errors[0]
Expand Down Expand Up @@ -989,67 +971,83 @@ def is_narrow_sense(self) -> bool:
return self._is_narrow_sense


###############################################################################
# JIT-compiled implementation of the specified functions
###############################################################################

DECODE_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, FieldArray._UNARY_CALCULATE_SIG, FieldArray._BINARY_CALCULATE_SIG, _lfsr.BERLEKAMP_MASSEY_CALCULATE_SIG, FieldArray._POLY_ROOTS_CALCULATE_SIG, int64, int64, int64))

def _decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, BERLEKAMP_MASSEY, POLY_ROOTS, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY): # pragma: no cover
class decode_jit(Function):
"""
Performs BCH decoding.

References
----------
* Lin, S. and Costello, D. Error Control Coding. Section 7.4.
"""
args = CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY
dtype = codeword.dtype

N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)

# The last column of the returned decoded codeword is the number of corrected errors
dec_codeword = np.zeros((N, n + 1), dtype=dtype)
dec_codeword[:, 0:n] = codeword[:,:]

for i in range(N):
if not np.all(syndrome[i,:] == 0):
# The syndrome vector is S = [S0, S1, ..., S2t-1]

# The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv,
# implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator
# value and ji is the error location.

# The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots
# of σ(x).

# Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1], ADD, SUBTRACT, MULTIPLY, RECIPROCAL, *args)[::-1]
v = sigma_rev.size - 1 # The number of errors

if v > t:
dec_codeword[i, -1] = -1
continue

# Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x)
degrees = np.arange(sigma_rev.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element, ADD, MULTIPLY, POWER, *args)
beta = results[0,:] # The roots of σ(x^-v)
error_locations = results[1,:] # The roots as powers of the primitive element α

if np.any(error_locations > n - 1):
# Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually
# more errors than alleged. Return failure to decode.
dec_codeword[i, -1] = -1
continue

if beta.size != v:
dec_codeword[i, -1] = -1
continue

for j in range(v):
# δi can only be 1
dec_codeword[i, n - 1 - error_locations[j]] ^= 1
dec_codeword[i, -1] = v # The number of corrected errors

return dec_codeword
def __call__(self, codeword, syndrome, t, primitive_element):
if self.field.ufunc_mode != "python-calculate":
y = self.jit(codeword.astype(np.int64), syndrome.astype(np.int64), t, primitive_element)
else:
y = self.python(codeword.view(np.ndarray), syndrome.view(np.ndarray), t, primitive_element)

dec_codeword, N_errors = y[:,0:-1], y[:,-1]
dec_codeword = dec_codeword.astype(codeword.dtype)
dec_codeword = dec_codeword.view(self.field)

return dec_codeword, N_errors

def set_globals(self):
# pylint: disable=global-variable-undefined
global POLY_ROOTS, BERLEKAMP_MASSEY
POLY_ROOTS = roots_jit(self.field).function
BERLEKAMP_MASSEY = berlekamp_massey_jit(self.field).function

_SIGNATURE = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64))

@staticmethod
def implementation(codeword, syndrome, t, primitive_element): # pragma: no cover
dtype = codeword.dtype
N = codeword.shape[0] # The number of codewords
n = codeword.shape[1] # The codeword size (could be less than the design n for shortened codes)

# The last column of the returned decoded codeword is the number of corrected errors
dec_codeword = np.zeros((N, n + 1), dtype=dtype)
dec_codeword[:, 0:n] = codeword[:,:]

for i in range(N):
if not np.all(syndrome[i,:] == 0):
# The syndrome vector is S = [S0, S1, ..., S2t-1]

# The error pattern is defined as the polynomial e(x) = e_j1*x^j1 + e_j2*x^j2 + ... for j1 to jv,
# implying there are v errors. And δi = e_ji is the i-th error value and βi = α^ji is the i-th error-locator
# value and ji is the error location.

# The error-locator polynomial σ(x) = (1 - β1*x)(1 - β2*x)...(1 - βv*x) where βi are the inverse of the roots
# of σ(x).

# Compute the error-locator polynomial's v-reversal σ(x^-v), since the syndrome is passed in backwards
# TODO: Re-evaluate these equations since changing BMA to return characteristic polynomial, not feedback polynomial
sigma_rev = BERLEKAMP_MASSEY(syndrome[i,::-1])[::-1]
v = sigma_rev.size - 1 # The number of errors

if v > t:
dec_codeword[i, -1] = -1
continue

# Compute βi, the roots of σ(x^-v) which are the inverse roots of σ(x)
degrees = np.arange(sigma_rev.size - 1, -1, -1)
results = POLY_ROOTS(degrees, sigma_rev, primitive_element)
beta = results[0,:] # The roots of σ(x^-v)
error_locations = results[1,:] # The roots as powers of the primitive element α

if np.any(error_locations > n - 1):
# Indicates there are "errors" in the zero-ed portion of a shortened code, which indicates there are actually
# more errors than alleged. Return failure to decode.
dec_codeword[i, -1] = -1
continue

if beta.size != v:
dec_codeword[i, -1] = -1
continue

for j in range(v):
# δi can only be 1
dec_codeword[i, n - 1 - error_locations[j]] ^= 1
dec_codeword[i, -1] = v # The number of corrected errors

return dec_codeword