# NTRU
### Implementing Standard NTRUEncrypt

This notebook contains everything to implement NTRU, and nothing more.

# Import Statements

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import itertools

In [2]:
def make_square(fill_color):
    """ This plots a square for the purpose of getting my attention if
        one of my functions does not pass its tests. fill_color should
        be a string, like 'red' or 'green' """
    plt.xlim(0,2)
    plt.ylim(0,2)
    plt.fill_between([0,2], 0, 2, color=fill_color)
    plt.yticks([])
    plt.xticks([])
    plt.show()

# Number Theory Functions

In [3]:
def EEA(a,b):
    """ Performs the Extended Euclidean Algorithm on a and b,
        returning (gcd, s, t)"""
    u = 1; g = a; x = 0; y = b
    while y != 0:
        q = g//y
        t = g % y
        s = u - q*x
        u = x
        g = y
        x = s
        y = t
    v = (g - a*u) / b
    return (g,u,v)

# testing
validated = True
if EEA(101,13) != (1,4, -31):
    validated = False
if EEA(4,6) != (2, -1, 1):
    validated = False
if not validated:
    make_square('red')

In [4]:
def modular_inverse(a, p):
    """ Returns the multiplicative inverse of a mod p, if it exists"""
    (g,u) = EEA(a,p)[:2]
    if g != 1:
        raise ValueError("{} is not relatively prime to {}".format(a,p))
    return u % p

# testing
validated = True
if modular_inverse(2,7) != 4:
    validated = False
if modular_inverse(5,6) != 5:
    validated = False
if modular_inverse(1,4) != 1:
    validated = False
if not validated:
    make_square('red')

In [5]:
def factor(N):
    """ Returns the factorization of N."""
    m = int(N)
    p = 2
    factors = []
    
    # iterate through each prime
    while p <= N//2:
        exp = 0
        
        # find all powers of the prime dividing N
        while m % p == 0:
            exp += 1
            m = m // p
        
        # if the prime divides N, add it to the list
        if exp > 0:
            factors.append((p,exp))
        p = next_prime(p)
    
    # if N is prime
    if len(factors) == 0:
        return [(N,1)]
    return factors

# Polynomials
We are viewing polynomials as lists. For example, $3x^3 + 2x - 1$ is `[-1, 2, 0, 3]`.
## Contents:
`to_string()`

`poly_format()`

`poly_add()`

`poly_subtract()`

`poly_mult()`

`poly_convolution()`

`poly_division()`

`poly_EEA()`

`reduce_mod()`

`scalar_mult()`

In [6]:
def poly_format(f):
    """ Removes all extraneous 0's of f."""
    while len(f) > 1 and f[-1] == 0:
        f = f[:-1]
    return list(f)

# testing
validated = True
if poly_format([1,2,3]) != [1,2,3]:
    validated = False
if poly_format([1,0,0,0]) != [1]:
    validated = False
if not validated:
    make_square('red')

In [7]:
def to_string(f, latex=False):
    """ f is a list representing the coefficients of a polynomial.
        This returns a string displaying f (with powers of x).
        If latex=True, $'s are put around the string."""
    if not f:
        return ""
    
    if latex:
        return "$" + to_string(f) + "$"
    
    # a constant polynomial
    if len(f) == 1 or poly_format(list(f)) == [0]:  # need this 2nd condition to handle a tuple of 0's
        return str(f[0])
    
    # don't write the constant zero
    if f[0] != 0:
        s = str(f[0]) + " + "
    else:
        s = ""
    
    if f[1] == 1:
        s += "x"
    elif f[1] != 0:
        s += str(f[1]) + "x"
    
    # if we have a linear polynomial, we are done
    if len(f) == 2:
        return s
        
    if len(s) > 0 and (s[-1].isalpha() or s[-1].isdigit()):
        s += " + "
    
    # turn all of the coefficients into strings
    for i in range(2, len(f) - 1):
        if f[i] != 0:
            if f[1] == 1:
                s += "x^" + str(i) + " + "
            else:
                s += str(f[i]) + "x^" + str(i) + " + "
        
    if f[-1] == 0:
        return s[:-3]
    if f[-1] == 1:
        s += "x^" + str(len(f) - 1)
    else:
        s += str(f[-1]) + "x^" + str(len(f) - 1)
    return s

# testing
validated = True
if to_string([]) != "":
    validated = False
elif to_string([5]) != "5":
    validated = False
