# 高速累乗計算アルゴリズム
繰り返し二乗法による累乗計算の高速化

参考:
[高速累乗計算(python3)](https://qiita.com/b1ueskydragon/items/0b8e0c382d782423c6d3)

In [9]:
import sys
sys.setrecursionlimit(50000)

In [10]:
def show(func):
    def wrapper(n, i):
        print(n, i)
        return func(n, i)
    return wrapper

### 普通に実装
愚直な再帰
O(n)

In [11]:
@show
def pow_repeat(n, i):
    if i == 0:
        return 1
    else:
        return n * pow_repeat(n, i-1)

In [12]:
pow_repeat(2, 20)

2 20
2 19
2 18
2 17
2 16
2 15
2 14
2 13
2 12
2 11
2 10
2 9
2 8
2 7
2 6
2 5
2 4
2 3
2 2
2 1
2 0


1048576

### 掛け算を減らす
iが偶数の場合にnを2乗して、i//2とする
O(log(n))

In [13]:
@show
def pow_squere(n, i):
    if i == 0:
        return 1
    elif not (i & 1):
        return pow_squere(n**2, i//2)
    else:
        return n * pow_squere(n, i-1)

In [14]:
pow_squere(2, 20)

2 20
4 10
16 5
16 4
256 2
65536 1
65536 0


1048576

### スタックメモリ領域を減らす
指数を割り続ける方式

In [15]:
def pow_sq_not_rec(n, i):
    if i == 0:
        return 1
    K = 1
    while i > 1:
        print(K, n, i)
        if (i & 1):
            K *= n  # 奇数の場合に端数を一旦保存しておく
            n **= 2
            i = (i - 1) // 2
        else:
            n **= 2
            i //= 2
    
    return K * n

In [16]:
pow_sq_not_rec(2, 10)

1 2 10
1 4 5
4 16 2


1024

## 時間計測

In [17]:
## 愚直な実装
def pow_repeat(n, i):
    if i == 0:
        return 1
    else:
        return n * pow_repeat(n, i-1)

## 高速化
def pow_squere(n, i):
    if i == 0:
        return 1
    elif not (i & 1):
        return pow_squere(n**2, i//2)
    else:
        return n * pow_squere(n, i-1)

## ループによる実装
def pow_sq_not_rec(n, i):
    if i == 0:
        return 1
    K = 1
    while i > 1:
        if (i & 1):
            K *= n  # 奇数の場合に端数を一旦保存しておく
            n **= 2
            i = (i - 1) // 2
        else:
            n **= 2
            i //= 2
    
    return K * n

In [18]:
%timeit pow_repeat(3, 200)
%timeit pow_squere(3, 200)
%timeit pow_sq_not_rec(3, 200)
%timeit pow(3, 200)

160 µs ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
13.9 µs ± 23.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.91 µs ± 30.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [19]:
%timeit pow_repeat(2, 101)
%timeit pow_squere(2, 101)
%timeit pow_sq_not_rec(2, 101)
%timeit pow(2, 101)

75.9 µs ± 474 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
13.4 µs ± 333 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
8.53 µs ± 23.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1.58 µs ± 23.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [20]:
%timeit pow_repeat(7, 256)
%timeit pow_squere(7, 256)
%timeit pow_sq_not_rec(7, 256)
%timeit pow(7, 256)

215 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
14.5 µs ± 187 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11.8 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.13 µs ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## 結論
O(log(n))のアルゴリズムは強い
基本的にループの方が速い

# POW_MOD
整数のn累乗の剰余を高速に求める

In [23]:
def pow_mod(n, i, p):
    if i == 0:
        return 1 % p
    K = 1
    while i > 1:
        if (i & 1):
            K = K * n % p  # 奇数の場合に端数を一旦保存しておく
            n = n ** 2 % p
            i = (i - 1) // 2
        else:
            n = n ** 2 % p
            i //= 2
    
    return (K * n) % p

In [30]:
from random import randint

In [None]:
for _ in range(1000):
    n, i, p = randint(10, 30), randint(20, 40), randint(100, 9999)
    