In [1]:
def multiply(f, g):
    h = [0] * (len(f) + len(g) -1)
    for i in range(len(f)):
        for j in range(len(g)):
            h[i+j] += f[i]*g[j]

    return h

f = [1, 2]
g = [1, 2, 3]
print(multiply(f, g))


[1, 4, 7, 6]


In [2]:
# FFTで多項式乗算
def multiply_fft(f, g):
    import numpy as np

    n = 1
    while n < len(f) + len(g) - 1:
        n *= 2
    n *= 2  # 余裕を持たせる

    F = np.fft.fft(f, n)
    G = np.fft.fft(g, n)
    H = F * G
    # 逆FFTを取る
    h = np.fft.ifft(H).real.round().astype(int)[:len(f)+len(g)-1]

    h = h.tolist()
    return h

f = [1, 2]
g = [1, 2, 3]
print(multiply_fft(f, g))

[1, 4, 7, 6]


## 浮動小数

In [3]:
# FFTで多項式乗算
def multiply_fft_2(f, g):
    import numpy as np

    n = 1
    while n < len(f) + len(g) - 1:
        n *= 2
    n *= 2  # 余裕を持たせる

    F = np.fft.fft(f, n)
    G = np.fft.fft(g, n)
    H = F * G
    # 逆FFTを取る
    h = np.fft.ifft(H)

    return h.real[:len(f) + len(g) - 1]

f = [1, 2]
g = [1, 2, 3]
print(multiply_fft_2(f, g))

[1. 4. 7. 6.]


In [4]:
## NTT(Number Theoretic Transform)で多項式乗算
def multiply_ntt(f, g):

    mod = 998244353
    root = 3

    n = 1
    while n < len(f) + len(g) - 1:
        n *= 2
    n *= 2  # 余裕を持たせる

    def ntt(a, invert=False):
        a = a + [0] * (n - len(a))
        j = 0
        for i in range(1, n):
            bit = n >> 1
            while j & bit:
                j ^= bit
                bit >>= 1
            j |= bit
            if i < j:
                a[i], a[j] = a[j], a[i]

        length = 2
        while length <= n:
            wlen = pow(root, (mod - 1) // length, mod)
            if invert:
                wlen = pow(wlen, mod - 2, mod)
            for i in range(0, n, length):
                w = 1
                for j in range(i, i + length // 2):
                    u = a[j]
                    v = a[j + length // 2] * w % mod
                    a[j] = (u + v) % mod
                    a[j + length // 2] = (u - v + mod) % mod
                    w = w * wlen % mod
            length <<= 1

        if invert:
            inv_n = pow(n, mod - 2, mod)
            for i in range(n):
                a[i] = a[i] * inv_n % mod

        return a

    F = ntt(f)
    G = ntt(g)
    H = [(F[i] * G[i]) % mod for i in range(n)]
    h = ntt(H, invert=True)[:len(f) + len(g) - 1]
    return h

f = [1, 2]
g = [1, 2, 3]
print(multiply_ntt(f, g))

[1, 4, 7, 6]
