From f0103bb1e350a6c7915790a1030585f1f268d149 Mon Sep 17 00:00:00 2001 From: Andrei Lapets Date: Sun, 24 Jul 2022 01:50:35 -0400 Subject: [PATCH] Refactor/simplify interpolation calculation. --- src/lagrange/lagrange.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/lagrange/lagrange.py b/src/lagrange/lagrange.py index a5f5356..3dc4c1d 100644 --- a/src/lagrange/lagrange.py +++ b/src/lagrange/lagrange.py @@ -4,8 +4,9 @@ from __future__ import annotations from functools import reduce from typing import Union, Optional, Sequence, Iterable -import collections.abc import doctest +import collections.abc +import itertools def _inv(a: int, prime: int) -> int: """ @@ -228,19 +229,16 @@ def interpolate( # Compute the value of each unique Lagrange basis polynomial at ``0``, # then sum them all up to get the resulting value at ``0``. return sum( - mul( - values[x], - reduce( - mul, - ( - # Extrapolate using the fact that *y* = ``1`` if - # ``x`` = ``x_known``, and *y* = ``0`` for the other - # known values in the domain. - div(0 - x_known, x - x_known) - for x_known in xs if x_known is not x - ), - 1 - ) + reduce( + mul, + itertools.chain([values[x]], ( + # Extrapolate using the fact that *y* = ``1`` if + # ``x`` = ``x_known``, and *y* = ``0`` for the other + # known values in the domain. + div(0 - x_known, x - x_known) + for x_known in xs if x_known is not x + )), + 1 ) for x in xs ) % modulus