In [33]:
import sage.all
import time
import random as pyrandom

class NTRU:
    """
    Complete NTRU (NTRUEncrypt) implementation following IEEE P1363.1 standard.
    
    Supports:
    - Standard parameter sets (toy, recommended, high-security)
    - Proper key generation with invertibility checks
    - Encryption/Decryption
    - LLL lattice attack simulation
    - Comprehensive testing utilities
    """
    
    # Standard parameter sets (N, p, q, df, dg, dr)
    PARAM_SETS = {
        'toy': (11, 3, 31, 4, 3, 3),           # For testing only
        'ees251ep3': (251, 3, 127, 50, 24, 16), # Deprecated, educational
        'ees251ep6': (251, 3, 127, 50, 24, 16),
        'moderate': (167, 3, 127, 61, 20, 10),  # Medium security
        'high': (347, 3, 2047, 115, 38, 38),    # Higher security
    }
    
    def __init__(self, N=None, p=None, q=None, df=None, dg=None, dr=None, 
                 param_set='moderate'):
        """
        Initialize NTRU cryptosystem.
        
        Args:
            N: Ring dimension (must be prime for security)
            p: Small modulus (typically 3)
            q: Large modulus (must be much larger than p)
            df: Number of ±1 coefficients in private key f
            dg: Number of ±1 coefficients in g
            dr: Number of ±1 coefficients in blinding polynomial r
            param_set: Use predefined parameter set ('toy', 'moderate', 'high')
        """
        if param_set and param_set in self.PARAM_SETS:
            N, p, q, df, dg, dr = self.PARAM_SETS[param_set]
        
        if N is None or p is None or q is None:
            raise ValueError("Must specify either param_set or (N, p, q)")
        
        self.N = N
        self.p = p
        self.q = q
        self.df = df if df is not None else N // 3
        self.dg = dg if dg is not None else N // 3
        self.dr = dr if dr is not None else N // 3
        
        # Validate parameters
        self._validate_parameters()
        
        # Create polynomial rings
        self.Zx = PolynomialRing(ZZ, 'x')
        x = self.Zx.gen()
        self.R = self.Zx.quotient(x^N - 1, 'x')  # Z[x]/(x^N - 1)
        
        # Storage for keys
        self.public_key = None
        self.private_key = None
    
    def _validate_parameters(self):
        """Validate NTRU parameters for security and correctness."""
        if self.q <= 2 * self.p:
            raise ValueError(f"q must be much larger than p (q={self.q}, p={self.p})")
        if not is_prime(self.N):
            print(f"Warning: N={self.N} is not prime. Use prime N for security.")
        if gcd(self.N, self.q) != 1:
            raise ValueError(f"N and q must be coprime (gcd({self.N}, {self.q}) != 1)")
        if gcd(self.p, self.q) != 1:
            raise ValueError(f"p and q must be coprime (gcd({self.p}, {self.q}) != 1)")
    
    # ========== Helper Methods ==========
    
    def _poly_to_vector(self, poly):
        """Convert polynomial to coefficient vector."""
        coeffs = poly.list() if hasattr(poly, 'list') else list(poly)
        return coeffs + [0] * (self.N - len(coeffs))
    
    def _vector_to_poly(self, vec):
        """Convert coefficient vector to polynomial in R."""
        return self.R(vec)
    
    def _center_lift(self, poly, modulus):
        """
        Reduce polynomial coefficients to range [-(modulus//2), modulus//2].
        Critical for correct NTRU decryption.
        """
        coeffs = self._poly_to_vector(poly)
        centered = []
        for c in coeffs:
            c = Integer(c) % modulus
            if c > modulus // 2:
                c -= modulus
            centered.append(c)
        return self.R(centered)
    
    def _poly_norm(self, poly):
        """Compute Euclidean norm of polynomial as vector."""
        coeffs = self._poly_to_vector(poly)
        return sqrt(sum(c^2 for c in coeffs))
    
    def _random_ternary(self, d_ones, d_neg_ones):
        """
        Generate random ternary polynomial with exactly:
        - d_ones coefficients equal to +1
        - d_neg_ones coefficients equal to -1
        - remaining coefficients equal to 0
        """
        coeffs = [0] * self.N
        positions = list(range(self.N))
        shuffle(positions)
        
        for i in range(d_ones):
            coeffs[positions[i]] = 1
        for i in range(d_ones, d_ones + d_neg_ones):
            coeffs[positions[i]] = -1
        
        return self.R(coeffs)
    
    def _invert_mod(self, f, modulus):
        Zmod = Integers(modulus)

        # 1. Polynomial ring over Zmod
        R = PolynomialRing(Zmod, 'x')
        x = R.gen()

        # 2. Lift f to an actual polynomial
        #    (this extracts coefficients correctly)
        f_lift = R(f.lift())

    # 3. Work in the quotient ring
        Rq = R.quotient(x**self.N - 1, 'xbar')
        f_q = Rq(f_lift)

    # 4. Try inversion
        try:
            return f_q.inverse_of_unit()
        except (ArithmeticError, ZeroDivisionError):
            return None

    
    # ========== Key Generation ==========
    
    def keygen(self, max_attempts=100, verbose=False):
        """
        Generate NTRU key pair with proper invertibility checks.
        
        Returns:
            Dictionary with 'public' and 'private' keys
        """
        for attempt in range(max_attempts):
            # Generate f with df +1's and df -1's
            f_prime = self._random_ternary(self.df, self.df)
            f = self.R(1) + self.p * f_prime
            
            # Check invertibility mod p
            f_p = self._invert_mod(f, self.p)
            if f_p is None:
                if verbose and attempt % 10 == 0:
                    print(f"Attempt {attempt}: f not invertible mod p")
                continue
            
            # Check invertibility mod q
            f_q = self._invert_mod(f, self.q)
            if f_q is None:
                if verbose and attempt % 10 == 0:
                    print(f"Attempt {attempt}: f not invertible mod q")
                continue
            
            # Success! Generate g
            g = self._random_ternary(self.dg, self.dg)
            
            # Compute public key: h = p * f_q * g (mod q)
            h = self.R(self.p * f_q * g)
            h_coeffs = [Integer(c) % self.q for c in h.list()]
            h = self.R(h_coeffs)
            
            # Store keys
            self.public_key = {
                'h': h,
                'N': self.N,
                'q': self.q,
                'p': self.p
            }
            
            self.private_key = {
                'f': f,
                'f_p': f_p,
                'f_q': f_q,
                'g': g,
                'N': self.N,
                'p': self.p,
                'q': self.q
            }
            
            if verbose:
                print(f"Key generation succeeded on attempt {attempt + 1}")
            
            return {
                'public': self.public_key,
                'private': self.private_key
            }
        
        raise RuntimeError(f"Failed to generate valid keys after {max_attempts} attempts")
    
    # ========== Encryption ==========

    def _random_binary(self, dm):
        """
        Generate a random binary polynomial with exactly dm coefficients = 1
        and the rest = 0.
        """
        coeffs = [Integer(0)] * self.N
        ones_positions = pyrandom.sample(range(self.N), int(dm))
        for i in ones_positions:
            coeffs[i] = Integer(1)
        return self.R(coeffs)

    
    def random_message(self, dm=None):
        # Generate a random *binary* message polynomial
        if dm is None:
            dm = self.N // 4
        return self._random_binary(dm)  # coefficients 0 or 1
    
    def encrypt(self, message, public_key=None, return_r=False):
        """
        Encrypt a message polynomial using rejection sampling.
        - message should be a small polynomial (now binary by default).
        - We loop until decryption (using our private key) would succeed.
        """
        if public_key is None:
            public_key = self.public_key or self.keygen()['public']
        h, q = public_key['h'], public_key['q']
        while True:
            r = self._random_ternary(self.dr, self.dr)  # small blinding poly
            c = self.R(r * h + message)
            c_coeffs = [Integer(cc) % q for cc in c.list()]
            c = self.R(c_coeffs)  # reduce mod q
            # If private key available, test decryption:
            if self.private_key is not None:
                try:
                    dec = self.decrypt(c)
                except Exception:
                    continue
                if self._poly_to_vector(dec) != self._poly_to_vector(message):
                    continue  # reject and try again
            # Success: either we matched, or we assume it's OK
            return (c, r) if return_r else c
    
    def decrypt(self, ciphertext, private_key=None):
        """
        Decrypt ciphertext: apply centered reductions before each mod step.
        """
        if private_key is None:
            private_key = self.private_key
        f, f_p, q, p = private_key['f'], private_key['f_p'], private_key['q'], private_key['p']
        # a = f * c mod q, then lift to (-q/2,q/2)
        a = self.R(f * ciphertext)
        a = self._center_lift(a, q)
        # b = f_p * a mod p, then lift to (-p/2,p/2)
        b = self.R(f_p * a)
        m = self._center_lift(b, p)
        return m

    
    # ========== LLL Attack ==========
    
    def construct_attack_lattice(self, public_key=None):
        """
        Construct NTRU lattice basis from public key for LLL attack.
        Adjusted to target (g, f) as short vector.
        
        h' = p^{-1} * h mod q
        Basis matrix:
            B = [ q*I_N |   0   ]
                [  H'   |  I_N ]
        
        where H' is the circulant matrix of h'.
        """
        if public_key is None:
            public_key = self.public_key
        
        h = public_key['h']
        q = public_key['q']
        p = public_key['p']
        N = public_key['N']
        
        # Compute p^{-1} mod q
        p_inv = pow(int(p), -1, int(q))
        
        # Compute h' = p^{-1} * h mod q, centered
        h_coeffs = self._poly_to_vector(h)
        h_prime_coeffs = [(int(c) * p_inv) % q for c in h_coeffs]
        h_prime = self._center_lift(self.R(h_prime_coeffs), q)
        h_prime_coeffs = self._poly_to_vector(h_prime)
        
        # Build circulant (rotation) matrix H'
        H_prime = Matrix(ZZ, N, N)
        for i in range(N):
            for j in range(N):
                H_prime[i, j] = h_prime_coeffs[(j - i) % N]
        
        # Build full basis
        I_N = Matrix.identity(ZZ, N)
        zero_N = Matrix(ZZ, N, N)
        
        # Stack: [q*I_N | 0] over [H' | I_N]
        top = (q * I_N).augment(zero_N)
        bottom = H_prime.augment(I_N)
        basis = top.stack(bottom)
        
        return basis
    
    def lll_attack(self, public_key=None, delta=0.99, verbose=False, 
                   check_all_rows=False):
        """
        LLL attack using adjusted lattice construction targeting (g, f).
        """
        if public_key is None:
            public_key = self.public_key
        if self.private_key is None:
            raise ValueError("Need private key for verification")
        
        h = public_key['h']
        q = public_key['q']
        p = public_key['p']
        N = public_key['N']
        
        # Build adjusted lattice
        basis = self.construct_attack_lattice(public_key)
        
        if verbose:
            print(f"Running BKZ (block_size=10, delta={delta}) on {2*N}×{2*N} lattice...")
        
        start_time = time.time()
        # Use BKZ for better reduction on small dimensions
        reduced_basis = basis.BKZ(block_size=10, delta=delta)
        lll_time = time.time() - start_time
        
        if verbose:
            print(f"BKZ completed in {lll_time:.4f}s")
        
        # Get true private key for comparison
        true_f_vec = self._poly_to_vector(self.private_key['f'])
        true_g_vec = self._poly_to_vector(self.private_key['g'])
        expected_norm = sqrt(sum(x**2 for x in true_g_vec + true_f_vec))
        
        if verbose:
            print(f"\nTarget: ||(g,f)|| = {float(expected_norm):.4f}")
            print("\nShortest vectors found by BKZ:")
            for i in range(min(10, reduced_basis.nrows())):
                row = list(reduced_basis[i])
                norm = sqrt(sum(x**2 for x in row))
                g_part = row[:N]
                f_part = row[N:]
                is_key = (g_part == true_g_vec and f_part == true_f_vec)
                is_neg_key = (g_part == [-x for x in true_g_vec] and f_part == [-x for x in true_f_vec])
                marker = " ← PRIVATE KEY!" if (is_key or is_neg_key) else ""
                print(f"  Row {i}: norm={float(norm):7.2f}, "
                      f"max|g|={max(abs(x) for x in g_part):3d}, "
                      f"max|f|={max(abs(x) for x in f_part):3d}{marker}")
        
        # Try to find private key
        max_rows = reduced_basis.nrows() if check_all_rows else min(100, reduced_basis.nrows())
        
        def verify_candidate(f_vec, g_vec):
            """Verify if candidate is valid private key."""
            try:
                # Structure checks
                f_max = max(abs(x) for x in f_vec)
                g_max = max(abs(x) for x in g_vec)
                
                if f_max > self.p + 2 or g_max > 1:
                    return None
                
                # Convert to polynomials
                f_cand = self.R([Integer(x) for x in f_vec])
                g_cand = self.R([Integer(x) for x in g_vec])
                
                # Check NTRU equation: f*h ≡ p*g (mod q)
                lhs = self.R(f_cand * h)
                rhs = self.R(p * g_cand)
                lhs_mod = [Integer(c) % q for c in lhs.list()[:N]]
                rhs_mod = [Integer(c) % q for c in rhs.list()[:N]]
                
                if lhs_mod != rhs_mod:
                    return None
                
                # Try to invert
                f_p_inv = self._invert_mod(f_cand, p)
                if f_p_inv is None:
                    return None
                
                # Format for decrypt
                coeffs_p = list(f_p_inv.lift().list())
                f_p_in_R = self.R([Integer(int(c) % p) for c in coeffs_p])
                
                # Build candidate key
                cand_key = {
                    'f': f_cand,
                    'f_p': f_p_in_R,
                    'N': N,
                    'p': p,
                    'q': q
                }
                
                # Test decryption
                for _ in range(3):
                    msg = self.random_message()
                    ciph = self.encrypt(msg, public_key)
                    try:
                        dec = self.decrypt(ciph, cand_key)
                        if self._poly_to_vector(dec) != self._poly_to_vector(msg):
                            return None
                    except:
                        return None
                
                return {'f': f_cand, 'g': g_cand}
                
            except:
                return None
        
        # Search reduced basis
        for i in range(max_rows):
            row = list(reduced_basis[i])
            g_vec = row[:N]
            f_vec = row[N:]
            
            # Quick checks
            f_nonzero = sum(1 for x in f_vec if x != 0)
            g_nonzero = sum(1 for x in g_vec if x != 0)
            
            if f_nonzero < N//4 or g_nonzero < N//4:
                continue
            
            # Try candidate
            result = verify_candidate(f_vec, g_vec)
            if result:
                if verbose:
                    print(f"\n✓ SUCCESS! Found private key at row {i}")
                return {
                    'success': True,
                    'recovered_f': result['f'],
                    'recovered_g': result['g'],
                    'row_index': i,
                    'time': lll_time
                }
            
            # Try negation
            result = verify_candidate([-x for x in f_vec], [-x for x in g_vec])
            if result:
                if verbose:
                    print(f"\n✓ SUCCESS! Found -private key at row {i}")
                return {
                    'success': True,
                    'recovered_f': result['f'],
                    'recovered_g': result['g'],
                    'row_index': i,
                    'time': lll_time
                }
        
        if verbose:
            print("\n✗ Attack failed - private key not recovered")
        
        return {'success': False, 'time': lll_time}
    
    # ========== Testing Utilities ==========
    
    def test_correctness(self, num_tests=10, verbose=True):
        """
        Test NTRU encryption/decryption with random binary messages.
        """
        if verbose:
            print(f"Testing NTRU correctness with {num_tests} trials")
        successes, failures = 0, []
        for i in range(num_tests):
            msg = self.random_message()
            cipher = self.encrypt(msg)
            recovered = self.decrypt(cipher)
            if self._poly_to_vector(msg) == self._poly_to_vector(recovered):
                successes += 1
            else:
                failures.append(i)
        if verbose:
            rate = 100*successes/num_tests
            print(f"Success rate: {rate.n():.1f}% ({successes}/{num_tests})")
            if failures:
                print("Failed trials:", failures)
        return {'success_rate': rate, 'successes': successes, 'failures': failures}
    
    def benchmark_lll_attack(self, deltas=[0.75, 0.85, 0.99], trials=10, verbose=True):
        """
        Benchmark LLL attack with different delta values.
        
        Returns:
            DataFrame with results
        """
        import pandas as pd
        
        results = []
        
        for delta in deltas:
            if verbose:
                print(f"\nTesting delta = {delta}")
            
            for trial in range(trials):
                # Generate fresh keys for each trial
                self.keygen()
                
                # Run attack
                result = self.lll_attack(delta=delta, verbose=False)
                
                results.append({
                    'delta': delta,
                    'trial': trial,
                    'success': result['success'],
                    'time': result['time'],
                })
                
                if verbose and trial == 0:
                    print(f"  First trial: success={result['success']}, "
                          f"time={result['time']:.4f}s")
        
        df = pd.DataFrame(results)
        
        if verbose:
            print("\n" + "="*60)
            print("Summary Statistics:")
            print("="*60)
            summary = df.groupby('delta').agg({
                'success': ['sum', 'mean'],
                'time': ['mean', 'std'],
            })
            print(summary)
        
        return df

    def __repr__(self):
        return (f"NTRU(N={self.N}, p={self.p}, q={self.q}, "
                f"df={self.df}, dg={self.dg}, dr={self.dr})")