elif to_string([1,0,2]) != "1 + 2x^2":
    validated = False
elif to_string([0,5]) != "5x":
    validated = False
elif to_string([0]) != "0":
    validated = False
elif to_string([0,1,0,1]) != "x + x^3":
    print(to_string([0,1,0,1]))
    validated = False
elif to_string([0,1]) != "x":
    validated = False
elif to_string([4,5,0]) != "4 + 5x":
    validated = False
elif to_string([0,0,1]) != "x^2":
    validated = False
elif to_string([0,0,0]) != "0":
    validated = False
elif to_string([4,3,2], latex=True) != "$4 + 3x + 2x^2$":
    validated = False
    
if not validated:
    make_square('red')

In [8]:
def poly_add(f, g, p=None, debug=False):
    """ Adds the polynomials f and g. Coefficients are added mod p if specified."""
    
    # make sure the lists have the same length
    if len(f) > len(g):
        g = list(g) + [0 for _ in range(len(f) - len(g))]
    else:
        f = list(f) + [0 for _ in range(len(g) - len(f))]
    
    # add the corresponding entries
    if not p:
        result = [f[i] + g[i] for i in range(len(f))]
    else:
        result = [(f[i] + g[i]) % p for i in range(len(f))]
    
    if debug:
        print("{} + {} = {}".format(to_string(f), to_string(g), to_string(result)))
    
    return poly_format(result)

# testing
validated = True
if poly_add([1], [2]) != [3]:
    validated=False
elif poly_add([0,1], [4]) != [4,1]:
    validated = False
elif poly_add([5,2], [2,3,4,1]) != [7,5,4,1]:
    validated = False
elif poly_add([5],[7], p=4) != [0]:
    validated = False
elif poly_add([1,2,3], [4,2], p=3) != [2,1]:
    validated = False
if not validated:
    make_square('red')

In [9]:
def poly_subtract(f, g, p=None, debug=False):
    """ Subtracts f and g, mod p if specified"""
    minus_g = [-c for c in g]
    return poly_add(f, minus_g, p=p, debug=debug)

# testing
validated = True
if poly_subtract([1], [1]) != [0]:
    validated = False
if poly_subtract([2,4,1], [5]) != [-3,4,1]:
    validated = False
if poly_subtract([5], [2,4,1]) != [3, -4, -1]:
    validated = False
if poly_subtract([5], [6], p=4) != [3]:
    validated = False
if not validated:
    make_square('red')

In [10]:
def poly_mult(f, g, p=None):
    """ Multiplies f and g as polynomials, reducing the coefficients mod p if needed."""
    result = [0 for _ in range(len(f) + len(g) - 1)]
    for i in range(len(f)):
        for j in range(len(g)):
            if not p:
                result[i + j] += f[i]*g[j]
            else:
                result[i + j] = (result[i + j] + f[i]*g[j]) % p
    return poly_format(result)

# testing 
validated = True
if poly_mult([1,2], [0,4]) != [0,4,8]:
    validated = False
if poly_mult([], []) != []:
    validated = False
if poly_mult([1,2,3,4,5], [0,7], p=5) != [0,2,4,1,3]:
    validated = False
if not validated:
    make_square('red')

In [11]:
def poly_convolution(f, g, N, p=None):
    """ Performs polynomial convolution, viewing f and g as elements of R 
        (polynomials of degree < N). If modulus is specified, all coefficients are
        reduced mod the modulus."""
    
    # make an empty list for the result
    result = [0 for _ in range(N)]
    
    for i in range(len(f)):
        for j in range(len(g)):
            
            # get the right index
            k = (i + j) % N
            if not p:
                result[k] += f[i]*g[j]
            else:
                result[k] = (result[k] + f[i]*g[j]) % p
    return poly_format(result)

# testing
validated = True
if poly_convolution([1], [1], 7) != [1]:
    validated = False
elif poly_convolution([1,2], [0,0,3,4], 7) != [0,0,3,10,8]:
    validated = False
elif poly_convolution([1,2], [0,0,3,4], 7, p=5) != [0,0,3,0,3]:
    validated = False
if not validated:
    make_square('red')

