In [None]:
# sagemath version 9.8
# Authors : K.A.Draziotis
# credits : https://latticehacks.cr.yp.to/ntru.html
reset()

# Auxiliary functions

In [None]:
%run auxiliary.ipynb

 We set, $\Phi_1(x)=x-1$ and  $\Phi_N(x) = D(x)/\Phi_1(x).$<br>
$N$ is prime, and $q\le 16N/3+16$ which is a power of two,
 $D(x)=x^N-1.$<br><br>
`GenKey(seed)`. Takes as input a *seed* and outputs a quadruple
`(pk,sk)`=$({\textbf h},({\textbf f},{\textbf f}_3,{\textbf h}_q)).$\
\
`1.` $(f(x),g(x))\xleftarrow{\$} {\mathcal{L}}_f\times {\mathcal{L}}_g$
\# `here we use the seed`\
`2.` $f_q(x)\leftarrow f^{-1}(x)\mod{(q,\Phi_N(x))}$ \# $\Phi_N(x)=x^{N-1}+\cdots x+1$\
`3.` $f_3(x)\leftarrow f^{-1}(x)\mod{(3,\Phi_N(x))}$\
`4.` $h(x) \leftarrow 3g(x)*f_q(x)\mod{(q,D(x))}$  $\# D(x)=x^N-1$\
`5.` $h_q(x) \leftarrow h^{-1}(x)\mod{(q,\Phi_N(x))}$ \
`6.` ${\mathbb{S}} \xleftarrow{\$} \{0,1\}^{256}$ \
`7.` **return**
 $(pk,sk)=({\textbf h},({\textbf f},{\textbf f}_3,{\textbf h}_q,{\mathbb{S}}))$


We also need to define the sample sets.<br>
With ${\mathcal{T}}_a$ we denote the set of **ternary** polynomials of
${\mathcal{R}}={\mathbb{Z}}[z]/(x^a-1)$ with degree at most $a-1$ and
$\mathcal{T}_a(d_1,d_2)\subset {\mathcal{T}}_{a}$ consists from elements
of ${\mathcal{T}}_{a}$ with $d_1$ coefficients equal to $1$ and $d_2$
equal to $-1.$ Furthermore, with
${\mathcal{T}}_a(w), w\in{\mathbb{Z}}_{>0}$ we denote the ternary
polynomials which have $w$ non-zero coefficients. We also have four sample spaces,
${\mathcal{L}}_f, {\mathcal{L}}_g, {\mathcal{L}}_r,$ and ${\mathcal{L}}_m.$\
$-$ ${\mathcal{L}}_m={\mathcal{L}}_g=
{\mathcal{T}}_{N-2}(\frac{q}{16}-1,\frac{q}{16}-1),$\
$-$ ${\mathcal{L}}_f={\mathcal{L}}_r={\mathcal{T}}_{N-2}.$


# Key Generation/Encrypt/Decrypt

In [None]:
def random32(): return randrange(-2^31,2^31)        # a random integer in -2^31 and 2^31

def randomrange3():  
    ''' 
    generate a random integer in {0, 1, 2} 
    by utilizing a 32-bit random number generator : random32()
    '''
    return ((random32() & 0x3fffffff) * 3) >> 30


# we generate ternary polynomials of degre at most with d1 one and d2 minus one
def T(d1,d2,N):
    if q/8-2>N:
        print("error")
        return _,_
    import random
    Zx.<x> = ZZ[]    
    a = d1*[1]
    b = d2*[-1]
    c = (N+1-d1-d2)*[0] #the length must be N+1 to get a polynomial of degree at most N
    L = flatten([a,b,c])
    random.shuffle(L) 
    return L,Zx(L)
    