# ========== Example Usage ==========
print("="*70)
print("TESTING NTRU LLL ATTACK")
print("="*70)
# Test with toy parameters
print("\n>>> Testing with toy parameters (N=11)")
ntru_toy = NTRU(param_set='toy')
ntru_toy.keygen(verbose=True)
print("\n>>> Running attack...")
result = ntru_toy.lll_attack(delta=0.99, verbose=True)
print(f"\n{'='*60}")
print(f"Attack result: {'SUCCESS' if result['success'] else 'FAILED'}")
print(f"{'='*60}")

# Also test with even weaker parameters
print("\n\n>>> Testing with ultra-weak parameters (N=7)")
ntru_weak = NTRU(N=7, p=3, q=17, df=2, dg=2, dr=2)
ntru_weak.keygen(verbose=True)
print("\n>>> Running attack...")
result = ntru_weak.lll_attack(delta=0.99, verbose=True)
print(f"\nAttack result: {'SUCCESS' if result['success'] else 'FAILED'}")

TESTING NTRU LLL ATTACK

>>> Testing with toy parameters (N=11)
Key generation succeeded on attempt 1

>>> Running attack...
Running BKZ (block_size=10, delta=0.990000000000000) on 22×22 lattice...
BKZ completed in 0.0812s

Target: ||(g,f)|| = 8.8882

Shortest vectors found by BKZ:
  Row 0: norm=   3.32, max|g|=  0, max|f|=  1
  Row 1: norm=   7.87, max|g|=  3, max|f|=  4
  Row 2: norm=   7.87, max|g|=  3, max|f|=  4
  Row 3: norm=   7.87, max|g|=  3, max|f|=  4
  Row 4: norm=   8.00, max|g|=  3, max|f|=  3
  Row 5: norm=   8.77, max|g|=  3, max|f|=  4
  Row 6: norm=   8.49, max|g|=  6, max|f|=  2
  Row 7: norm=   7.87, max|g|=  3, max|f|=  4
  Row 8: norm=   8.89, max|g|=  4, max|f|=  4
  Row 9: norm=   8.77, max|g|=  4, max|f|=  4

✓ SUCCESS! Found private key at row 17

Attack result: SUCCESS


>>> Testing with ultra-weak parameters (N=7)
Key generation succeeded on attempt 1

>>> Running attack...
Running BKZ (block_size=10, delta=0.990000000000000) on 334×334 lattice...
BKZ comple