November 4, 2021

Burton Rosenberg

Descrete logs and easy bits for the Diffie-Hellman problem

In [1]:

class ModP:

    @staticmethod
    def exp_mod(a,b,p):

        a = a%p
        b = b%(p-1)

        if a==0:
            return 0
        if a==1:
            return 1
        if b==0:
            return 1
        if b==1:
            return a
        assert b>1

        if b%2==1:
            return a * ModP.exp_mod(a,b-1,p)%p
        t = ModP.exp_mod(a,b//2,p)
        return (t*t)%p

    @staticmethod
    def qr_p(a,p):
        return ModP.exp_mod(a,(p-1)//2,p)==1

    @staticmethod
    def extended_gcd(a,b):
        """
        extended GCD algorithm. recursive.
        returns (d,s,t) where d = s*a+t*b 
        and d = gcd(a,b)
        """
        assert(
            a>=0 and b>=0 )
        if b==0:
            return (a,1,0)
        (q,r) = divmod(a,b)
        (d,s,t) = ModP.extended_gcd(b,r)
        # gcd(a, b) == gcd(b, r) == s*b + t*r == s*b + t*(a - q*b)
        return (d,t,s-q*t)

    @staticmethod
    def invert(a,p):
        (d,t,s) = ModP.extended_gcd(a,p)
        assert 1==d
        return t%p

    @staticmethod
    def gen_p(g,p,verbose=False):
        r = g 
        for i in range(1,p-1):
            if verbose: print (f'{i}:\t{r}')
            if r==1:
                return False
            r = (r*g)%p
        if verbose: print (f'{p-1}:\t{r}')
        return True

    @staticmethod
    def find_gen(p):
        for g in range(2,p-1):
            if ModP.gen_p(g,p):
                return g
        assert False

    @staticmethod
    def brute_log(x,g,p):
        x = x%p
        assert x!=0

        r = g
        if x==1:
            return 0
        for i in range(1,p-1):
            if r==x:
                return i
            r = r*g%p
        assert False

  
    

In [2]:
class TonelliShanks:
    
    def __init__(self,p):
        self.p = p

    def find_qnr(self):
        for i in range(2,self.p-1):
            if not ModP.qr_p(i,self.p):
                return i
        assert False

    def make_one(self,t):
        i = 1
        e = (t*t)%self.p
        while e!=1:
            i += 1
            e = (e*e)%self.p
        return i

    @staticmethod
    def s_q_form(n):
        i = 0
        while n%2==0:
            i += 1
            n //= 2
        return (i,n)

    def sq_root(self,n):
        if not ModP.qr_p(n,self.p):
            return 0 

        (s,q) = TonelliShanks.s_q_form(self.p-1)
        z = self.find_qnr()
        m = s
        c = ModP.exp_mod(z,q,self.p)
        t = ModP.exp_mod(n,q,self.p)
        r = ModP.exp_mod(n,((q+1)//2),self.p)
        while True:
            if t==0: return 0
            if t==1: return r
            i = self.make_one(t)
            b = ModP.exp_mod(c,2**(m-i-1),self.p)
            m = i
            c = ModP.exp_mod(b,2,self.p)
            t = (t*c)%self.p
            r = (r*b)%self.p
        assert False
 
class EasyBits:
    
    def __init__(self,p,g):
        self.p = p
        self.g = g
        self.inv_g = ModP.invert(g,p)
        (s,t) = TonelliShanks.s_q_form(p-1)
        self.ts = TonelliShanks(p)
        self.s = s
        self.t = t

    def how_many(self):
        print(f'there are {self.s} easy bits')
        
    def easy_bits(self,n):
        eb = []
        for i in range(self.s):
            if ModP.qr_p(n,self.p):
                eb = [0]+eb
            else:
                eb = [1]+eb
                n = n*self.inv_g % self.p
            n = self.ts.sq_root(n)
        return eb
            
      

In [3]:
my_prime = 11
my_gen = ModP.find_gen(my_prime)
print(f'<{my_gen}> = Z/{my_prime}Z')
ModP.gen_p(my_gen,my_prime,verbose=True)

tonelli_shanks = TonelliShanks(my_prime)
for i in range(1,my_prime):
    sq = tonelli_shanks.sq_root(i)
    if sq!=0:
        assert (sq*sq)%my_prime==i
        print(f'sqrt({i})=\t{sq}, {(my_prime-sq%my_prime)}')

<2> = Z/11Z
1:	2
2:	4
3:	8
4:	5
5:	10
6:	9
7:	7
8:	3
9:	6
10:	1
sqrt(1)=	1, 10
sqrt(3)=	5, 6
sqrt(4)=	9, 2
sqrt(5)=	4, 7
sqrt(9)=	3, 8


In [4]:
      
my_prime = 13
my_gen = ModP.find_gen(my_prime)        
eb = EasyBits(my_prime,my_gen)

ModP.gen_p(my_gen,my_prime,verbose=True)

for n in range(1,my_prime):
    print(n,eb.easy_bits(n),bin(ModP.brute_log(n,my_gen,my_prime)))

1:	2
2:	4
3:	8
4:	3
5:	6
6:	12
7:	11
8:	9
9:	5
10:	10
11:	7
12:	1
1 [0, 0] 0b0
2 [0, 1] 0b1
3 [0, 0] 0b100
4 [1, 0] 0b10
5 [0, 1] 0b1001
6 [0, 1] 0b101
7 [1, 1] 0b1011
8 [1, 1] 0b11
9 [0, 0] 0b1000
10 [1, 0] 0b1010
11 [1, 1] 0b111
12 [1, 0] 0b110
