# NTT revised

## 问题

$N = 2^{26}$ 个元素的 list 进行 NTT，元素本身是一个 256bits 的大数。

GPU 计算。

经常8-9个同时计算。

想分开在各个GPU里面去局部计算。

## revised

我们的大数应该都在一个很大的数域中完成，可以假定其中有 $N$ 阶原根，也就是 $x ^ N = 1$ 但 $x ^ {? < N} \neq 1$。

In [57]:
# Secp256k1 as an example

p = 0xFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F

# 数域是 $F_p$

找个小点的数域做test

In [58]:
p = 257

# |F_p| = 256

# 找到一个 256 阶原根

root = 2
while root < p:
    if pow(root, p - 1, p) == 1:
        more = False
        for i in range(1, p - 1):
            if pow(root, i, p) == 1:
                more = True
                break
        if not more:
            break
    root += 1

assert root < p
print(f"256-th root of unity: {root}")

powers = []
for i in range(1, p):
    powers.append(pow(root, i, p))

print(powers)


256-th root of unity: 3
[3, 9, 27, 81, 243, 215, 131, 136, 151, 196, 74, 222, 152, 199, 83, 249, 233, 185, 41, 123, 112, 79, 237, 197, 77, 231, 179, 23, 69, 207, 107, 64, 192, 62, 186, 44, 132, 139, 160, 223, 155, 208, 110, 73, 219, 143, 172, 2, 6, 18, 54, 162, 229, 173, 5, 15, 45, 135, 148, 187, 47, 141, 166, 241, 209, 113, 82, 246, 224, 158, 217, 137, 154, 205, 101, 46, 138, 157, 214, 128, 127, 124, 115, 88, 7, 21, 63, 189, 53, 159, 220, 146, 181, 29, 87, 4, 12, 36, 108, 67, 201, 89, 10, 30, 90, 13, 39, 117, 94, 25, 75, 225, 161, 226, 164, 235, 191, 59, 177, 17, 51, 153, 202, 92, 19, 57, 171, 256, 254, 248, 230, 176, 14, 42, 126, 121, 106, 61, 183, 35, 105, 58, 174, 8, 24, 72, 216, 134, 145, 178, 20, 60, 180, 26, 78, 234, 188, 50, 150, 193, 65, 195, 71, 213, 125, 118, 97, 34, 102, 49, 147, 184, 38, 114, 85, 255, 251, 239, 203, 95, 28, 84, 252, 242, 212, 122, 109, 70, 210, 116, 91, 16, 48, 144, 175, 11, 33, 99, 40, 120, 103, 52, 156, 211, 119, 100, 43, 129, 130, 133, 142, 169, 250, 23

Revise the brute force NTT and INTT

In [59]:
list_to_ntt = [x % p for x in range(0, 32)]
root_of_32 = pow(root, 8, p)
reverse_of_root = pow(root_of_32, -1, p)

print(f'list_to_ntt = {list_to_ntt}')

def ntt(list, root):
    ret = []
    n = len(list)
    assert pow(root, n, p) == 1
    for j in range(n):
        sum = 0
        root_for_this_loop = pow(root, j, p)
        for i in range(n):
            sum += list[i] * pow(root_for_this_loop, i, p)
            sum %= p
        ret.append(sum)
    return ret

def intt(list, root):
    ret = ntt(list, root)
    n = len(list)
    ret = [x * pow(n, -1, p) % p for x in ret]
    return ret

ntt_result = ntt(list_to_ntt, root_of_32)
print(f'ntt_result = {ntt_result}')

intt_result = intt(ntt_result, reverse_of_root)
print(f'intt_result = {intt_result}')

assert intt_result == list_to_ntt


list_to_ntt = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
ntt_result = [239, 84, 25, 168, 180, 190, 32, 39, 240, 227, 150, 115, 182, 72, 147, 2, 241, 223, 78, 153, 43, 110, 75, 255, 242, 186, 193, 35, 45, 57, 200, 141]
intt_result = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]


Why NTT is important?

In [60]:
a_list = [x * 3 % p for x in range(0, 32)]
b_list = [x * 4 % p for x in range(0, 32)]

def conv_mul(a, b, root):
    a_ntt = ntt(a, root)
    b_ntt = ntt(b, root)
    ret = [a_ntt[i] * b_ntt[i] % p for i in range(len(a))]
    ret = intt(ret, pow(root, -1, p))
    return ret