In [12]:
def poly_division(f, g, p):
    """ Performs the division algorithm of f and g, 
        returning (quotient, remainder). f and g are
        polynomials in Z/pZ. This algorithm is described
        on page 99 of Intro to Math Crypto."""
    
    if g == [0]:
        s = "Cannot divide by the zero polynomial. You tried doing {}/{} mod {}".format(to_string(f), to_string(g), p)
        raise ValueError(s)
    
    # so it doesn't accidentally think the leading coefficient is a 0
    g = poly_format(g)
    
    # starting by setting the quotient to 0 and the remainder to f
    # so this satisfies f = kg + r
    k = [0]
    r = f
    # keep adding and subtracting b*r[e]/b[d] to maintain the equality
    # and make r have a smaller degree
    while len(r) >= len(g) and r != [0]:
        d = len(g) - 1
        e = len(r) - 1
        single_term = [0 for _ in range(e-d)] + [(r[e]*modular_inverse(g[d], p)) % p]
        k = poly_format(poly_add(k, single_term, p=p))
        r = poly_format(poly_subtract(r, poly_mult(single_term, g, p=p), p=p))
    return (k, r)

# testing
validated = True
if poly_division([-1,0,0,0,0,1], [-3,2,0,1], 13) != ([11,0,1], [6,4,3]):
    validated = False
if not validated:
    make_square('red')

In [13]:
def poly_EEA(f, g, p):
    """ Peforms the extended Euclidean algorithm on f and g in R_p.
        Returns (gcd, u, v)."""
    u = [1]
    gcd = f
    x = [0]
    y = list(g)
    while y != [0]:
        (q,t) = poly_division(gcd, y, p)
        s = poly_subtract(u, poly_mult(q, x, p=p), p=p)
        u = x
        gcd = y
        x = s
        y = t
    v = poly_division(poly_subtract(gcd, poly_mult(f, u, p=p), p=p), g, p)[0]
    return (gcd, u, v)
    
# testing
validated = True
test_data = (([-1,0,0,0,0,1], [-3,2,0,1], 13), ([4,3,2,1], [1,0,0,5], 7), ([65,0,1,-4,0,0,0,2], [1,2], 19))
for (f,g,p) in test_data:
    (gcd,u,v) = poly_EEA(f, g, p)
    if gcd != poly_format(poly_add(poly_mult(f, u, p=p), poly_mult(g, v, p=p), p=p)):
        validated = False

if not validated:
    make_square('red')

In [14]:
def reduce_mod(f, p):
    """ Reduce the coefficients of f mod p"""
    return [i % p for i in f]

# testing
validated = True
if reduce_mod([2,7,3,4], 3) != [2,1,0,1]:
    validated = False
    
if not validated:
    make_square('red')

In [15]:
def scalar_mult(f, p):
    """ Return p multiplied by each coefficient of f"""
    return [p*a for a in f]

In [16]:
def poly_reversal(f):
    """ returns f bar, which is f(x^{-1})"""
    reversal = [f[0]]
    for i in range(1, len(f)):
        reversal.append(f[-i])
    return reversal

# testing
validated = True
if poly_reversal([1,2,3,4]) != [1,4,3,2]:
    validated = False
if not validated:
    make_square('red')

# Convolution Rings
Functions for getting inverses in $R_p(N)$.

### Contents

`has_inverse()`

`get_inverse()`

`lift_inverse()`

`find_inverse_mod_q()`

In [17]:
def has_inverse(f, N, p):
    """ Returns true if f is invertible in R_p by checking
        the gcd of f and x^N-1 in F_p[x]."""
    
    # the special case of 0
    if poly_format(f) == [0]:
        return False
    
    modulus = [-1] + [0 for _ in range(N-1)] + [1]
    return len(poly_EEA(f, modulus, p)[0]) == 1

# testing
validated = True
if has_inverse([0,0,1], 3, 2) != True:
    validated = False
elif has_inverse([1,1,1], 3, 2) != False:
    validated = False
elif has_inverse([1,1,0,0,1], 5, 2) != True:
    validated = False
elif has_inverse([6], 7, 11) != True:
    validated = False
    
if not validated:
    make_square('red')

In [18]:
def get_inverse(f, N, p):
    """ Finds the inverse of f in R_p(N) using EEA"""
    
    # make x^N - 1 as a polynomial
    modulus = [-1] + [0 for _ in range(N-1)] + [1]
    
    # perform the Extended Euclidean Algorithm
    (gcd,u,v) = poly_EEA(f, modulus, p)
    if len(gcd) > 1:
        return None
    return scalar_mult(u, modular_inverse(gcd[0], p))

