In [3]:
import math
import cmath
import numpy as np
import numpy.polynomial.polynomial as polynomial
import fft

In [7]:

def split_poly(p):
    assert(len(p) % 2 == 0)
    even = []
    odd = []
    for i in range(0, len(p), 2):
        even.append(p[i])
        odd.append(p[i + 1])
    return [even, odd]

def align_polynomials(p1 : list, p2 :list):
    n1 = len(p1)
    n2 = len(p2)
    req_deg = max(n1, n2)
    make_required_degree(p1, req_deg - 1)
    make_required_degree(p2, req_deg - 1)

def make_required_degree(p : list, n : int):
    while len(p) <= n:
        p.append(0)

def next_power_of_2(n):
    if n <= 0:
        return 0
    return pow(2, math.ceil(math.log(n)/math.log(2)))

def nth_roots_1(n):
    poly = [0] * (n + 1)
    poly[0] = 1.0
    poly[-1] = -1.0
    return np.roots(poly)

def get_w_1(n):
    assert(n > 0)
    roots = nth_roots_1(n)
    first_root = roots[0]
    for root in roots:
        root_polar = cmath.polar(root)
        first_root_polar = cmath.polar(first_root)
        if (root_polar[1] > 0 and root_polar[1] < first_root_polar[1]):
            first_root = root
    return first_root


In [12]:
def fft_1(p, w):
    if (cmath.isclose(w, complex(1,0))):
        return [p[0]]
    n = len(p)
    [even, odd] = split_poly(p)
    s1 = fft(even, w * w)
    s2 = fft(odd, w * w)
    r = [0] * n
    for j in range(0, (n // 2)):
        r[j] = s1[j] + (w ** j) * s2[j]
        r[j + (n // 2)]  = s1[j] - (w ** j) * s2[j]
    return r

def inv_fft(points, w):
    poly = fft(points, (1 / w))
    n = len(poly)
    res = [coef * (1 / n) for coef in poly]
    return res

In [13]:
f = [1,2,3,4,5,6,7]
w = get_w_1(8)
fft_1(f, w)

NameError: name 'nth_roots' is not defined

In [5]:
def product(poly1, poly2):
    res_deg = (len(poly1) - 1) + (len(poly2) - 1)
    n = next_power_of_2(res_deg + 1)
    make_required_degree(poly1, n - 1)
    make_required_degree(poly2, n - 1)
    w = get_w(n)
    points1 = fft(poly1, w)
    points2 = fft(poly2, w)
    points_res = [ y1 * y2 for (y1, y2) in zip(points1, points2)]
    poly_res = inv_fft(points_res, w)
    return poly_res

In [6]:
def to_integer_poly(p):
    real_poly = [np.real(coef) for coef in p]
    return np.rint(real_poly).astype(int)

In [53]:
def random_poly(deg, low, high):
    poly = np.random.randint(low, high, deg)
    while poly[deg - 1] == 0:
        poly[deg - 1] = np.random.randint(low, high, 1)
    return list(poly)

def test_multiplication(deg1, deg2):
    low = -10
    high = 10
    poly1 = random_poly(deg1, low, high)
    poly2 = random_poly(deg2, low, high)

    fft_poly = list(to_integer_poly(product(poly1.copy(), poly2.copy())))

    poly1_rev = poly1.copy()
    poly1_rev.reverse()
    poly2_rev = poly2.copy()
    poly2_rev.reverse()
    np_poly = list(np.polymul(poly1_rev, poly2_rev))
    np_poly.reverse()
    align_polynomials(fft_poly, np_poly)
    is_eq = (fft_poly == np_poly)
    return [is_eq, poly1, poly2, fft_poly, np_poly]

In [52]:
def test():
    for i in range(1, 10):
        for j in range(1, 10):
            is_eq, poly1, poly2, fft_poly, np_poly = test_multiplication(i, j)
            if not is_eq:
                print("Error!")
                print("poly1 = ", poly1)
                print("poly2 = ", poly2)
                print("fft_poly = ", fft_poly)
                print("np_poly = ", np_poly)
                return
    print("All cases passed successfully")
    return


In [54]:
test()

All cases passed successfully
