# NEWHOPE

In [116]:
import os
import array
import math
import hashlib
import pickle

class New_Hope_PKE_IND_CPA:
    
    def __init__(self,n=512):
        if not n in [512,1024]:
            raise ValueError("Not Accepted Value of N")
        self.n = n
        
        self.sigma = None
        
        self.q = 1 + 2*n
        while True:
            if (self.q).is_prime():
                break
            self.q += 2*n
        
        # Defining The Ring Z_q[X]/(X^n + 1)
        self.F = GF(self.q)
        self.R_original = PolynomialRing(self.F,'x')
        R = self.R_original.quotient(x**(self.n) + 1,"a")
        self.R = R
        
        w = (self.R_original).gen();
        self.w = w
        
        g = (w^n + 1)
        xi = g.roots(multiplicities=False)[-1]
        self.xi = xi
        rs = [xi^(2*i+1)  for i in range(n)] 
        self.base = crt_basis([(w - r) for r in rs])
        
    
    def b2i(self,number):
        return number
    
    def PolyBitRev(self,poly):
        v_array = [0]*self.n
        for i in range(0,self.n):
            v_array[self.BitRev(i)] = poly[i]
        return self.R(v_array)
    
    def BitRev(self,index):
        result = 0
        for i in range(0,int(math.log(self.n,2))):
            result += (((index >> i) & 1) << (int(math.log(self.n,2)) - 1 - i))
        return result
    
    def hamming_weight(self,n):
        c = 0
        while n:
            c += 1
            n &= n - 1

        return c
    
    def Sample(self,seed,nonce):
        nonce = nonce & 0xFF
        r_array = []
        extseed = [0]*34
        extseed[:32] = seed
        extseed[32] = nonce
        for i in range(0,self.n/64):
            extseed[33] = i
            buf = self.shake256(128,extseed)
            for j in range(0,64):
                a = buf[2*j]
                b = buf[2*j+1]
                r_array.append((self.hamming_weight(a) + self.q - self.hamming_weight(b)) % self.q)
        return self.R(r_array)
    
    def ntt(self,f):
        def _expand_(f): 
            u = f.list()
            return u + [0]*(self.n-len(u)) 
        
        def _ntt_(xi,N,f):
            if N==1:
                return f
            N_ = N/2 ; 
            xi2 =  xi^2  
            f0 = [f[2*i]   for i in range(N_)] ;
            f1 = [f[2*i+1] for i in range(N_)] 
            ff0 = _ntt_(xi2,N_,f0) ;
            ff1 = _ntt_(xi2,N_,f1)  
    
            s  = xi ;
            ff = [self.F(0) for i in range(N)] 
            for i in range(N_):
                a = ff0[i] ;
                b = s*ff1[i]  
                ff[i] = a + b ; 
                ff[i + N_] = a - b 
                s = s * xi2                     
            return ff 
        
        return _ntt_(self.xi,self.n,_expand_(f))
    
    def ntt_inverse(self,ff):               
        return sum([ff[i]*self.base[i] for i in range(self.n)])
    
    def component_multiplication(self,poly_one,poly_two):
        result_poly = []
        for i in range(0,self.n):
            result_poly.append(poly_one[i] * poly_two[i])
        return self.R(result_poly)
    
    def shake256(self,length,message):
        m = hashlib.shake_256()
        m.update(bytearray(message))
        return m.digest(int(length))
    
    def shake128_Absorb(self,message):
        self.state_number = 0
        m = hashlib.shake_128()
        m.update(bytearray(message))
        return m
    
    def shake128_Squeeze(self,j,state):
        self.state_number+=j
        length_digest = (self.state_number + j) * 168
        return (state.digest(int(length_digest))[(self.state_number*168):],state)
    
    def GenA(self,seed):
        a_array = []
        extseed = [0]*33
        extseed[:32] = seed
        for i in range(0,self.n/64):
            ctr = 0
            extseed[32] = i
            state = self.shake128_Absorb(extseed)
            while ctr < 64:
                (buf,state) = self.shake128_Squeeze(1,state)
                j = 0
                while j < 168 and ctr < 64:
                    val = self.b2i(buf[j]) | (self.b2i(buf[j+1]) << 8)
                    if val < 5*self.q:
                        a_array.append(val)
                        ctr = ctr+1
                    j = j+2
        return self.R(a_array)
    
    def Decode(self,coded_message):
        message = [0]*32 # Initializes an array of 32 bytes
        for i in range(0,256):
            t = abs(int(coded_message[i]) - int((self.q - 1)/2))
            t = t + abs(int(coded_message[i+256]) - int((self.q-1)/2))
            if self.n == 1024:
                t = t + abs(int(coded_message[i+512]) - int((self.q - 1)/2))
                t = t + abs(int(coded_message[i+768]) - int((self.q - 1)/2))
                t = t-int(self.q)
            else:
                t = t - int(self.q/2)
            t = t >> 15
            index = i >> 3
            t = -t
            
            message[index] += (2**(i % 8)) * t
        return bytearray(message)
    
    def Encode(self,message):
        message = bytearray(message)  
        v = [0]*1024
        for i in range(0,32):
            for j in range(0,8):
                mask = -((message[i] >> j) & 1)
            
                v[8*i + j + 0] = int(mask) & (int(self.q/2))
                v[8*i + j + 256] = int(mask) & (int(self.q/2))
                if self.n == 1024:
                    v[8*i + j + 512] = int(mask) & (int(self.q/2))
                    v[8*i + j + 768] = int(mask) & (int(self.q/2))
        
        return self.R(v)
    def Compress(self,poly):
        k = 0
        t = [0] * 8
        q = int(self.q.n())
        
        h = []
        
        for l in range(0,self.n -1):
            i = 8*l
            for j in range(0,8):
                t[j] = poly[i+j] % q
                t[j] = int((((int(t[j]) << 3) + q/2)/q)) & 7
            h.append( t[0] | (t[1] << 3) | (t[2] << 6) )
            h.append( (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7))
            h.append((t[5] >> 1) | (t[6] << 2) | (t[7] << 5))
        return h
    
    def Decompress(self,a):
        k = 0
        r = [0]*self.n
        for l in range(0,self.n/8):
            i = 8*l
            r[i] = a[k + 0] & 7
            r[i + 1] = (a[k + 0] >> 3) & 7
            r[i + 2] = (a[k+ 0] >> 6)|((a[k+ 1] << 2) & 4)
            r[i + 3] = (a[k+ 1] >> 1) & 7
            r[i + 4] = (a[k+ 1] >> 4) & 7
            r[i + 5] = (a[k+ 1] >> 7)|((a[k+ 2] << 1) & 6)
            r[i + 6] = (a[k+ 2] >> 2) & 7
            r[i + 7] = (a[k+ 2] >> 5)
            k = k + 3
            for j in range(0,8):
                r[i+j] = (r[i+j] * int(self.q) + 4) >> 3
        return self.R(r)
    
    def EncodeC(self,c):
        (u_hat,h) = c
        r = []
        r[0:((7*self.n/4))] = self.EncodePolynomial(u_hat)
        r[(7*self.n/4):] = h
        return r
    
    def DecodeC(self,enc):
        u_hat = self.DecodePolynomial(enc[0:((7*self.n/4))])
        h = enc[(7*self.n/4):]
        return (u_hat,h)
    
    def EncodePolynomial(self,s_hat):
        r = []
        for i in range(0,self.n/4):
            t0 = int(s_hat[4*i + 0])
            t1 = int(s_hat[4*i + 1])
            t2 = int(s_hat[4*i + 2])
            t3 = int(s_hat[4*i + 3])
            r.append(t0&0xff)
            r.append((t0 >> 8)|(t1 << 6)&0xff)
            r.append((t1 >> 2)&0xff)
            r.append((t1 >> 10)|(t2 << 4)&0xff)
            r.append((t2 >> 4)&0xff)
            r.append((t2 >> 12)|(t3 << 2)&0xff)
            r.append((t3 >> 6)&0xff)
        return r
    
    def DecodePolynomial(self,v):
        r = []
        for i in range(0,self.n/4):
            r.append(int(v[7*i + 0])|(int(v[7*i+ 1]) & 0x3f) << 8)
            r.append((int(v[7*i+ 1])>>6)|(int(v[7*i+ 2]) <<2)|((int(v[7*i+ 3]) & 0x0f) <<10))
            r.append((int(v[7*i+ 3]) >> 4)|(int(v[7*i+ 4]) << 4)|((int(v[7*i+ 5]) & 0x03) << 12))
            r.append((int(v[7*i+ 5]) >> 2)|(int(v[7*i+ 6]) << 6))
        return r
    
    def EncodePK(self,b,publicSeed):
        r = []
        r[:((7*self.n/4))] = self.EncodePolynomial(b)
        r[7*self.n/4:] = publicSeed
        return r
    
    def DecodePK(self,pk):
        b = self.DecodePolynomial(pk[:(7*self.n/4)])
        seed = pk[(7*self.n/4):]
        return (b,seed)
    
    def KeyGen(self):
        seed = os.urandom(32)
        z = self.shake256(64,(int(0x01)).to_bytes(1,byteorder="big") + seed)
        publicSeed = z[:32]
        noiseSeed = z[32:]
        a_hat = self.GenA(publicSeed)
        s = self.PolyBitRev(self.Sample(noiseSeed,0))
        s_hat = self.R(self.ntt(s))
        e = self.PolyBitRev(self.Sample(noiseSeed,1))
        e_hat = self.R(self.ntt(e))
        b_hat = self.component_multiplication(a_hat,s_hat) + e_hat
        return(self.EncodePK(b_hat,publicSeed),self.EncodePolynomial(s_hat))
    
    def Encrypt(self,pk,message,coin):
        # PK 7*n/4 + 32
        # mess 32
        # coin 32
        if len(message) != 32:
            raise ValueError("Message must 32 bytes long")
        (b_hat,publicSeed) = self.DecodePK(pk)
        a_hat = self.GenA(publicSeed)
        s_prime = self.PolyBitRev(self.Sample(coin,0))
        e_prime = self.PolyBitRev(self.Sample(coin,1))
        e_prime2 = self.Sample(coin,2)
        t_hat = self.ntt(s_prime)
        u_hat = self.component_multiplication(a_hat,t_hat) + self.R(self.ntt(e_prime))
        v = self.Encode(message)
        v_prime = self.ntt_inverse(self.component_multiplication(b_hat,t_hat)) + e_prime2 + v
        h = self.Compress(v_prime)
        return self.EncodeC((u_hat,h))
    
    def Decrypt(self,cripto,sk):
        (u_hat,h) = self.DecodeC(cripto)
        s_hat = self.DecodePolynomial(sk)
        v_prime = self.Decompress(h)
        mess_poly = v_prime - self.ntt_inverse(self.component_multiplication(u_hat,s_hat))
        mess = self.Decode(mess_poly)
        return mess