# testing
validated = True
p = 2; N = 5
for _ in range(10):
    f = list(np.random.randint(p, size=N))
    if has_inverse(f, N, p) and poly_convolution(f, get_inverse(f, N, p), N, p=p) != [1]:
        print(poly_convolution(f, get_inverse(f, N, p), N, p=p))
        validated = False
p = 3; N = 101
for _ in range(5):
    f = list(np.random.randint(p, size=N))
    if has_inverse(f, N, p) and poly_convolution(f, get_inverse(f, N, p), N, p=p) != [1]:
        print(poly_convolution(f, get_inverse(f, N, p), N, p=p))
        validated = False
        
if not validated:
    make_square('red')

In [19]:
def lift_inverse(f, F, N, q):
    """ Computes G = F * (2 - f*F) mod q, where
        q = p^{2i} and F is the inverse of f 
        mod p^i"""
    return poly_convolution(F, poly_subtract([2], poly_convolution(f, F, N, p=q), p=q), N, p=q)

# testing
validated = True
if lift_inverse([1,0,1], [2,1,1,2,2], 5, 9) != [5,4,4,5,5]:
    validated = False

if not validated:
    make_square('red')

In [20]:
def find_inverse_mod_q(f, N, p, alpha):
    """ Computes f inverse mod q = p^alpha"""
    q = p**alpha
    F = get_inverse(reduce_mod(f,p), N, p)
    
    # if f is not invertible mod p, return None
    if not F:
        return None
        
    power_of_two = 1
    while power_of_two <= alpha:
        power_of_two *= 2
        F = lift_inverse(f, F, N, p**power_of_two)
        
    
    return [coefficient % q for coefficient in F]


# testing on random polynomials
validated = True

N = 5; p = 2; alpha = 7
i = 0
while i < 5:
    f = np.random.randint(p**alpha, size=N)
    F = find_inverse_mod_q(f, N, p, alpha)
    if F:
        i += 1
        if poly_convolution(f, F, N, p=p**alpha) != [1]:
            validated = False
            
N = 8; p = 3; alpha = 6
i = 0
while i < 5:
    f = np.random.randint(p**alpha, size=N)
    F = find_inverse_mod_q(f, N, p, alpha)
    if F:
        i += 1
        if poly_convolution(f, F, N, p=p**alpha) != [1]:
            validated = False

if not validated:
    make_square('red')

# NTRUEncrypt
### Contents

`random_T()`

`invert()`

`make_f()`

`make_g()`

`make_h()`

`make_m()`

`make_r()`

`center_lift()`

In [21]:
def random_T(d_1, d_2, N):
    """ Returns a random element of T(d_1, d_2), a polynomial
        of degree N with d_1 1's and d_2 -1's and 0's elsewhere."""
    f = [1 for _ in range(d_1)] + [-1 for _ in range(d_2)] + [0 for _ in range(N - d_1 - d_2)]
    return list(np.random.permutation(f))

In [22]:
def invert(f, N, p):
    """ This finds the inverse of f in R_p(N) if p is a prime power."""
    factors = factor(p)
    
    # check if p is not a prime power
    if len(factors) > 1:
        print("Invert() only works on prime powers, not on {}".format(p))
    
    return find_inverse_mod_q(f, N, factors[0][0], factors[0][1])

In [23]:
def make_f(N, p, q):
    """ Gets f for NTRU, a random element of T(d+1, d)
        that is invertible in R_p(N) and R_q(N). Returns
        a tuple of f and its inverses."""
    d = N//3
    
    # pick a random f
    f = random_T(d+1, d, N)
    F_p = invert(f, N, p); F_q = invert(f, N, q)
    while not F_p or not F_q:
        
        # if it's not invertible in both rings, choose a new one
        f = random_T(d+1, d, N)
        F_p = invert(f, N, p); F_q = invert(f, N, q)
        
    # just checking
    if poly_convolution(f, F_p, N, p=p) != [1]:
        print("Inversion Failure: {}, {}, {}, {}".format(f, F_p, N, p))
    if poly_convolution(f, F_q, N, p=q) != [1]:
        print("Inversion Failure: {}, {}, {}, {}".format(f, F_q, N, q))
        
    return (f, F_p, F_q)

def make_g(N):
    d = N//3
    return random_T(d, d, N)

def make_h(F_q, g, N, q):
    return poly_convolution(F_q, g, N, p=q)

