In [1]:
import random
import math

class NTRU:
    
    def __init__(self, N, p, q, d):
        """
        Initializes an NTRU cryptosystem instance with specified parameters.

        Parameters:
        - N (int): The degree of the polynomial ring, must be a prime number.
        - p (int): Modulus for the coefficients in the polynomials over the smaller field.
        - q (int): Modulus for the coefficients in the polynomials over the larger field.
        - d (int): A parameter that controls the density of ones and negative ones in the polynomials.

        Raises:
        - AssertionError: If any of the parameter conditions are not met.
        """
        
        assert q > (6*d + 1) * p, "Parameters must hold: q > (6*d + 1) * p"
        assert N.is_prime(), "N must be a large prime"
        assert p.gcd(q) == 1, "p and q must be relatively prime"
        
        self.N = N
        self.p = p
        self.q = q
        self.d = d
        
        self.Zx = PolynomialRing(ZZ, 'x')
        self.x = self.Zx.gen()
        self.R = self.Zx.quotient(x^N - 1, 'x')
        
        Zqx = PolynomialRing(GF(self.q), 'x')
        self.Rq = Zqx.quotient(x^N - 1, 'x')
        
        Zpx = PolynomialRing(GF(self.p), 'x')
        self.Rp = Zpx.quotient(x^N - 1, 'x') 
        
        self.public_key, self.private_key = self.generate_keys()
        
        
        
    def generate_pol(self, d1, d2):
        """
        Generates a random polynomial where exactly d1 coefficients are 1, d2 coefficients are -1 and the rest are 0.

        Parameters:
        - d1 (int): Number of coefficients to be set to 1.
        - d2 (int): Number of coefficients to be set to -1.

        Returns:
        - Polynomial: A polynomial in self.R with specified number of 1's and -1's.

        Raises:
        - AssertionError: If d1 + d2 exceeds the degree N.
        """
        assert d1 + d2 <= self.N
        coeffs = self.N*[0]
        
        for i in range(d1):
            j = randint(0, self.N-1)
            while coeffs[j] != 0:
                j = randint(0, self.N-1)
            coeffs[j] = 1
            
        for i in range(d2):
            j = randint(0, self.N-1)
            while coeffs[j] != 0:
                j = randint(0, self.N-1)
            coeffs[j] = -1
            
        return self.R(coeffs)
    
    def generate_message(self):
        """
        Generates a random polynomial message where each coefficient is chosen independently from 
        the integer interval [-p/2, p/2].

        Returns:
        - Polynomial: A message polynomial in self.R, with coefficients in the specified range.
        """
        upper_bound = math.floor(self.p / 2)
        lower_bound = -upper_bound

        coeffs = [randint(lower_bound, upper_bound) for _ in range(self.N)]

#         coeffs = [randint(0,10) for _ in range(self.N)]
#         message = self.reduce_mod(self.R(coeffs), self.p)

#         return message
        return self.R(coeffs)
            
    def star_product(self, p1, p2):
        """
        Computes the 'star' product of two polynomials, which is their product reduced by the ring's modulus.

        Parameters:
        - p1 (Polynomial): First polynomial operand.
        - p2 (Polynomial): Second polynomial operand.

        Returns:
        - Polynomial: The product of p1 and p2 within the ring self.R.
        """
        return self.R(p1) * self.R(p2)
    
    def reduce_mod(self, pol, mod):
        """
        Reduces a polynomial modulo 'mod' and balances the coefficients in the integer range [-mod/2, mod/2].

        Parameters:
        - pol (Polynomial): The polynomial to be reduced.
        - mod (int): The modulus to be used for reduction.

        Returns:
        - Polynomial: The reduced and balanced polynomial in self.R.
        """

        coeffs = [((pol[i] + mod//2) % mod) - mod//2 for i in range(self.N)]
        return self.R(coeffs)
    
    def generate_keys(self):
        """
        Generates the public and private keys for the NTRU cryptosystem.

        Returns:
        - Tuple: A tuple containing the public key and the private key.
        """

        fp, fq = None, None

        while fp is None or fq is None:
            g = self.generate_pol(self.d, self.d)
            f = self.generate_pol(self.d, self.d-1)

            try:
                fp = self.R(1 / self.Rp(f))
#         fp = self.R(1 / self.reduce_mod(f, self.p))
            except ZeroDivisionError:
                fp = None
                continue

            try:
                fq = self.R(1 / self.Rq(f))
#         fq = self.R(1 / self.reduce_mod(f, self.q))
            except ZeroDivisionError:
                fq = None
                continue
        
#         h = self.Rq(self.star_product(fq, g))
        h = self.reduce_mod(self.star_product(fq, g), self.q)

        public_key = h
        private_key = (f, fp)

        return public_key, private_key
    
    
    def encrypt(self, m):
        """
        Encrypts a message polynomial using the public key.

        Parameters:
        - m (Polynomial): The message polynomial to be encrypted.

        Returns:
        - Polynomial: The encrypted message.
        """
        h = self.public_key
        phi = self.generate_pol(self.d, self.d)
#         encrypted_m = self.Rq(self.star_product(self.p*phi, h) + m)
        encrypted_m = self.reduce_mod(self.star_product(self.p*phi, h) + m, self.q)
        return encrypted_m
        
    def decrypt(self, encrypted_m):
        """
        Decrypts an encrypted message polynomial using the private key.

        Parameters:
        - encrypted_m (Polynomial): The encrypted message polynomial.

        Returns:
        - Polynomial: The decrypted message.
        """
        f, fp = self.private_key
#         a = self.Rq(self.star_product(f, encrypted_m))
        a = self.reduce_mod(self.star_product(f, encrypted_m), self.q)
#         return self.Rp(self.star_product(fp, self.R(a)))
        return self.reduce_mod(self.star_product(fp, a), self.p)
    
    
    def run(self):
        """
        A test function to run an encryption and decryption cycle to verify the implementation.

        Prints:
        - Message, Encrypted Message, Decrypted Message and verification result.
        """
        m = self.generate_message()
        print("Message: ", m)
        encrypted_m = self.encrypt(m)
        print("Encrypted message: ", encrypted_m)
        decrypted_m = self.decrypt(encrypted_m)
        print("Decripted message: ", decrypted_m)
        print(m == decrypted_m)
        
        

In [2]:
ntru = NTRU(11,3,32,1)

In [3]:
ntru.run()

Message:  x^10 - x^9 - x^7 + x^5 - x^4 - x^2 + x + 1
Encrypted message:  x^10 + 2*x^9 - x^7 + 4*x^5 - x^4 - 3*x^3 - x^2 + x - 2
Decripted message:  x^10 - x^9 - x^7 + x^5 - x^4 - x^2 + x + 1
True
