# Pohlig-Hellman algorithm for discrete logarithms


This program was made for an assignment of the class ”Cryptography” of my master’s program:  
"Use the Pohlig-Hellman algorithm to calculate the discrete logarithm $x$ inside $\mathbb{Z}_N=\mathbb{Z}_{1693}$, where $17^x \equiv 101 \, mod \, 1693$."

<u>**Pohlig-Hellman algorithm**</u>  
$\cdot$ Input: A cyclic group $G=<g>$ with $|G|=n>1$, a prime factorization $n=p_1^{e_1}\dots p_k^{e_k}$ and $a \in G$  
$\cdot$ Output: $x=\log_g a$
1. For $i=1, \dots, k$ we compute the quantities:  
$$n_i=n/p_i^{e_i}, \, g_i = g^{n_i}, \, a_i=a^{n_i}, \, \gamma_i = g_i^{p_i^{e_i-1}}$$

2. For $i=1, \dots, k$ we compute the discrete logarithms:   
\begin{equation*}
\begin{aligned}
x_{i0} & =\log_{\gamma_i} a_i^{p_i^{e_i-1}} \\  
x_{ij} & =\log_{\gamma_i} \left( a_ig_i^{-(x_{i0}+x_{i-1}p_i+\dots+x_{i,j-1}p_i^{e_i-i-j})} \right)^{p_i^{e_i-1-j}} \; (j=1,\dots,e_i-1)
\end{aligned}
\end{equation*}

3. For $i=1, \dots, k$ we calculate the sum:  
$$x_i=x_{i0}+x_{i1}p_i + \dots x_{i,e_i-1}p_i^{e_i-1}$$

4. We calculate $x \in \mathbb{Z}$ with $0 \leq x \leq n-1$ such that:
$$x \equiv x_i \, (mod \, p_i^{e_i}) \; (i=1,\dots,k)$$

5. We extract $x=\log_g a$.  

Author: Florias Papadopoulos

## Importing modules

We start by importing the modules that we will use

In [1]:
import math

## Defining the functions

### Starter functions

#### (a) shanksAlgorithm

We will use our function of Shanks' algorithm to compute the discrete logarithms needed in step 2 of the Pohlig-Hellman algorithm.    

As mentioned before, our goal is to find $x=\log_g a$, given a cyclic group $G=<g>$ with $|G|=n$ and $g^x \equiv a \, mod \, N$.  
Therefore, we will create a script that has as input the values $N, n, g, a$ and returns $x$, along with the other values that were computed in each step of the algorithm.

In [2]:
import itertools #shanks algorithm needs this module to work

#more on Shanks' algorithm can be found on the corresponding .ipynb
def shanksAlgorithm(N,n,g,a):

    #step1
    m = math.floor(math.sqrt(n))+1

    #step2
    B_list = []

    for r in range(0,m):
        g_inv = pow(g, -1, N)
        agr = a*((g_inv**r) % N) % N
        B_list.append(agr)

        if agr == 1:
            r_end = r
            return r_end ,0, 0, 0, 0, m, B_list

    #step3
    d = (g**m) % N

    for q in itertools.count(start=1):
        dq = (d**q) % N
        if dq in B_list:
            r = B_list.index(dq)
            x = q*m + r
            break
  
    return x, d, dq, q, r, m, B_list

#### (b) chinese_remainder 

We also need a function that can calculate the $x \in \mathbb{Z}$ with $0 \leq x \leq n-1$ such that: 
$\; \; x \equiv x_i \, (mod \, p_i^{e_i}) \; (i=1,\dots,k)$.

Theoretically, this is done using the Chinese Remainder Theorem.  
In practise, the function below works with an input of $n_{list}=[n_1, \dots, n_k]$ and $x_{list}=[x_1,\dots,x_k]$ when trying to find $x$ such that:  
$$x \equiv x_i \, (mod \, n_i) \; (i=1,\dots,k)$$

In [3]:
# Python 3.6 - found online
from functools import reduce

