In [3]:
#check related prime
def is_related_prime(k, n):
    n = int(n/2)
    for i in range(2,n):
        if k%i==0 and n%i==0:
            return False
    return True

#calculate division over F_q
def cal_div(x , q):
    if x == 0:
        return 0
    if x == 1:
        return 1
    for i in range(1, q):
        if (x*i) % q == 1:
            return i
    return 0 

In [4]:
#point over elliptic curve
class point(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        if type(y) == int:
            self.my = - self.y
    def Print(self):
        return list([self.x, self.y])
        
#Embedding plane text on the elliptic curve
class EC(object):
    # y = x^3 + a*x + b over Z_q
    def __init__(self, q, a, b):
        assert 0 < a and a < q and 0 < b and b < q and q > 2
        assert (4 * (a ** 3) + 27 * (b ** 2))  % q != 0
        self.a = a
        self.b = b
        self.q = q
        
    def Print(self):
        return [self.a, self.b, self.q]
    
    # 對於x 判斷在曲線上是否有解
    def point_on_curve(self, x):
        square_of_y = (x**3 + x*self.a + self.b) % self.q
        for i in range(self.q):
            if (i*i) % self.q == square_of_y:
                return i
        return False

class EG(object):
    def __init__(self, ec):
        self.ec = ec
        self.zero = (0, 0, 0)
    def get_point(self, m):
        p = point(0, 0)
        if self.ec.point_on_curve(m):
            py = self.ec.point_on_curve(m)
        else:
            py = "not exist"
        p = point(m, py)
        return p
    def is_valid(self, p):
        if p == self.zero: return True
        l = (p.y ** 2) % self.ec.q
        r = ((p.x ** 3) + self.ec.a * p.x + self.ec.b) % self.ec.q
        return l == r
            
    def add_on_curve(self, p1, p2):
        p1x = p1.x
        p1y = p1.y
        p2x = p2.x
        p2y = p2.y
        
        #p1 + p2 = p3
        p3 = point(0, 0)
        if p1 == point(0, 0):
            return p2
        if p2 == point(0, 0):
            return p1
        # p1 + -p1 == 0
        if p1x == p2x and (p1y != p2y or p1y == 0):
            p3 = point(0, 0)
            return p3
        # p1 + p1: use tangent line of p1 as (p1,p1) line
        if p1x != p2x:
            m = (p2y - p1y ) * cal_div((p2x - p1x), self.ec.q) % self.ec.q
            p3x = (m*m - p1x - p2x) % self.ec.q
            p3y = (m*(p1x - p3x) - p1y) % self.ec.q
            p3 = point(p3x, p3y)
            return p3
        else:
            m = (3*(p1x*p1x) + self.ec.a) * cal_div(2*p1y, self.ec.q) % self.ec.q
#             print('m:{}, a:{}, x:{}, y:{}, q:{}, un:{}, in:{}'.format(m, self.ec.a, p1x, p1y, ec.q, 2*p1.y, cal_div(2*p1y, ec.q)))
            p3x = ((m*m) - 2*p1x) % self.ec.q
#             print('m: {} , p1x: {}, p3x: {}'.format(m, p1x, p3x))
            p3y = (m*(p1x - p3x) - p1y) % self.ec.q
            p3 = point(p3x, p3y)
            return p3
    def mul_on_curve(self, k, p):
        pxx = p
        while k-1 != 0:
            k = k-1
            pxx = self.add_on_curve(pxx, p)
        return pxx
    def order_of_point(self, p):
        k = 1
        p1 = p
        p1 = self.add_on_curve(p1, p)
        while p1.y != 0:
            k = k+1
            p1 = self.add_on_curve(p1, p)
        return k+1

In [5]:
class ECDH(object):
    def __init__(self, ec, eg, kA, kB, G):
        self.ec = ec    #ellpiptic curve
        self.eg = eg
        self.kA = kA
        self.kB = kB
        self.G = G
        self.kAG = self.eg.mul_on_curve(self.kA, self.G)
        self.kBG = self.eg.mul_on_curve(self.kB, self.G)
 
    def encode(self, m):
        P_m = self.eg.get_point(m)
        assert type(P_m.y) == int
        kAkBG = self.eg.mul_on_curve(self.kA, self.kBG)
        return self.eg.add_on_curve(P_m, kAkBG)
    
    def decode(self, encode_m):
        kBkAG = self.eg.mul_on_curve(self.kB, self.kAG)
        inverse_kBkAG = point(kBkAG.x, -kBkAG.y)
        return self.eg.add_on_curve(encode_m, inverse_kBkAG)

In [6]:
#find random prime number of bit long 14, 20,32
import random
import numpy as np

def get_prime(bit_long):
    assert bit_long >= 3
    cond = True
    while cond:
        n = random.randint((2**(bit_long-1))+1, 2**bit_long)
        n_sqrt = int(np.sqrt(n))
        for i in range(2, n_sqrt+1):
            if (n%i == 0):
                break
            if (i == n_sqrt):
                if n%i != 0:
                    cond = False
    return n

In [7]:
def get_encode_message(bit_long):
    cond1 = True
    while cond1:
        p  = get_prime(bit_long)
        a  = random.randint(0, p-1)
        b  = random.randint(0, p-1)
        if EC(p,a,b):
            ec = EC(p,a,b)
            x  = random.randint(0, p-1)
            G = EG(ec).get_point(x)
            if type(G.y) == int:
                if  EG(ec).order_of_point(G)>16:
                    cond1 = False
#     print('p : {}, a : {}, b : {}, G : ({}, {}))'.format(p, a, b, G.x, G.y))
    ec = EC(p,a,b)
    eg = EG(ec)
    order = EG(ec).order_of_point(G)
#     print('end find G')
    
    #find K
    K = int(p/17)
    message = [i for i in range(16)]
    x_m = []
    for m in message:
        for j in range(0, K):
            x = m*K + j
            if ec.point_on_curve(x):
                x_m.append(x)
                break
#     print('end find K')
#     print(x_m)
    assert len(x_m) == 16
    
    ka = random.randint(1, order-1)
    kb = random.randint(1, order-1)
    a = random.randint(0,15)
    encode_m = ECDH(ec, eg, ka, kb, G).encode(x_m[a])
    decode_m = ECDH(ec, eg, ka, kb, G).decode(encode_m)
#     print(a, decode_m.x, int(decode_m.x/K))
    assert a == int(decode_m.x/K)
    

#     print('end')
    #return m, Pm, kaG, kbG
    return [a, 
            encode_m, 
            ECDH(ec, eg, ka, kb, G).kAG, 
            ECDH(ec, eg, ka, kb, G).kBG]
            
def get_encode_messages(n, bitlong):
    k = 0
    while n != k:
        print("进度:{}, {}%".format(k, round((k) * 100 / n)), end="\r")
        cond = True
        while cond:
            cond = False
            try:
                g = get_encode_message(bitlong)
            except:
                cond = True
        if k == 0:
            m   = [g[0]]
            Pmx, Pmy = [g[1].x], [g[1].y]
            kaGx,kaGy= [g[2].x], [g[2].y]
            kbGx,kbGy= [g[3].x], [g[3].y]
        else:
            m.append(g[0])
            Pmx.append(g[1].x)
            Pmy.append(g[1].y)
            kaGx.append(g[2].x)
            kaGy.append(g[2].y)
            kbGx.append(g[3].x)
            kbGy.append(g[3].y)
        k = k+1
    print('Finish!!')
    return m, Pmx, Pmy, kaGx, kaGy, kbGx, kbGy

In [8]:
%%time
for i in range(100):
    gg = get_encode_message(8)

AssertionError: 

In [9]:
m, Pmx, Pmy, kaGx, kaGy, kbGx, kbGy = get_encode_messages(200, 14)

Finish!!100%


In [10]:
len(Pmy)

200

In [11]:
y = np.array([i%2 for i in m])
X = np.array([Pmx, Pmy, kaGx, kaGy, kbGx, kbGy])
X = X.transpose()

In [12]:
import pandas as pd
data = pd.DataFrame(np.array([m, Pmx, Pmy, kaGx, kaGy, kbGx, kbGy]).transpose(), 
                    columns=['m','Pmx', 'Pmy', 'kaGx', 'kaGy', 'kbGx', 'kbGy'])

In [13]:
data.head()

Unnamed: 0,m,Pmx,Pmy,kaGx,kaGy,kbGx,kbGy
0,11,4321,3878,9501,7412,7848,670
1,9,7543,5717,2119,8161,7945,1607
2,10,5230,3337,8541,3036,2240,7909
3,15,912,7125,822,5187,8681,2189
4,11,3009,6442,3947,1948,1626,837


In [14]:
data.to_csv('p_14_test.csv', index=False)