<a href="https://colab.research.google.com/github/kevinrchilders/computational-number-theory/blob/master/cryptography_chapter_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
# Fast powering algorithm, gcd algorithms, etc.

def binary(n):
  binary_repn = []
  if n > 1:
    binary_repn = binary(n // 2)
  binary_repn.append(n % 2)
  return binary_repn

def power(g, A, N):
  A = binary(A)
  total=1
  for i in range(len(A)):
    if A[len(A)-i-1]:
      total = (total*g) % N
    g = (g*g) % N
  return total

def gcd(a, b):
  return a if b==0 else gcd(b, a%b)

def extended_gcd(a, b):
  u, g, x, y = 1, a, 0, b
  while y != 0:
    q, t = g // y, g % y
    s = u - q*x
    u, g = x, y
    x, y = s, t
  v = (g - a*u) // b 
  return g, u, v

def inverse(a, p):
  return extended_gcd(a, p)[1]

def order(a, p):
  n = 1
  x = a
  while power(a, n, p) != 1:
    x = (x * a) % p
    n += 1
  return n

def is_primitive(a, p):
  return order(a, p) == p-1

def is_mrprime(n, trials=50):
  for i in range(trials):
    a = random.randint(1, n)
    if is_mrwitness(a, n):
      return False
  return True

def generate_prime(digits, attempts=100):
  N = 2 * 3 * 5 * 7
  for K in range(int(10**(digits)/N), int(10**(digits)/N)+attempts):
    if is_mrprime(N*K + 1):
      return N*K + 1
  print('No primes found. Try more attempts.')
  return None

def is_mrwitness(a, n):
  # If a and n have a common factor, then n is composite
  if gcd(a, n) != 1:
    return True

  # Write n-1 = 2^k*q with q odd
  k=0
  q=n-1
  while q%2 == 0:
    k += 1
    q = q//2
  
  # If a^q == 1 (mod n) then a is not a Miller-Rabin witness for n
  a = power(a, q, n)
  if a == 1:
    return False
  
  # If a^(2^iq) == -1 (mod n) for some i<k then a is not a Miller-Rabin witness for n
  for i in range(k):
    if a == n-1:
      return False
    a = power(a, 2, n)

  return True # Otherwise a is a Miller-Rabin witness for n

def pollard(N, a=2, maxn=1000000):
  for j in range(2, maxn):
    a = power(a, j, N)
    d = gcd(a-1, N)
    if d != 1 and d != N:
      return d
  print('Test failed, try a larger maxn.')
  return None

def find_primitive(p):
  a = 2
  while not is_primitive(a, p):
    a += 1
  return a

# Addition on elliptic curves

In [None]:
# An algorithm for adding points P and Q on an elliptic curve
# with Weierstrauss equation Y^2 = X^3 + A*X + B

def ec_add(P, Q, A, B):
  if P == 'O':
    return Q
  if Q == 'O':
    return P
  x1, y1 = P
  x2, y2 = Q
  if x1 == x2:
    if y1 == -1 * y2:
      return 'O' 
    slope = (3*x1*x1 + A)/2/y1
  else:
    slope = (y2 - y1)/(x2 - x1)
  x3 = slope**2 - x1 - x2
  y3 = slope*(x1 - x3) - y1
  return (x3, y3)

In [None]:
# A plot of Y^2 = X^3 - X + 6

A = -1
B = 6
X1 = np.linspace(-2,4,601)
X2 = np.flip(X1)[:-1]
X = np.concatenate((X2, X1))
Y1 = np.sqrt(X1**3 + A*X1 + B)
Y2 = -1*np.sqrt(X2**3 + A*X2 + B)
Y = np.concatenate((Y2, Y1))

plt.figure(figsize=(12,8))
plt.plot(X, Y)

In [None]:
# Some sample additions
O = 'O'
ec_add(O, O, A, B)

In [None]:
P = (-2, 0)
ec_add(P, O, A, B)

In [None]:
ec_add(O, P, A, B)

In [None]:
ec_add(P, P, A, B)

In [None]:
Q = (3, -1*np.sqrt(3**3 + A*3 + B))
sum = ec_add(P, Q, A, B)

In [None]:
# A plot to show that P + Q = sum
plt.figure(figsize=(12,8))
plt.plot(X, Y)
plt.plot(*zip(P, Q))
plt.plot([sum[0], sum[0]], [sum[1],-1*sum[1]])
plt.scatter(*zip(P, Q, sum))

In [None]:
# Check of associativity
R = (1, np.sqrt(1**3 + A*1 + B))
ec_add(P, ec_add(Q, R, A, B), A, B), ec_add(ec_add(P, Q, A, B), R, A, B)

# Elliptic curves over finite fields

In [None]:
# An algorithm for adding points P and Q on an elliptic curve
# with Weierstrauss equation Y^2 = X^3 + A*X + B (mod p)

def ecp_add(P, Q, A, B, p):
  if P == 'O':
    return Q
  if Q == 'O':
    return P
  x1, y1 = P
  x2, y2 = Q
  if x1 == x2:
    if y1 == (-1 * y2) % p:
      return 'O' 
    slope = ((3*x1*x1 + A)*inverse(2*y1,p)) % p
  else:
    slope = ((y2 - y1)*inverse(x2 - x1, p)) % p
  x3 = (slope**2 - x1 - x2) % p
  y3 = (slope*(x1 - x3) - y1) % p
  return (x3, y3)

In [None]:
# Y^2 = X^3 +3*X + 8 over F_13
p = 13
A = 3
B = 8

P = (1, 5)
Q = (1, 8)
ecp_add(O, O, A, B, p), ecp_add(P, O, A, B, p), ecp_add(P, Q, A, B, p)

In [None]:
# Here's a plot.  It looks strangely similar to a few points on a characteristic 0 elliptic curve.

E = [(1, 5), (1, 8), (2, 3), (2, 10), (9, 6), (9, 7), (12, 2), (12, 11), 'O']
plt.figure(figsize=(12,8))
plt.scatter(*zip(*E[:-1]))

In [None]:
# An algorithm for finding all points on a curve Y^2 = X^3 + A*X + B over F_p

def find_points(A, B, p):
  roots = {0:[0]}
  for i in range(1,(p+1)//2):
    roots[power(i, 2, p)] = [i, (-1*i)%p]
  E = []
  for x in range(p):
    y_squared = (x**3 + A*x + B) % p
    if y_squared in roots.keys():
      for y in roots[y_squared]:
        E = E + [(x, y)]
  return E + ['O']

In [None]:
find_points(3, 8, 13)

In [None]:
# With a bigger example, we can see that the above example was just a coincidence.

E = find_points(3, 8, 1009)
plt.figure(figsize=(12,8))
plt.scatter(*zip(*E[:-1]))