# NEWHOPE

In [38]:
import os
import array
import math
import hashlib

class New_Hope_PKE_IND_CPA:
    
    def __init__(self,n=512):
        if not n in [512,1024]:
            raise ValueError("Not Accepted Value of N")
        self.n = n
        
        self.sigma = None
        
        self.q = 1 + 2*n
        while True:
            if (self.q).is_prime():
                break
            self.q += 2*n
        
        # Defining The Ring Z_q[X]/(X^n + 1)
        self.F = GF(self.q)
        self.R_original = PolynomialRing(self.F,'x')
        R = self.R_original.quotient(x**(self.n) + 1,"a")
        self.R = R
        
        w = (self.R_original).gen();
        self.w = w
        
        g = (w^n + 1)
        xi = g.roots(multiplicities=False)[-1]
        self.xi = xi
        rs = [xi^(2*i+1)  for i in range(n)] 
        self.base = crt_basis([(w - r) for r in rs])
        
    
    def b2i(self,number):
        return number
    
    def PolyBitRev(self,poly):
        v_array = [0]*self.n
        for i in range(0,self.n):
            v_array[self.BitRev(i)] = poly[i]
        return self.R(v_array)
    
    def BitRev(self,index):
        result = 0
        for i in range(0,int(math.log(self.n,2))):
            result += (((index >> i) & 1) << (int(math.log(self.n,2)) - 1 - i))
        return result
    
    def hamming_weight(self,n):
        c = 0
        while n:
            c += 1
            n &= n - 1

        return c
    
    def Sample(self,seed,nonce):
        nonce = nonce & 0xFF
        r_array = []
        extseed = [0]*34
        extseed[:32] = seed
        extseed[32] = nonce
        for i in range(0,self.n/64):
            extseed[33] = i
            buf = self.shake256(128,extseed)
            for j in range(0,64):
                a = buf[2*j]
                b = buf[2*j+1]
                r_array.append((self.hamming_weight(a) + self.q - self.hamming_weight(b)) % self.q)
        return self.R(r_array)
    
    def ntt(self,f):
        def _expand_(f): 
            u = f.list()
            return u + [0]*(self.n-len(u)) 
        
        def _ntt_(xi,N,f):
            if N==1:
                return f
            N_ = N/2 ; 
            xi2 =  xi^2  
            f0 = [f[2*i]   for i in range(N_)] ;
            f1 = [f[2*i+1] for i in range(N_)] 
            ff0 = _ntt_(xi2,N_,f0) ;
            ff1 = _ntt_(xi2,N_,f1)  
    
            s  = xi ;
            ff = [self.F(0) for i in range(N)] 
            for i in range(N_):
                a = ff0[i] ;
                b = s*ff1[i]  
                ff[i] = a + b ; 
                ff[i + N_] = a - b 
                s = s * xi2                     
            return ff 
        
        return _ntt_(self.xi,self.n,_expand_(f))
    
    def ntt_inverse(self,ff):               
        return sum([ff[i]*self.base[i] for i in range(self.n)])
    
    def component_multiplication(self,poly_one,poly_two):
        result_poly = []
        for i in range(0,self.n):
            result_poly.append(poly_one[i] * poly_two[i])
        return self.R(result_poly)
    
    def shake256(self,length,message):
        m = hashlib.shake_256()
        m.update(bytearray(message))
        return m.digest(int(length))
    
    def shake128_Absorb(self,message):
        self.state_number = 0
        m = hashlib.shake_128()
        m.update(bytearray(message))
        return m
    
    def shake128_Squeeze(self,j,state):
        self.state_number+=j
        length_digest = (self.state_number + j) * 168
        return (state.digest(int(length_digest))[(self.state_number*168):],state)
    
    def GenA(self,seed):
        a_array = []
        extseed = [0]*33
        extseed[:32] = seed
        for i in range(0,self.n/64):
            ctr = 0
            extseed[32] = i
            state = self.shake128_Absorb(extseed)
            while ctr < 64:
                (buf,state) = self.shake128_Squeeze(1,state)
                j = 0
                while j < 168 and ctr < 64:
                    val = self.b2i(buf[j]) | (self.b2i(buf[j+1]) << 8)
                    if val < 5*self.q:
                        a_array.append(val)
                        ctr = ctr+1
                    j = j+2
        return self.R(a_array)
    
    def Decode(self,coded_message):
        message = [0]*32 # Initializes an array of 32 bytes
        for i in range(0,256):
            t = abs(int(coded_message[i]) - int((self.q - 1)/2))
            t = t + abs(int(coded_message[i+256]) - int((self.q-1)/2))
            if self.n == 1024:
                t = t + abs(int(coded_message[i+512]) - int((self.q - 1)/2))
                t = t + abs(int(coded_message[i+768]) - int((self.q - 1)/2))
                t = t-int(self.q)
            else:
                t = t - int(self.q/2)
            t = t >> 15
            index = i >> 3
            t = -t
            
            message[index] += (2**(i % 8)) * t
        return bytearray(message)
    
    def Encode(self,message):
        message = bytearray(message)  
        v = [0]*1024
        for i in range(0,32):
            for j in range(0,8):
                mask = -((message[i] >> j) & 1)
            
                v[8*i + j + 0] = int(mask) & (int(self.q/2))
                v[8*i + j + 256] = int(mask) & (int(self.q/2))
                if self.n == 1024:
                    v[8*i + j + 512] = int(mask) & (int(self.q/2))
                    v[8*i + j + 768] = int(mask) & (int(self.q/2))
        
        return self.R(v)
    
    def KeyGen(self):
        seed = os.urandom(32)
        z = self.shake256(64,(int(0x01)).to_bytes(1,byteorder="big") + seed)
        publicSeed = z[:32]
        noiseSeed = z[32:]
        a_hat = self.GenA(publicSeed)
        s = self.PolyBitRev(self.Sample(noiseSeed,0))
        s_hat = self.R(self.ntt(s))
        e = self.PolyBitRev(self.Sample(noiseSeed,1))
        e_hat = self.R(self.ntt(e))
        b_hat = self.component_multiplication(a_hat,s_hat) + e_hat
        return((b_hat,publicSeed),s_hat)
    
    def Encrypt(self,pk,message,coin):
        # PK 7*n/4 + 32
        # mess 32
        # coin 32
        if len(message) != 32:
            raise ValueError("Message must 32 bytes long")
        (b_hat,publicSeed) = pk
        a_hat = self.GenA(publicSeed)
        s_prime = self.PolyBitRev(self.Sample(coin,0))
        e_prime = self.PolyBitRev(self.Sample(coin,1))
        e_prime2 = self.Sample(coin,2)
        t_hat = self.ntt(s_prime)
        u_hat = self.component_multiplication(a_hat,t_hat) + self.R(self.ntt(e_prime))
        v = self.Encode(message)
        v_prime = self.ntt_inverse(self.component_multiplication(b_hat,t_hat)) + e_prime2 + v
        return (u_hat,v_prime)
    
    def Decrypt(self,cripto,sk):
        (u_hat,v_prime) = cripto
        s_hat = sk
        mess_poly = v_prime - self.ntt_inverse(self.component_multiplication(u_hat,s_hat))
        mess = self.Decode(mess_poly)
        return mess