def chinese_remainder(n_list, x_list):
    sum = 0
    prod = reduce(lambda x_list, b: x_list*b, n_list)
    for n_i, x_i in zip(n_list, x_list):
        p = prod // n_i
        sum += x_i * mul_inv(p, n_i) * p
        
    return sum % prod
 
def mul_inv(a, b):
    b0 = b
    x0, x1 = 0, 1
    if b == 1: return 1
    while a > 1:
        q = a // b
        a, b = b, a%b
        x0, x1 = x1 - q * x0, x0
    if x1 < 0: x1 += b0
    return x1

#### (c) generatePrimeFactors

Finally we need a function that can take the number $n$ and calculate its prime factors, along with their exponents.  
This function works with an accompanying algorithm (the Sieve of Eratosthenes).

In [4]:
#Python3 program to print prime factors and their powers using Sieve Of Eratosthenes (found online, with small changes by me -F)
# Using SieveOfEratosthenes to  find smallest prime factor of all the numbers.
# For example, if N is 10, s[2] = s[4] = s[6] = s[10] = 2, s[3] = s[9] = 3, s[5] = 5, s[7] = 7

def generatePrimeFactors(N):
    # s[i] is going to store smallest prime factor of i.
    s = [0] * (N+1)
    
    # Filling values in s[] using the sieve
    # Create a boolean array "prime[0..n]" and initialize all entries in it as false.
    prime = [False] * (N+1)
    # Initializing smallest factor equal to 2 for all the even numbers
    for i in range(2, N+1, 2):
        s[i] = 2
    # For odd numbers less than equal to n
    for i in range(3, N+1, 2):
        if (prime[i] == False):            
            # s(i) for a prime is the number itself
            s[i] = i
            # For all multiples of current prime number
            for j in range(i, int(N / i) + 1, 2):
                if (prime[i*j] == False):
                    prime[i*j] = True
                    # i is the smallest prime factor for number "i*j".
                    s[i * j] = i
    
    # Current prime factor of N
    curr = s[N]   
    # Power of current prime factor + allagh dikia mou
    cnt = 1
    curr_set = []
    cnt_set = []
    # Printing prime factors and their powers
    while (N > 1):
        N //= s[N]
        # N is now N/s[N]. If new N also has smallest prime factor as curr, increment power
        if (curr == s[N]):
            cnt += 1
            continue
        curr_set.append(curr)
        cnt_set.append(cnt)
        # Update current prime factor as s[N] and initializing count as 1.
        curr = s[N]
        cnt = 1
        
    return [curr_set, cnt_set]

### Pohlig-Hellman function

We will now create our main function based on the Poligh-Hellman algorithm, albeit with a twist.  
In particular, as we mentioned before, we will make use of the Shanks algorithm for computing discrete logarithms on the second step of the algorithm.
Moreover, our algorithm will not only output the discrete logarithm, but all values that were calculated in each step.

