In [None]:
# sagemath version 9.8
# Authors : K.A.Draziotis
# refactoring : M. Adamoudis
reset()

# Auxiliary functions

In [None]:
%run ntru-lattice-vfk-cvp.ipynb

In [None]:
def LLL_reduction_of_M_NTRU(init_M_NTRU):        
    def fptosage(A):
        n = A.nrows
        C = matrix(n)
        for i in range(n):
            C[i] = list(A[i])
        return C
    import time
    from fpylll import IntegerMatrix,LLL
    start = time.time()
    M_NTRU_fplll = IntegerMatrix.from_matrix(init_M_NTRU)
    LLL.reduction(M_NTRU_fplll, delta =0.99 )
    M_NTRU = fptosage(M_NTRU_fplll)
    print("LLL is done")
    print("time for LLL:"),time.time()-start
    return M_NTRU,M_NTRU_fplll

def find_k_and_P(q):
    # these are parameters for setting the VFK lattice
    for k in range(1,200):
        A,B,C = -(1+k^2), -k^2+2*k*q-1, q*(k-q)
        g = A*x^2+B*x+C
        P = floor( k*q / (k^2+1)  )
        G=g.subs(x=P)
        if G>0:
            A,B,C = -(1+(k-1)^2), -(k-1)^2+2*(k-1)*q-1, q*(k-1-q)
            g = A*x^2+B*x+C
            P = floor( (k-1)*q / ( (k-1)^2+1)  )
            G=g.subs(x=P)
            print(q,":",k-1,P)  
            return k-1,P

def unimodular(P,N):
    '''
    N,P see the paper (both are positive integers)
    The output is a unimodular matrix of dimension 2N x 2N
    '''
    #upper block
    I=identity_matrix(N)
    Zero_Matrix=matrix(N)
    #lower block
    matrix_P = P*I
    
    B_u = block_matrix([[I,Zero_Matrix]])   
    B_l = block_matrix([[matrix_P,I]])  
    unimod_matrix = block_matrix([[B_u],[B_l]])   
    return unimod_matrix

# Key Generation

In [None]:
# key generation
# generate a list with d+1 ones and d minus ones 
Zx.<x> = ZZ[]
def T(d1,d2,N):
    # d1+d2<N
    import random
    Zx.<x> = ZZ[]    
    a = d1*[1]
    b = d2*[-1]
    c = (N-d1-d2)*[0]
    L = flatten([a,b,c])
    random.shuffle(L) 
    return L,Zx(L)

# generation of private keys (f,g)
def private_keys(N,d):
    f = T(d+1,d,N)
    g = T(d,d,N)
    return f[1],g[1]
    
# compute the inverses of f in R_p and R_q

