In [1]:
import random

In [2]:
class MerkleHellmanCryptosystem:
    """
    Implements the Merkle-Hellman knapsack cryptosystem.
    
    Attributes:
        length (int): The length of the superincreasing sequence, public key, and messages.
        public_key (list of int): The public key used for encryption.
        private_key (tuple): The private key used for decryption, containing the superincreasing sequence, modulus q, and multiplier w.
    """

    def __init__(self, length):
        """
        Initializes the cryptosystem with a specific length for keys and messages.
        
        Parameters:
            length (int): The length of the keys and messages.
        """
        self.length = length
        self.public_key, self.private_key = self.generate_keys(length)

    def generate_keys(self, n):
        """
        Generates the public and private keys for the cryptosystem.
        
        Parameters:
            n (int): The length of the superincreasing sequence.
        
        Returns:
            tuple: A tuple containing the public key (list of int) and private key (tuple of list, int, int).
        """
        # Initial superincreasing sequence for private key
        a_n = [random.randint(1, 100)]
        for i in range(1, n):
            a_n.append(sum(a_n) + random.randint(1, 10))
            
        # Modulus q, larger than the sum of the superincreasing sequence
        q = sum(a_n) + random.randint(1, 10)
        
        # Multiplier w that is coprime with q
        w = random_prime(q-1, True, q//2)  # random_prime available thanks to SageMath
        
        # Public key derived from private key elements
        public_key = [(w*a_i) % q for a_i in a_n]
        private_key = (a_n, q, w)
        
        return public_key, private_key

    def encrypt(self, message):
        """
        Encrypts a binary message using the public key.
        
        Parameters:
            message (list of int): The binary message to be encrypted.
        
        Returns:
            int: The encrypted message as an integer.
        """
        assert len(message) == len(self.public_key), "Message length must match the public key length."
        return sum(m_i * b_i for m_i, b_i in zip(message, self.public_key))

    def decrypt(self, cipher):
        """
        Decrypts an encrypted message using the private key.
        
        Parameters:
            cipher (int): The encrypted message.
        
        Returns:
            list of int: The decrypted binary message.
        """
        a_n, q, w = self.private_key
        w_inv = power_mod(w, -1, q)  # power_mod available thanks to SageMath
        c_prim = (cipher * w_inv) % q

        message = []
        for a_i in reversed(a_n):
            if a_i <= c_prim:
                message.append(1)
                c_prim -= a_i
            else:
                message.append(0)
        message.reverse()
        return message

    @staticmethod
    def generate_message(length):
        """
        Generates a random binary message of a specified length.
        
        Parameters:
            length (int): The length of the message to generate.
        
        Returns:
            list of int: A random binary message.
        """
        return [random.randint(0, 1) for _ in range(length)]

    @staticmethod
    def error_count(arr1, arr2):
        """
        Counts the number of differences between two arrays.
        
        Parameters:
            arr1 (list of int): The first array.
            arr2 (list of int): The second array.
        
        Returns:
            int: The number of differences between the two arrays.
        
        Raises:
            ValueError: If the arrays have different lengths.
        """
        if len(arr1) != len(arr2):
            raise ValueError("Both arrays must have the exact same length!")
        return sum(1 for x, y in zip(arr1, arr2) if x != y)

    def print_info(self, original_message, encrypted_message, decrypted_message):
        """
        Prints information about the encryption and decryption process.
        
        Parameters:
            original_message (list of int): The original binary message.
            encrypted_message (int): The encrypted message.
            decrypted_message (list of int): The decrypted binary message.
        """
        print("Length:", self.length)
        print("Public key:", self.public_key)
        print("Private key:", self.private_key)
        print("Original message:", original_message)
        print("Encrypted message:", encrypted_message)
        print("Decrypted message:", decrypted_message)
        print("Errors:", self.error_count(original_message, decrypted_message))

    def run(self, message=None):
        """
        Demonstrates the complete encryption and decryption process using randomly generated messages.
        
        Parameters:
            message (list of int): Optional parameter representing a binary message. If it is missing, a random message is generated.
        """
        if message is None:
            original_message = self.generate_message(self.length)
        else:
            assert len(message) == self.length, "Provided message length must match the cryptosystem length."
            original_message = message
        encrypted_message = self.encrypt(original_message)
        decrypted_message = self.decrypt(encrypted_message)
        self.print_info(original_message, encrypted_message, decrypted_message)
        
        
    # Added method to prove the LD attack
    def get_public_key(self):
        if self.public_key is not None:
            return self.public_key
        else:
            return None


In [3]:
class LowDensityAttack:
    
    def __init__(self, a, M):
        """
        Initializes the LowDensityAttack class with a vector 'a' and a target sum 'M'.
        
        Parameters:
        a (list): The list of integers.
        M (int): The target sum.
        """
        self.a = a
        self.M = M
        self.n = len(a)
        
    def generate_base_step_1(self, M):
        """
        Generates the initial basis matrix for the lattice.
        
        Parameters:
        M (int): The target sum.
        
        Returns:
        Matrix: The initial basis matrix.
        """
        basis = []
        for i in range(self.n):
            b_i = [0 for _ in range(self.n)]
            b_i[i] = 1
            b_i.append(-self.a[i])
            basis.append(b_i)
        b_last = [0 for _ in range(self.n)]
        b_last.append(M)
        basis.append(b_last)
        return Matrix(basis)
    
    def generate_reduced_basis_step_2(self, basis):
        """
        Applies the LLL algorithm to the basis to obtain a reduced basis.
        
        Parameters:
        basis (Matrix): The initial basis matrix.
        
        Returns:
        Matrix: The reduced basis matrix.
        """
        return basis.LLL()
    
    def check_solution(self, x):
        """
        Checks if a given vector 'x' is a solution to the equation a*x = M.
        
        Parameters:
        x (list): The vector to check.
        
        Returns:
        bool: True if 'x' is a solution, False otherwise.
        """
        if x is None:
            return False
        aux = 0
        for i in range(self.n):
            aux += x[i] * self.a[i]
        return True if aux == self.M else False
    
    def check_all_zero_or_lambda(self, vector, lamb):
        """
        Checks if all components of a vector are either 0 or 'lamb'.
        
        Parameters:
        vector (list): The vector to check.
        lamb (int): The value to check multiples of.
        
        Returns:
        bool: True if all components are either 0 or 'lamb', False otherwise.
        """
        for component in vector:
            if component != 0 and component != lamb:
                return False
        return True
    
    def check_basis_step_3(self, basis):
        """
        Checks the reduced basis for potential solutions.
        
        Parameters:
        basis (Matrix): The reduced basis matrix.
        
        Returns:
        list: A valid solution vector if found, None otherwise.
        """
        solution = []
        
        for i in range(self.n + 1):
            lamb = 1
            lamb_found = False
            for j in range(self.n):
                if basis[i][j] != 0:
                    lamb = basis[i][j]
                    if self.check_all_zero_or_lambda(basis[i], lamb):
                        lamb_found = True
                        break
            if not lamb_found:
                continue
                
            potential_solution = [basis[i][j] / lamb for j in range(self.n)]
            if self.check_solution(potential_solution):
                solution = potential_solution
                break
        return solution if solution else None
    
    def solve(self):
        """
        Solves the low-density attack problem.
        
        Returns:
        list: A solution vector if found, prints "NO SOLUTION FOUND" otherwise.
        """
        if self.M == 0:
            return [0 for _ in range(self.n)]
        basis = self.generate_base_step_1(self.M)
        reduced_basis = self.generate_reduced_basis_step_2(basis)
        solution = self.check_basis_step_3(reduced_basis)
        if solution is not None:
            return solution
        else:
            # Step 4 of the algorithm
            new_M = sum(self.a) - self.M
            basis = self.generate_base_step_1(new_M)
            reduced_basis = self.generate_reduced_basis_step_2(basis)
            solution = self.check_basis_step_3(reduced_basis)
            if solution is None:
                print("NO SOLUTION FOUND")
            return solution


In [4]:
def test_LD():
    n = randint(1, 10)
    set_of_numbers = [randint(1, 100) for _ in range(n)]
    solution = [randint(0,1) for _ in range(n)]
    target_sum = sum([a*b for a,b in zip(set_of_numbers, solution)])
    print("Set: ", set_of_numbers)
    print("Target: ", target_sum)
    LD = LowDensityAttack(set_of_numbers, target_sum)
    return LD.check_solution(LD.solve())
test_LD()

Set:  [37, 27, 18, 66, 62, 83, 34, 59]
Target:  182


True

In [5]:
def brokeMH(length):
    MH = MerkleHellmanCryptosystem(length)
    message = [randint(0,1) for _ in range(length)]
    b_n = MH.get_public_key()
    target_sum = MH.encrypt(message)
    
    LD = LowDensityAttack(b_n, target_sum)
    result = LD.solve()
    
    print("Original message: ", message)
    print("Message obtained after attacking: ", result)
    if message == result:
        print("--------SUCCESSFUL ATTACK--------")
    else:
        print("--------UNSUCCESSFUL ATTACK--------")

In [12]:
brokeMH(30)

Original message:  [0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1]
Message obtained after attacking:  [0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1]
--------SUCCESSFUL ATTACK--------