class New_Hope_KEM_IND_CPA:
    
    def __init__(self,n=512):
        self.pke = New_Hope_PKE_IND_CPA(n)
    
    def KeyGen(self):
        return self.pke.KeyGen()
    
    def Encapsulate(self,pk):
        coin = os.urandom(32)
        result = self.pke.shake256(64,b"\x02" + coin)
        K = result[:32]
        coin_prime = result[32:]
        cr = self.pke.Encrypt(pk,K,coin_prime)
        ss = self.pke.shake256(32,K)
        return (cr,ss)
    
    def Decapsulate(self,c,sk):
        K_prime = self.pke.Decrypt(c,sk)
        return self.pke.shake256(32,K_prime)
    
pke = New_Hope_PKE_IND_CPA()

(pk,sk) = pke.KeyGen()
cripto = pke.Encrypt(pk,b"3a3a2222222222622222222222222222",os.urandom(32))
print("PKE-IND-CPA")
print(pke.Decrypt(cripto,sk))
print("\n")
kem = New_Hope_KEM_IND_CPA()
(pk,sk) = kem.KeyGen()
(c,ss) = kem.Encapsulate(pk)
rss = kem.Decapsulate(c,sk)
print("KEM-IND-CPA")
print(ss)
print(rss)
print("\n")


PKE-IND-CPA
bytearray(b'3a3a2222222222622222222222222222')


KEM-IND-CPA
b'\xafD\xa0\xf585\x9d\x16}{\xb9{\x987HR3\xbdik\xcd\xc9CR\x10\xd0.\xec;E\x89t'
b'\xafD\xa0\xf585\x9d\x16}{\xb9{\x987HR3\xbdik\xcd\xc9CR\x10\xd0.\xec;E\x89t'




In [None]:


            
F = GF(q);
R = PolynomialRing(F, name="w")

w = (R).gen();
w = w
print(w)

g = (w^n + 1)
print(g)
xi = g.roots(multiplicities=False)[-1]

print(xi)