In [2]:
# Euclid's LEmma for finding GCD(a,b)
# GCD(a,b) = GCD(b, a mod b)
# O(log N) -> N = min(a,b)
# O(1) space complexity

def euclidean_gcd(a: int,b: int) -> int:
    while b != 0:
        a,b = b, a%b
    return abs(a)

print("GCD(48,180): {}".format(euclidean_gcd(48,180)))

GCD(48,180): 12


In [10]:
# type:ignore
# extended euclidean algorithm
# finding x,y such that ax + by = GCD(a,b)

def extended_gcd(a: int, b: int) -> int:
    if b==0:
        return a,1,0
    g,x1,y1 = extended_gcd(b, a%b)
    x = y1
    y = x1 - (a // b) * y1

    return g,x,y

g,x,y = extended_gcd(48, 180)

print("GCD: {}".format(g))
print("x: {}".format(x))
print("y: {}".format(y))
print("check: 48x + 180y = {}".format(48*x+180*y))

GCD: 12
x: 4
y: -1
check: 48x + 180y = 12


In [11]:
# type:ignore
# modular inverse
# find x such that a*x === 1 (mod m)

def mod_inverse(a: int, m: int) -> int:
    g,x,_ = extended_gcd(a,m)
    if g != 1:
        raise ValueError("inverse does not exist")
    return x%m

print("inverse of 3 mod 11 : {}".format(mod_inverse(3,11)))

inverse of 3 mod 11 : 4


In [19]:
# type: ignore
# primality test using trail division upto sqrt of n

def is_prime(n: int) -> bool:
    if n <= 1:
        return False
    if n <= 3:
        return True
    if n % 2 == 0:
        return False
    
    i = 3
    while i * i <= n:
        if n % i == 0:
            return False
        i += 2
    return True

print("97:  {}".format(is_prime(97)))
print("1:   {}".format(is_prime(1)))
print("0:   {}".format(is_prime(0)))
print("15:  {}".format(is_prime(15)))
print("122: {}".format(is_prime(122)))
print("3:   {}".format(is_prime(3)))

97:  True
1:   False
0:   False
15:  False
122: False
3:   True


In [20]:
# sieve of eratosthanes algo to generate all primes <= n

def sieve(n: int):
    is_prime = [True] * (n + 1)
    is_prime[0:2] = [False, False]

    for p in range(2, int(n**0.5) + 1):
        if is_prime[p]:
            for k in range(p*p, n + 1, p):
                is_prime[k] = False
    
    return [i for i, prime in enumerate(is_prime) if prime]

print("all primes < 30: {}".format(sieve(30)))

all primes < 30: [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]


In [21]:
# type: ignore
# fast modular exponentiation
# computer a^b mod m

def mod_pow(a:int, b:int, m:int) -> int:
    res = 1
    a %= m
    while b > 0:
        if b & 1:
            res = (res * a) % m
        a = (a**2) % m
        b >>= 1
    return res

print("3^13 mod 17 = {}".format(mod_pow(3,13,17)))


3^13 mod 17 = 12