def CenterLift(f,q,N):    
    f_balanced = list(   ((f[i]+q//2)%q) -q//2  for i in range(N))
    return Zx(f_balanced)

def Invertmodprime(f,p,N):        #p must be prime
    if is_prime(p)==False:
        return "error"
    T = Zx.change_ring(Integers(p)).quotient(x^N-1)
    return Zx(lift(1/T(f)))
    
def Invertmodpowerofprime(f,q,e,N): # we compute the inverse of f in R_m where m = q ^ e, and q is a prime number
    F = Invertmodprime(f,q,N)
    if e == 1:      
        return F
    n = 2
    while e>0:
        temp = Convolution_in_R(F,f,N);
        F = Convolution_in_R_p(F,2-temp,N,q^n);
        e = floor(e/2)
        n = 2*n
    return F
    
def Convolution_in_R(f,g,N):
    return (f*g)%(x^N-1)

def Convolution_in_R_p(f,g,N,p):
    h = (f*g)%(x^N-1)
    h1 = list( (h[i]%p)   for i in range(N) )
    return  Zx(h1)  
        
def gen_keys(N,d,p,q,e): # N,p primes, 
    #if d  >= ((q^e-p)/(6*p)).n():
    #    print d  , ((q^e-p)/(6*p)).n()
    #    return "choose a smaller d or larger N or q^e",d,(q^e-p)/(6*p).n()
    f,g=private_keys(N,d);
    try:
        Invertmodprime(f,q,N)
    except ZeroDivisionError:
        print("Oops! there is not inverse of f in R_q")
        return
    try:
        Invertmodprime(f,p,N)
    except ZeroDivisionError:
        print("Oops! there is not inverse of f in R_p")
        return
    Fq = Invertmodpowerofprime(f,q,e,N)      
    h = Convolution_in_R_p(Fq,g,N,q^e); # public key
    print("checking if f is inverted modq...",Convolution_in_R_p(f,Fq,N,q^e)==1)
    return f,g,h
    
# some auxiliary functions

def matrix_for_the_lattice(N,q,exponent,h):
    m = q^(exponent)
    H = matrix(N)
    I=identity_matrix(N)

    Zero_Matrix=matrix(N)

    for i in range(N):
        for j in range(N):
            H[i,j] = Convolution_in_R(h,x^i,N)[j]
       
    B_1=block_matrix([[I,H]])        
    B_2=block_matrix([[Zero_Matrix,m*I]])
    M_NTRU=block_matrix([[B_1],[B_2]])
    return M_NTRU

def LOG2(x):
    if x==0:
        return 0
    else:
        return log(x,2).n()
        
def bits(L):
    M = []
    for i in range(len(L)):
        M.append(floor(log(L[i],2)) + 1)
    return M

In [None]:
### Generate a random message
def msg(p):
    left = ceil(-p/2)
    right = floor(p/2)
    M = Zx([randint(left,right) for i in range(N)])
    return M

### Encryption
def enc(N,p,q,exponent,d,msg):
    #print("the message M:",M) # the message
    r = T(d,d,N) #ephemeral key
    #show("the ephemeral key r:",r)

    ### Encryption
    e1 = Convolution_in_R_p(h,p*r[1],N,q^exponent)
    e = e1 + msg
    #print('the encryption e:',e)
    return e

### Decryption
def decryption(N,p,q,exponent,d,e,msg):
    m = q^exponent;
    a = Convolution_in_R_p(f,e,N,m)
    a = CenterLift(a,m,N)
    Fp = Invertmodprime(f,p,N)
    b=Convolution_in_R_p(Fp,a,N,p)
    dec = CenterLift(b,p,N)
    #print("decryption",dec)  
    print("Is decryption valid?",dec==msg)   # we check if we find the message m(x) after decryption
    return dec==msg

In [None]:
# step 1-4

def initial_param(N,q,exponent,y,alpha_vector):
    '''
    Input : (N,q,exponent,y,alpha_vector)
    
    Output: ( A,k,alpha_vector,M_NTRU,Blist)
    Where, 
    - A is a polynomial of R_q, 
    - k is the constant (N-1)/2, 
    - alpha_vector is the vector that define the matrix L_{a},
    - M_NTRU is the NTRU matrix that we use in our attack
    - B_list is +
    '''
    import random
    # step 1
    m = q^(exponent)
    k = (N-1)/2    
    # step 2
    # we uncomment the following line for N=677, 557
    #vector_a = [i for i in range(-k,0)] + [i for i in range(1,k+1)] + [floor(N*m^(1/y)) + 1] 
    #vector_a = [randint(-k,k) for i in range(N-1)] + [floor(N*m^(1/y)) + 1] 
    #if N==509:
    #    import numpy as np
    #    a_vector_from_file_ = np.loadtxt(file_a, dtype='int')
    #    a_vector_from_file = a_vector_from_file_.tolist()
    #    alpha_vector =  a_vector_from_file  # for N=509, q=2048    
    
    A = Zx(alpha_vector)
    
    # step 3
    M_NTRU = matrix_for_the_lattice(N,q,exponent,A)
    
    # step 4
    B =Convolution_in_R_p(A,e,N,m);
    Blist = B.coefficients(sparse=False)
    
    return A,k,alpha_vector,M_NTRU,Blist

In [None]:
def corrections(N,q,exponent,p,h,A,r):
    m = q^exponent
    C1 = Convolution_in_R_p(-p*A,r[1],N,m)
    C = Convolution_in_R_p(C1,h,N,m);
    # E must be a "good" approximation of C_vector in order to get a successful attack
    M_vector = M.coefficients(sparse=False)
    C_vector = C.coefficients(sparse=False)
    # we correct the vectors by appending zeros in the case their dimension is <N
    
    print(len(M_vector),len(C_vector))
    if len(M_vector)<N:
        diff = N - len(M_vector)
        M_vector.append(diff*[0])
        M_vector=flatten(M_vector)
    if len(C_vector)<N:
        diff = N - len(C_vector)
        C_vector.append(diff*[0])
        C_vector=flatten(C_vector)
        
    print(len(M_vector),len(C_vector))
    return C_vector,M_vector
    

In [None]:
# step 5
# Our oracle, which in each call return an approximation of u(x)

def oracle(N,q,exponent,Range,M_vector,C_vector,Blist):
    
    import random
    m = q^exponent
    Real = M_vector + C_vector # note that M_vector, C_vector are lists, so the sum is the concatenation of the two lists
    u = M_vector + (vector(C_vector) + vector(Blist)).list()
    temp  = vector(C_vector) + vector([randint(-Range,Range) for i in range(N)]); # random choice
    E = [0]*N + list(temp);
    #print E
    return E

def target_vector(N,Blist,E):
    if len(Blist)==N:
        t = vector(N*[0] + Blist) + vector(E)
    return list(t) # the target vector t = (0_N,b) + E    
    

In [None]:
def the_attack(N,q,exponent,p,h,A,r,Range,Blist,init_M_NTRU,M_NTRU1,counts,flag,M_vector,C_vector):
    from fpylll import GSO
    import time
    
    m =  q^exponent
    
    def fptosage(A):
        n = A.nrows
        C = matrix(n)
        for i in range(n):
            C[i] = list(A[i])
        return C
    
    
    def hits(L,M):
        Len = len(L)
        K = []
        for i in range(Len):
            if L[i]!=M[i]:
                K.append(i)
        Len1 = Integer(len(K))
        percentage = ((Len-Len1)/Len) * 100
        print("percentage:",float(percentage))
        print("the lists differs in ",Len1," elements")
        return 

    print("N=",N)
    print("d=",d)
    print("p=",p)
    print("q,e,q^e=",q,exponent,q^exponent)
    print("range : |e_i-c_i|<= ",Range)
    print("y=",y)
    start =  time.time()
    if flag==1:
        M_NTRU,M_NTRU_fplll = LLL_reduction_of_M_NTRU(init_M_NTRU)
    # we use flag=2 in the case we have aleready reduce our Ntru matrix in Fpylll.
    # This is becouse in sagemath we get error for large values of N
    # We do not use this feature here
    if flag==2:
        M_NTRU_fplll = M_NTRU1
        M_NTRU = fptosage(M_NTRU_fplll)
        
    M_GSO = GSO.Mat(M_NTRU_fplll)
    M_GSO.update_gso()
    
    for i in range(counts):
        start1 = time.time()
        print("\n",i+1)
        print("=======")
       
        # the oracle #
        E = oracle(N,q,exponent,Range,M_vector,C_vector,Blist) 
        
        # the target vector #   
        t = target_vector(N,Blist,E)
        
        # we apply Babai. We use fpylll implementation of Babai.#
        L = M_GSO.babai(t)
      
        w = sum(L[i]*M_NTRU[i] for i in range(M_NTRU_fplll.nrows)).list()
        
        print("babai done")
        print("time for babai:",time.time()-start)
        print("success/fail:",list(w[0:N])== M_vector)  
    
        hits(list(w[0:N]),M_vector)
      
        if list(w[0:N])== M_vector: #or list(w_old[0:N])== M_vector:
            print("total time for the loop:",time.time()-start1)
            break
    print("total time for the attack:",time.time()-start)
    
    
def the_attack_VFK(N,q,exponent,p,h,A,r,Range,Blist,init_M_NTRU,counts,P,M_vector,C_vector):
    from fpylll import GSO
    import time
    
    m = q^exponent
    
    def fptosage(A):
        n = A.nrows
        C = matrix(n)
        for i in range(n):
            C[i] = list(A[i])
        return C
    
    
    def hits(L,M):
        Len = len(L)
        K = []
        for i in range(Len):
            if L[i]!=M[i]:
                K.append(i)
        Len1 = Integer(len(K))
        percentage = ((Len-Len1)/Len) * 100
        print("percentage:",float(percentage))
        print("the lists differs in ",Len1," elements")
        return 

    print("N=",N)
    print("d=",d)
    print("p=",p)
    print("q,e,q^e=",q,exponent,q^exponent)
    print("range : |e_i-c_i|<= ",Range)
    start =  time.time()
    
    # computation of the superbasis and Selling parameters #
    # note that, this is outside the  following loop #
    
    M_NTRU_VFK=unimodular(P,N)*init_M_NTRU
    basis = M_NTRU_VFK.rows()
    super_basis=get_superbasis(basis)
    Q=get_qij(super_basis)
    
    for i in range(counts):
        start1 = time.time()
        print("\n",i+1)
        print("=======")
       
        # the oracle #
        E = oracle(N,q,exponent,Range,M_vector,C_vector,Blist)  
        
        # the target vector #   
        t = target_vector(N,Blist,E)
                
        # we apply Babai. We use fpylll implementation of Babai.
        # Here is the place to put CVP-VFK oracle
        # We have to change the matrix before using CVP_VFK the matrix.
        # ie. we have to change the vector {\bf{a}} be 
        # a=(-k,0,...,0)
        
        
        w=cvp_vfk(N,Q,basis,super_basis,t)
        print("CVP done")
        print("time for CVP:",time.time()-start)
        
        print("success/fail:",list(w[0:N])== M_vector)  
    
        #print (vector(u) - vector(w)).norm().n(),q^(1/y).n()
        hits(list(w[0:N]),M_vector)
      
        if list(w[0:N])== M_vector: #or list(w_old[0:N])== M_vector:
            print("total time for the loop:",time.time()-start1)
            break
    print("total time for the attack:",time.time()-start)

In [None]:
#uncomment an example

#N,d,p,q,exponent = nth_prime(10),11,3,2,8
#N,d,p,q,exponent = nth_prime(45),11,3,2,11 # nth_prime(45) = 197
#N,d,p,q,exponent = nth_prime(45),11,3,2,9  
#N,d,p,q,exponent = nth_prime(50),11,3,2,8
#N,d,p,q,exponent = nth_prime(46),41,3,2,8
#N,d,p,q,exponent = nth_prime(52),71,3,2,8
#N,d,p,q,exponent = nth_prime(55),91,3,2,8
#N,d,p,q,exponent = nth_prime(57),99,3,2,8
#N,d,p,q,exponent = nth_prime(61),99,3,2,10
#N,d,p,q,exponent = nth_prime(40),5,3,nth_prime(40),1
#N,d,p,q,exponent = nth_prime(51),11,3,2,9
#N,d,p,q,exponent  = 197,10,3,197,1
#N,d,p,q,exponent  = 509,10,3,2,11
#N,d,p,q,exponent  = 677,20,3,2,11
#N,d,p,q,exponent  = 31,6,3,2,7

#N,d,p,q,exponent = nth_prime(63),15,3,2,10
#f,g,h=gen_keys(N,d,p,q,exponent)
#[f,g];h;[N,d,p,q,exponent]
p = 3
experiment = 2

if experiment==1:
    N,q,exponent = 239,2,8
    d = 71
if experiment==2:
    N,q,exponent = 509,2,11 
    d = 10
if experiment==3:
    N,q,exponent = 509,2,11 
    d = 10

In [None]:
f,g,h=gen_keys(N,d,p,q,exponent)

# we choose a random message

left = ceil(-p/2)
right = floor(p/2)
M = Zx([randint(left,right) for i in range(N)])
#print("the message M:",M) # the message
r = T(d,d,N) #ephemeral key
#show("the ephemeral key r:",r)

### Encryption
e1 = Convolution_in_R_p(h,p*r[1],N,q^exponent)
e = e1 + M;

### Decryption
m = q^exponent;
a = Convolution_in_R_p(f,e,N,m) 
a = CenterLift(a,m,N)  
Fp = Invertmodprime(f,p,N)
b=Convolution_in_R_p(Fp,a,N,p)
dec = CenterLift(b,p,N)
print("is decryption correct?",dec==M)   # we check if we find the message m(x)

In [None]:
# this is the first attack.

y = 2.3
#alpha_vector_classic = [randint(0,1) for i in range(N-1)] + [floor(N*m^(1/y))] # here we choose our alpha vector (type:list)
#alpha_vector = random.shuffle(alpha_vector[:N-1])
kappa,P = find_k_and_P(q^exponent)
alpha_vector_vfk = [-kappa] + [0 for i in range(N-1)] 
alpha_vector_classic = alpha_vector_vfk
A_1,k_1,vector_a_1,init_M_NTRU,Blist_1     = initial_param(N,q,exponent,y,alpha_vector_classic) 
A_2,k_2,vector_a_2,init_M_NTRU_vfk,Blist_2 = initial_param(N,q,exponent,y,alpha_vector_vfk)
C_vector_1,M_vector_1 = corrections(N,q,exponent,p,h,A_1,r)
C_vector_2,M_vector_2 = corrections(N,q,exponent,p,h,A_2,r)
# define the range R and counts
Range  = 32
counts = 100

In [None]:
# using Babai
M_NTRU1 = []
flag = 1
the_attack(N,q,exponent,p,h,A_1,r,Range,Blist_1,init_M_NTRU,M_NTRU1,counts,flag,M_vector_1,C_vector_1)

In [None]:
# using exact CVP
the_attack_VFK(N,q,exponent,p,h,A_2,r,Range,Blist_2,init_M_NTRU_vfk,counts,P,M_vector_2,C_vector_2)