def make_m(N, p):
    m = list(np.random.randint(-p//2 +1, high=p//2 + 1, size=N))
    return poly_format(m)

def make_r(N):
    d = N//3
    return random_T(d, d, N)

In [24]:
def center_lift(f, p):
    """ Returns the center lift of mod p, the unique polynomial
        whose coefficients are congruent mod p to f, and are between
        -p/2 and p/2."""
    center_lift = []
    
    # iterate through all coefficients
    for coefficient in f:
        
        # reduce the coefficient mod p
        reduced = coefficient % p
        
        # if it is less than or equal to p/2, just add it to the center lift
        if reduced <= p/2:
            center_lift.append(reduced)
            
        # otherwise, subtract p to make the value between -p/2 and 0
        else:
            center_lift.append(reduced - p)
    return poly_format(center_lift)

In [25]:
N = 7
p = 3
g = [1,1,-1,-1,0,0,0]
r = [1,0,0,0,-1,-1,1]
f = [1,1,1,-1,-1,0,0]
m = [1,-1,-1,-1,-1,1,1]
print(poly_add(poly_convolution(scalar_mult(g,p),r,N), poly_convolution(f,m,N)))

[17, 4, -9, -8, -6, -5, 6]


In [26]:
f = [2,1,4,5,0,2,1]
f_bar = poly_reversal(f)
print(poly_convolution(f, f_bar, 7))

[51, 30, 28, 29, 29, 28, 30]


In [27]:
f = [1,1,3,2]
g = [f[-1]] + f[0:-1]
print(sum([a*a for a in f]))
print(sum([f[i]*g[i] for i in range(len(f))]))

15
12


# Testing NTRU

In [28]:
def is_prime(n):
    """ The naive primality test."""
    for i in range(2, int(np.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

def next_prime(p):
    """ Returns the next prime. Very inefficient."""
    q = p + 1
    while not is_prime(q):
        q += 1
    return q

In [29]:
def nearby_random_prime(n):
    """ Returns a random prime within 20% of n"""
    nearby_x_values = range(int(n-n/5), int(n + n/5))
    
    # a counter to make sure we don't hit an infinite loop
    i = 0
    while True:
        potential_prime = np.random.choice(nearby_x_values)
        if is_prime(potential_prime):
            return potential_prime
        i += 1
        if i > 100:
            raise ValueError("No nearby primes found for {}".format(n))

In [30]:
def next_power_of_two(n):
    """ Returns the smallest power of 2 
        greater than n"""
    power = 2
    while power < n:
        power *= 2
    return power

In [31]:
def test_NTRU(N, p, q, d, debug=False):
    """ Tests a random example of NTRU with the given parameter set."""
    (f, F_p, F_q) = make_f(N, p, q)
    g = make_g(N)
    h = make_h(F_q, g, N, q)
    m = make_m(N, p)
    r = make_r(N)
    e = poly_add(poly_convolution(scalar_mult(h, p), r, N, p=p) , m, p=p)
    a = poly_convolution(f, e, N, p=q)
    a = center_lift(a, q)
    b = poly_convolution(F_p, a, N, p=p)
    b = center_lift(b, p)
    
    if debug:
        print('f: ' + str(f))
        print('F_p: ' + str(F_p))
        print('F_q: ' + str(F_q))
        print('g: ' + str(g))
        print('h: ' + str(h))
        print('m: ' + str(m))
        print('r: ' + str(r))
        print('e: ' + str(e))
        print('a: ' + str(a))
        print('b: ' + str(b))
    
    if b != m:
        print(b,m)
    return b == m

In [32]:
def test_NTRU_examples(num_Ns, max_N):
    """ This tests NTRU on random examples. It will test random N 
        up to max_N, somewhat evenly distributing them. """
    
    successes = 0 # count the successes
    
    min_test_value = 30
    test_interval = max_N // num_Ns
    test_values = [min_test_value + test_interval*i for i in range(num_Ns)]
    
    for test_value in test_values:
        N = nearby_random_prime(test_value)
        d = N//3
        
        # p = 3, q = power of 2
        p = 3
        q = next_power_of_two((6*d+1)*p)
        test_result = test_NTRU(N, p, q, d)
        if test_result:
            successes += 1
    
    if successes == num_Ns:
        print("All tests successful.")
    else:
        print("Not all tests successful.")

In [33]:
test_NTRU_examples(30, 500)

KeyboardInterrupt: 

In [75]:
test_NTRU(743,3,2048, 743//3)

True

In [35]:
test_NTRU(443,3,1024, 443//3)

True