class New_Hope_KEM_IND_CPA:
    
    def __init__(self,n=512):
        self.pke = New_Hope_PKE_IND_CPA(n)
    
    def KeyGen(self):
        return self.pke.KeyGen()
    
    def Encapsulate(self,pk):
        coin = os.urandom(32)
        result = self.pke.shake256(64,b"\x02" + coin)
        K = result[:32]
        coin_prime = result[32:]
        cr = self.pke.Encrypt(pk,K,coin_prime)
        ss = self.pke.shake256(32,K)
        return (cr,ss)
    
    def Decapsulate(self,c,sk):
        K_prime = self.pke.Decrypt(c,sk)
        return self.pke.shake256(32,K_prime)

class New_Hope_KEM_IND_CCA():
    
    def __init__(self,n=512):
        self.n = n
        self.pke = New_Hope_PKE_IND_CPA(n)
        
    def KeyGen(self):
        (pk,sk) = self.pke.KeyGen()
        s = os.urandom(32)
        
        return (pk, bytearray(sk) + bytearray(pk) + self.pke.shake256(32,pk) + s)
    
    def Encapsulate(self,pk):
        coin = os.urandom(32)
        miu = self.pke.shake256(32,b"\x04" + coin)
        result = self.pke.shake256(96,b"\x08" + miu + self.pke.shake256(32,pk))
        K = result[:32]
        coin_prime = result[32:64]
        d = result[64:]
        c = self.pke.Encrypt(pk,miu,coin_prime)

        ss = self.pke.shake256(32,K + self.pke.shake256(32,pickle.dumps(c) + d))
 
        return (pickle.dumps(c) + d, ss)
    
    def Decapsulate(self,c_bar,sk_bar):
        c = pickle.loads(c_bar[:-32])
        d = c_bar[-32:]

        sk = sk_bar[:7*self.n/4]
        pk = sk_bar[7*self.n/4:7*self.n/2 + 32]
        h = sk_bar[7*self.n/2 + 32:7*self.n/2 + 64]
        s = sk_bar[32:7*self.n/2 + 64:]
        
        miu = self.pke.Decrypt(c,sk)
        
        result = self.pke.shake256(96,b"\x08" + miu + h)
        K_prime = result[:32]
        coin_prime2 = result[32:64]
        d_prime = result[64:]
        fail = 1
        K = s
        if c == self.pke.Encrypt(pk,miu,coin_prime2) and d == d_prime:
            fail = 0
            K = K_prime
        return self.pke.shake256(32,K + self.pke.shake256(32,pickle.dumps(c) + d))
        
        