def CenterLift(f,q,N):    
    f_balanced = list(   ((f[i]+q//2)%q) -q//2  for i in range(N))
    return Zx(f_balanced)

def reduce_mod_PhiN_and_modq(f):
    return Zx(S(f).lift())%q

def reduce_mod_DN_and_modq(f): # if f=G*F then we get the multiplication of G and F in R/(q,x^N-1) 
    return Zx(R(f).lift())%q

def reduce_mod_DN(f): # if f=G*F then we get the multiplication of G and F in R/(q,x^N-1) 
    return Zx(R(f).lift())

def Convolution_in_S_q(f,g,N,q): 
    PhiN=sum(x^i for i in [0..N-1])
    h = (f*g)%(PhiN)
    return  Zx(h)%q  

def Convolution_in_S(f,g,N):
    PhiN=sum(x^i for i in [0..N-1])
    h = (f*g)%(PhiN)
    #h1 = list( (h[i]%q)   for i in range(N) )
    return  Zx(h) 

def Invertmodprime_and_phiN(f,p,N):        #p must be prime
    phiN = sum(x^i for i in [0..N-1])
    if is_prime(p)==False:
        return "error"
    T = Zx.change_ring(Integers(p)).quotient(PhiN)
    return Zx(lift(1/T(f)))

def invert_mod_2_and_phiN(f):
    return Zx( (1/S2(f)).lift())

def Invertmodpowerofprime_mod_PhiN(f,Q,e,N): # Q prime, e: exponent
    F = invert_mod_2_and_phiN(f)
    if e == 1:      
        return F
    temp_exponent = 2
    while e>0:
        temp = Convolution_in_S(F,f,N)
        F = Convolution_in_S_q(F,2-temp,N,Q^temp_exponent)
        e = floor(e/2)
        temp_exponent = 2*temp_exponent
    return F

def Invert_mod3_and_PhiN(f):
    T = 1/S3(f)
    return T    

def private_keys(N):
    # for L_g we choose d=q/16-1
    # for L_f we choose T_{N-2}
    while True:
        f = Zx([randomrange3()-1 for i in range(N-1)])
        if S3(f).is_unit(): break
    d=q/16-1;g = T(d,d,N-2)
    return f,g[1]

def gen_keys(N,q): 
    # N,p primes, e exponent 
    # q = (base)^(exponent)
    f,g=private_keys(N);
    # We set D = x^N-1 and phiN=D/(x-1)
    base_of_q = 2
    exponent  = int(log(q,2))
    try:
        fq=Invertmodpowerofprime_mod_PhiN(f,base_of_q,exponent,N)
    except ZeroDivisionError:
        print("Oops! there is not inverse of f in R_{0}".format(q))
        return _,_,_,_,_,_
    try:
        f3 = Invert_mod3_and_PhiN(f)
    except ZeroDivisionError:
        print("Oops! there is not inverse of f in R_{0}".format(3))
        return _,_,_,_,_,_
    fq = fq%q
    h   = 3*reduce_mod_DN_and_modq(fq * g); # public key, we work mod<q,x^N-1>
    try: 
        hq=Invertmodpowerofprime_mod_PhiN(h,base_of_q,exponent,N)
    except ZeroDivisionError:
        print("Oops! there is not inverse of h in R_{0}".format(q))
        return
    hq = hq%q
    #print("public key\nh=3g/f:{0} \nprivate key\nf:{1}\nf3:{2}\nhq:{3}\ng:{4}".format(h,f,f3,hq,g))
    return h,f,Zx(f3.lift()),hq,g,fq

### Encryption
def enc(N,q,msg,h):
    r = Zx([randomrange3()-1 for i in range(N-1)]) #ephemeral key, a ternary of degree <=N-2
    ### Encryption
    ct1=reduce_mod_DN_and_modq(h*r)  # mod <q,x^N-1>
    ct = reduce_mod_DN_and_modq(ct1 + msg)   # mod <q,x^N-1>
    #print("r={0}\nct1={1}\nmsg={2}\nct={3}".format(r,ct1,msg,ct))
    return ct,r

def decryption(ct,f,m):
    a=reduce_mod_DN_and_modq(ct * f)
    a=CenterLift(a,q,N)
    if S3(a*f3)==S3(m):
        print("decryption OK")
    return Zx(S3(a*f3).lift())

# construction of M_h
def matrix_for_the_lattice(N,q,h):
    H = matrix(N)
    I=identity_matrix(N)
    Zero_Matrix=matrix(N)
    for i in range(N):
        for j in range(N):
            H[i,j] = reduce_mod_DN(h*x^i)[j]
       
    B_1=block_matrix([[I,H]])        
    B_2=block_matrix([[Zero_Matrix,q*I]])
    M_NTRU=block_matrix([[B_1],[B_2]])
    return M_NTRU

# choose the target vector
def target_vector(N,ct):
    '''
    Input
    -----
    N   : initial param of NTRU-prime, N=p
    E   : the guessing vector, from the oracle
    ct  : the ciphertext
    
    Output
    ------
    the target vector of the form
    (0_N,ct)
    '''  
    ct_list=ct.coefficients(sparse=False)
    #r_list=r.coefficients(sparse=False)
    r_l = correction_of_msg(N,r)
    r_l_1 =[-x for x in r_l]
    print(r_l)
    E = vector(r_l + [randint(-1,1) for i in range(N)])
    if len(ct_list)==N:
        t = vector([0]*N + ct_list)+ E
    else:
        print("the length of ct <N")
        return _
    return t # the target vector t = (0_N,b) + E    

In [None]:
#parameters for NTRU-HPS KEM
# returns N,q
def ntruhps(x):
    if x==1:
        return 509,2048 #ntruhps2048509
    if x==2:
        return 677,2048
    if x==3:
        return 821,4096
N,q=ntruhps(2)

# The Attack<br>
The first step is to define the matrix 
$$ 
	M_{\bf h}=
    \left[\begin{array}{c|c}
	I_N & C({\bf h})  \\
	\hline
	{\textbf 0}_N & qI_N   \\
	\end{array}\right].
    $$
Where $C({\bf h})$ is the cyclic matrix generated by the public key ${\bf h}.$ Then we consider the lattice generated by the rows of the matrix, and we consider the target vector ${\bf t}=({\bf 0}_N,{\bf c}).$ We know that the distance of the lattice point ${\bf x}=(3{\bf r},{\bf c}-{\bf m})$ from ${\bf t}$ is at most $\sqrt{10N}.$ Indeed,
$$d({\bf x},{\bf t})=\|(3{\bf r},{\bf c}-{\bf m})-({\bf 0}_N,{\bf c}) \|=\|(3{\bf r},-{\bf m})\|=\sqrt{9{\|\bf r}\|^2 + \|{\bf m}\|^2}\le \sqrt{9N+N}=\sqrt{10N}.$$
This is for unweighted messages. If the weight os $q/8-2$ then
$$d({\bf x},{\bf t})=\|(3{\bf r},{\bf c}-{\bf m})-({\bf 0}_N,{\bf c}) \|=\|(3{\bf r},-{\bf m})\|=\sqrt{9{\|\bf r}\|^2 + \|{\bf m}\|^2}\le \sqrt{9N+q/8-2}.$$
For NTRU-HPS $q\approx 4N$ so
$$d({\bf x},{\bf t})=\|(3{\bf r},{\bf c}-{\bf m})-({\bf 0}_N,{\bf c}) \|=\|(3{\bf r},-{\bf m})\|=\sqrt{9{\|\bf r}\|^2 + \|{\bf m}\|^2}\le \sqrt{9N+N/2}< 3\sqrt{N}.$$
*Why ${\bf x}\in{\mathcal{L}}_{\bf h}$?*
From encryption equation $c(x) = 3h(x)*r(x)+m(x)\mod(q,x^N-1)$ we get
$c(x)=3h(x)*r(x)+m(x)+qv(x).$ This equation if we transfer it in vector is written
${\bf c}=3{\bf r}C({\bf h})+{\bf m}+q{\bf v}.$ Now,
$$(3{\bf r},{\bf v})M_{\bf h}=(3{\bf r},3{\bf r}C({\bf h})+q{\bf v})=(3{\bf r},{\bf c}-{\bf m})$$
Gaussian heuristic suggests $GH({\mathcal{L}}_{\bf h})=0.35\sqrt{q}.$ Now, on average $q\approx 4N$ (for HPS),
so $GH({\mathcal{L}}_{\bf h})=0.7\sqrt{N}$ that is $||{\bf x}-{\bf t}||\approx 4.5GH({\mathcal{L}}_{\bf h})$ for unweighted messages. If the weight is $q/8-2$ then get a similar bound.


In [None]:
N,q=41,2^6
#N,p,q,d=239,3,2^8,71
print("N,q=",N,q)
print(q/16 - 1<=2*N/3)
D,PhiN,Phi1=x^N-1,sum(x^i for i in [0..N-1]),x-1

Zx.<x> = ZZ[]
R.<xN> = Zx.quotient(D)
S.<XN> = Zx.quotient(PhiN)

F3 = GF(3); 
F3x.<x3> = F3[]; 
Phi3N = sum(x3^i for i in [0..N-1])
S3.<X3> = F3x.quotient(Phi3N)

F2 = GF(2); F2x.<x2> = F2[]
Phi2N=sum(x2^i for i in [0..N-1])
S2.<X2> = F2x.quotient(Phi2N)
h,f,f3,hq,g,fq = gen_keys(N=N,q=q)

In [None]:
from fpylll import IntegerMatrix,LLL,GSO
for _ in range(5):
    # new parameters
    msg = T(q/16-1,q/16-1,N-2) # Since we want messages of weight q/8-2.
    ct,r=enc(N,q,msg[1],h)
    #print(r,r.coefficients(sparse=False),correction_of_msg(N,r))
    m = msg[1]
    _=decryption(ct,f,m)
    m_list=correction_of_msg(N,m)
    #print(m_list)
    m_vector=vector(m_list)
    ct_list=ct.coefficients(sparse=False)
    ct_vector=vector(ct_list)
    #print("ct_vector:",ct_vector)
    #babai
    M_NTRU,M_NTRU_fplll=LLL_reduction_of_M_NTRU(matrix_for_the_lattice(N,q,h))
    target=target_vector(N,ct)
    M_GSO = GSO.Mat(M_NTRU_fplll)
    M_GSO.update_gso()
    L_babai = M_GSO.babai(target)
    w_babai = sum(L_babai[i]*M_NTRU[i] for i in range(M_NTRU_fplll.nrows)).list()
    bab=w_babai[N:2*N]
    bab_vector=vector(bab)
    output=ct_vector-bab_vector
    print(output)
    print(target)
    print("Success of the attack:",output==m_vector) # we check if our attack found the message
    print("distance:",(m_vector-output).norm().n() )