In [5]:
def pohligHellmanAlgorithm(N,n,g,a):

    #primefactors of n (p_1^e_1 * ... * p_{k}^e_{k}) // attention! python lists start from 0 and not 1
    [p_list, e_list] = generatePrimeFactors(n)
    k = len(p_list)

    #step1
    n_list = []
    g_list = []
    a_list = []
    gamma_list = []

    for i in range(k):
        n_list.append(0)
        g_list.append(0)
        a_list.append(0)
        gamma_list.append(0)

    for i in range(k):
        n_list[i] = (n // (p_list[i]**e_list[i])) % N
        g_list[i] = (g ** n_list[i]) % N
        a_list[i] = (a ** n_list[i]) % N
        gamma_list[i] = (g_list[i] ** p_list[i] ** (e_list[i] - 1)) % N

    #step2

    #calculation of g_i^-1
    g_inv_list = []
    for i in range(k):
        g_inv_list.append(0)

    for i in range(k):
        g_inv_list[i] = pow(g_list[i],-1,N)

    x_j_list = []
    x_ij_list = []

    #for i=1,...,k
    for i in range(k):
        
        #calculation of x_{i0}
        x_j_list = []
        for j in range(0,e_list[i]):
            x_j_list.append(0)

        num_0 = (a_list[i] ** p_list[i] ** (e_list[i] - 1)) % N
        (discr, d, dq, q, r, m, B_list) = shanksAlgorithm(N,n,gamma_list[i],num_0)
        x_j_list[0] = discr
            
        #calculation of x_{ij} for j=1,...,e_i-1
        for j in range(1,e_list[i]):

            exp_of_g_i = 0
            for j_0 in range(0,j):
                add = x_j_list[j_0] * (p_list[i]**j_0)
                exp_of_g_i = exp_of_g_i + add

            inside_num_j = (g_inv_list[i] ** (exp_of_g_i)) % N
            exp_num_j = p_list[i]**(e_list[i]-1-j)
            num_j = ((a_list[i] * inside_num_j)**exp_num_j) % N
            (discr_j, d, dq, q, r, m, B_list) = shanksAlgorithm(N,n,gamma_list[i],num_j)
            x_j_list[j] = discr_j

        x_ij_list.append(x_j_list)

    #step3 - calculation of x_i
    x_list = []
    for i in range(k):
        x_list.append(0)
        for j in range(e_list[i]):
            x_list[i] = x_list[i] + (x_ij_list[i][j]*(p_list[i]**j)) % N
    
    #step4 - computation of x using the Chinese Remainder Theorem
    last_list = []
    for i in range(k):
        last_list.append(p_list[i]**e_list[i])

    x = chinese_remainder(last_list, x_list)

    return p_list, e_list, n_list, g_list, a_list, gamma_list, x_ij_list, x_list, last_list, x

## Solving the problem

We create a script that uses the above function to return us a text that elaborates on all the values that were computed in each step of the algorithm.  
In our problem $N=1693$, $n=1962$, $g=17$ and $a=101$.

In [6]:
#input
N, n, g, a = 1693, 1692, 17, 101 
#input 

(p_list, e_list, n_list, g_list, a_list, gamma_list, x_ij_list, x_list, last_list, x) = pohligHellmanAlgorithm(N,n,g,a)


print("We wanted to find the discrete logarithm log_" +str(g) + "(" + str(a)+ ")")
print("For this, we used the Pohlig-Hellman Algorithm and did the following:")
print("")
print("First, we found the prime analysis of n=p-1=p_1**e_1 * ... * p_k**e_k, with the set of p_i being", p_list, "and the set of e_i being", e_list,".")
print("In step 1, we calculated the n_i, g_i, a_i and γ_i (for i=0,...," + str(len(p_list)) +"), getting:")
print("- The set of n_i", n_list)
print("- The set of g_i", g_list)
print("- The set of a_i", a_list)
print("- The set of γ_i", gamma_list)
print("In step 2, we calculated the discrete logarithms x_ij, getting the x_i lists:")
for i in range(len(p_list)):
    print(x_ij_list[i])
print("In step 3, we calculated the x_i, getting", x_list)
print("In step 4, we used the chinese remainder theorem to solve the system:")
for i in range(len(p_list)):
    print("x(equiv)" + str(x_list[i]) + "(mod" + str(last_list[i]) +")")
print("and finally got the discrete logarithm x=log_" +str(g) + "(" + str(a)+ ")=" + str(x))

We wanted to find the discrete logarithm log_17(101)
For this, we used the Pohlig-Hellman Algorithm and did the following:

First, we found the prime analysis of n=p-1=p_1**e_1 * ... * p_k**e_k, with the set of p_i being [2, 3, 47] and the set of e_i being [2, 2, 1] .
In step 1, we calculated the n_i, g_i, a_i and γ_i (for i=0,...,3), getting:
- The set of n_i [423, 188, 36]
- The set of g_i [92, 85, 1241]
- The set of a_i [1, 1252, 642]
- The set of γ_i [1692, 1259, 1241]
In step 2, we calculated the discrete logarithms x_ij, getting the x_i lists:
[0, 0]
[1, 2]
[26]
In step 3, we calculated the x_i, getting [0, 7, 26]
In step 4, we used the chinese remainder theorem to solve the system:
x(equiv)0(mod4)
x(equiv)7(mod9)
x(equiv)26(mod47)
and finally got the discrete logarithm x=log_17(101)=1060