def conv_simple(a, b):
    assert len(a) == len(b);
    n = len(a);
    ret = [0 for i in range(n)]
    for i in range(n):
        for j in range(n):
            k = (i + j) % n
            ret[k] += a[i] * b[j]
            ret[k] %= p
    return ret

conv_mul_result = conv_mul(a_list, b_list, root_of_32)
print(f'conv_mul_result = {conv_mul_result}')

conv_simple_result = conv_simple(a_list, b_list)
print(f'conv_simple_result = {conv_simple_result}')

assert conv_mul_result == conv_simple_result

conv_mul_result = [194, 108, 152, 69, 116, 36, 86, 9, 62, 245, 44, 230, 32, 221, 26, 218, 26, 221, 32, 230, 44, 245, 62, 9, 86, 36, 116, 69, 152, 108, 194, 153]
conv_simple_result = [194, 108, 152, 69, 116, 36, 86, 9, 62, 245, 44, 230, 32, 221, 26, 218, 26, 221, 32, 230, 44, 245, 62, 9, 86, 36, 116, 69, 152, 108, 194, 153]


NTT matters only when it is faster thant brute force version.

In [61]:
def faster_ntt(list, root):
    n = len(list)
    assert pow(root, n, p) == 1
    if n == 1:
        return list
    assert n % 2 == 0
    even = faster_ntt(list[0::2], pow(root, 2, p))
    odd = faster_ntt(list[1::2], pow(root, 2, p))
    ret = [0 for i in range(n)]
    half = n // 2
    for i in range(half):
        ret[i] = (even[i] + pow(root, i, p) * odd[i]) % p
        ret[i + half] = (even[i] - pow(root, i, p) * odd[i] + p) % p
    return ret

faster_ntt_result = faster_ntt(list_to_ntt, root_of_32)
print(f'faster_ntt_result = {faster_ntt_result}')

assert faster_ntt_result == ntt_result

faster_ntt_result = [239, 84, 25, 168, 180, 190, 32, 39, 240, 227, 150, 115, 182, 72, 147, 2, 241, 223, 78, 153, 43, 110, 75, 255, 242, 186, 193, 35, 45, 57, 200, 141]


This is $O(N \lg N)$ algorithm, should be much faster than the $O(N^2)$ algorithm


What if the last 3/4 are all zero?

e.g. 

list_to_ntt[32] = [1, 2, 3, 4, 5, 6, 7, 8, 0, ..., 0]

BTW, can we first convert recursive version to loop version?

In [62]:
def reverse_of_bits(v, width):
    ret = 0
    for i in range(width):
        ret |= ((v >> i) & 1) << (width - i - 1)
    return ret

assert reverse_of_bits(0b1010, 4) == 0b0101
assert reverse_of_bits(0b1010, 3) == 0b010

def loop_version_ntt(list, root):
    n = len(list)
    assert pow(root, n, p) == 1
    assert n > 1 and n & (n - 1) == 0 # power of 2, in fact it is 32 here

    bits_width_of_n = 0
    n2 = n
    while n2 > 0:
        bits_width_of_n += 1
        n2 >>= 1
    bits_width_of_n -= 1
    
    assert n == 1 << bits_width_of_n

    # rearrange the list
    new_list = [0 for i in range(n)]
    for i in range(n):
        new_list[i] = list[reverse_of_bits(i, bits_width_of_n)]
    list = new_list

    for level in range(1, bits_width_of_n + 1):
        half = 1 << (level - 1)
        for i in range(0, n, half * 2):
            for j in range(half):
                a = list[i + j]
                b = list[i + j + half]
                list[i + j] = (a + pow(root, j << (bits_width_of_n - level)) * b) % p
                list[i + j + half] = (a - pow(root, j << (bits_width_of_n - level)) * b + p) % p
    
    return list


loop_ntt_result = loop_version_ntt(list_to_ntt, root_of_32)
print(f'loop_ntt_result = {loop_ntt_result}')

assert loop_ntt_result == ntt_result


loop_ntt_result = [239, 84, 25, 168, 180, 190, 32, 39, 240, 227, 150, 115, 182, 72, 147, 2, 241, 223, 78, 153, 43, 110, 75, 255, 242, 186, 193, 35, 45, 57, 200, 141]