pke = New_Hope_PKE_IND_CPA()

(pk,sk) = pke.KeyGen()
cripto = pke.Encrypt(pk,b"3a3a2222222222622222222222222222",os.urandom(32))
print("PKE-IND-CPA")
print(pke.Decrypt(cripto,sk))
print("\n")
kem = New_Hope_KEM_IND_CPA()
(pk,sk) = kem.KeyGen()
(c,ss) = kem.Encapsulate(pk)
rss = kem.Decapsulate(c,sk)
print("KEM-IND-CPA")
print(ss)
print(rss)
print("\n")

print("KEM_IND_CCA")
kem_cca = New_Hope_KEM_IND_CCA()
(pk,sk) = kem_cca.KeyGen()
(enc,ss) = kem_cca.Encapsulate(pk)
rss = kem_cca.Decapsulate(enc,sk)
print(ss)
print(rss)

print("\n")


PKE-IND-CPA
bytearray(b'3a3a2222222222622222222222222222')


KEM-IND-CPA
b"r\xe4\x99'i\xe9`s\x833\xac)U%Kv{\x08\xd1(e\xfc4dU\x9c#]\x08\xd2\x1e\xc8"
b"r\xe4\x99'i\xe9`s\x833\xac)U%Kv{\x08\xd1(e\xfc4dU\x9c#]\x08\xd2\x1e\xc8"


KEM_IND_CCA
b'\xfcP\xf3\xce\xaa\xafB=\xea\xac]q8"\xdf,\x97 \x8e\xb2\xde/\xa4R\x90\xc2@j\xfc\xbc\xf0\x06'
b'\xfcP\xf3\xce\xaa\xafB=\xea\xac]q8"\xdf,\x97 \x8e\xb2\xde/\xa4R\x90\xc2@j\xfc\xbc\xf0\x06'


