<a href="https://colab.research.google.com/github/orenbara/braude_crypto_project/blob/master/colab/Salsa20_EC_ELGAMAL_RSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pre-run installations

In [21]:
# @title Connect Google Drive
mount_path = "/content/drive" # @param {type:"string"}
# Connect to Google Drive
from google.colab import drive
drive.mount(mount_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
cd /content/drive/MyDrive/Colab Notebooks/Crypto Course/

/content/drive/MyDrive/Colab Notebooks/Crypto Course


# Sender Side

In [23]:
# @title Imports and definitions
import os
import struct
import random
import hashlib
# Elliptic curve parameters (using NIST P-256 curve)
p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff
a = 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc
b = 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b
Gx = 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296
Gy = 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5
n = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551


In [24]:
# @title Generate RSA Key Pair

def mod_inverse(a, m):
    if a == 0:
        return 0
    lm, hm = 1, 0
    low, high = a % m, m
    while low > 1:
        ratio = high // low
        nm, new = hm - lm * ratio, high - low * ratio
        lm, low, hm, high = nm, new, lm, low
    return lm % m


"""
Miller-Rabin primality test, a probabilistic algorithm used to determine whether a given number is prime.
"""
def is_prime(n, k=5):
    if n < 2:
        return False
    for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]:
        if n % p == 0:
            return n == p
    s, d = 0, n - 1
    while d % 2 == 0:
        s, d = s + 1, d // 2
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = pow(a, d, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(s - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

"""
simple function that will generate a prime number based on number of bits
"""
def generate_prime(bits):
    while True:
        p = random.getrandbits(bits)
        if is_prime(p):
            return p

"""
Generates a public and a private keys which together satisfy the modular equation
e * d = 1 mod (phi(n))
where public key is (n,e) and private key is (n,d)
"""
def generate_rsa_keypair(bits=2048):
    # Since n = p * q, and p and q are roughly the same size, each prime should be about half the bit length of n.
    p = generate_prime(bits // 2)
    q = generate_prime(bits // 2)
    n = p * q
    phi = (p - 1) * (q - 1)
    e = 65537
    d = mod_inverse(e, phi)
    return (n, e), (n, d)


print("Generating RSA keypair...")
rsa_public_key, rsa_private_key = generate_rsa_keypair()
print("RSA keypair generated successfully.")


Generating RSA keypair...
RSA keypair generated successfully.


In [25]:
# @title Generate an ECC key pair

def mod_inverse(a, m):
    if a == 0:
        return 0
    lm, hm = 1, 0
    low, high = a % m, m
    while low > 1:
        ratio = high // low
        nm, new = hm - lm * ratio, high - low * ratio
        lm, low, hm, high = nm, new, lm, low
    return lm % m

def point_add(P1, P2):
    if P1 is None:
        return P2
    if P2 is None:
        return P1
    if P1[0] == P2[0] and P1[1] != P2[1]:
        return None
    if P1 == P2:
        lam = (3 * P1[0] * P1[0] + a) * mod_inverse(2 * P1[1], p)
    else:
        lam = (P2[1] - P1[1]) * mod_inverse(P2[0] - P1[0], p)
    x3 = (lam * lam - P1[0] - P2[0]) % p
    y3 = (lam * (P1[0] - x3) - P1[1]) % p
    return (x3, y3)


"""
This uses the double-and-add algorithm, an efficient method for scalar multiplication
Scalar multiplication involves computing the product of a scalar (integer) k and a point P on an elliptic curve
to obtain another point Q on the curve, denoted as Q=kP This operation is analogous to repeated addition in the
context of elliptic curves.
"""
def scalar_mult(k, P):
    Q = None
    for i in range(256):
        if k & (1 << i):
            Q = point_add(Q, P)
        P = point_add(P, P)
    return Q


"""
Generate an ECC key pair
"""
def generate_keypair():
    # Chooses a random integer as the private key
    private_key = random.randint(1, n - 1)

    # Computes the public key by multiplying the generator point (Gx, Gy) by the private key.
    public_key = scalar_mult(private_key, (Gx, Gy))
    return private_key, public_key


print("Generating EC keypair...")
ec_private_key, ec_public_key = generate_keypair()
print(f"EC private key: {ec_private_key}")
print(f"EC public key: {ec_public_key}")
print("EC keypair generated successfully.")

Generating EC keypair...
EC private key: 83457324207508227216638540510889016463399203744345636264554733957196249720241
EC public key: (24058615692104224531602981659803775712463775818505693915990587731976572846408, 10752262227775125893086188636912150501714532404023915046642728880706600471427)
EC keypair generated successfully.


In [26]:
# @title Generate Salsa20 key

print("Generating SALSA20 key...")
salsa_key = os.urandom(32)
print(f"SALSA20 key: {salsa_key.hex()}")
print("SALSA20 key generated successfully.")

Generating SALSA20 key...
SALSA20 key: 3dc428f1fd93591daf415d45611ab51c3c88af7f5f34d67f2e9ef048ab74be8c
SALSA20 key generated successfully.


In [27]:
# @title Encrypt Salsa20 Key with EC-Elgamal
def point_add(P1, P2):
    if P1 is None:
        return P2
    if P2 is None:
        return P1
    if P1[0] == P2[0] and P1[1] != P2[1]:
        return None
    if P1 == P2:
        lam = (3 * P1[0] * P1[0] + a) * mod_inverse(2 * P1[1], p)
    else:
        lam = (P2[1] - P1[1]) * mod_inverse(P2[0] - P1[0], p)
    x3 = (lam * lam - P1[0] - P2[0]) % p
    y3 = (lam * (P1[0] - x3) - P1[1]) % p
    return (x3, y3)

"""
This uses the double-and-add algorithm, an efficient method for scalar multiplication
Scalar multiplication involves computing the product of a scalar (integer) k and a point P on an elliptic curve
to obtain another point Q on the curve, denoted as Q=kP This operation is analogous to repeated addition in the
context of elliptic curves.
"""
def scalar_mult(k, P):
    Q = None
    for i in range(256):
        if k & (1 << i):
            Q = point_add(Q, P)
        P = point_add(P, P)
    return Q



"""
Encrypt a message (in this case, the SALSA20 key) using ECC El-Gamal encryption
"""
def encrypt_key(public_key, plaintext):
    # Chooses a random k for this encryption.
    k = random.randint(1, n - 1)
    C1 = scalar_mult(k, (Gx, Gy))
    S = scalar_mult(k, public_key)
    C2 = (plaintext * S[0]) % p
    return (C1, C2)


print("Encrypting SALSA20 key with EC El-Gamal...")
encrypted_key = encrypt_key(ec_public_key, int.from_bytes(salsa_key, 'big'))
print(f"Encrypted key: {encrypted_key}")
print("SALSA20 key encrypted successfully.")

Encrypting SALSA20 key with EC El-Gamal...
Encrypted key: ((38726271288457331098030453662435535331255936851490323255812499400869760638734, 16973498864417693547636103216548996059411422914589199023320990725470600386508), 66202916957528646828573328501534199698719535913990585878234073917140344870498)
SALSA20 key encrypted successfully.


In [28]:
# @title Signing The Encrypted Key With RSA
"""
Create a digital signature for a message using the RSA private key.
"""
def rsa_sign(private_key, message):
    n, d = private_key

    # Hash the message using SHA-1 - we get the MD1 (Message Digest)
    # hash function has many benefits like fix sized, improved security, but it is not mandatory
    hash_object = hashlib.sha1(message)
    hashed_message = hash_object.digest()

    # Convert the hashed message from bytes to an integer - so we could do math
    message_int = int.from_bytes(hashed_message, 'big')

    # Compute the signature as message^d mod n (pow does modulo when provided with 3 params) - this is the encryption
    signature = pow(message_int, d, n)

    # Return the signature as bytes (playing with the numbers to make sure there is no overflow)
    return signature.to_bytes((signature.bit_length() + 7) // 8, 'big')



print("Signing the encrypted key with RSA...")
signature = rsa_sign(rsa_private_key, str(encrypted_key).encode())
print(f"Signature: {signature.hex()}")
print("Signature created successfully.")

Signing the encrypted key with RSA...
Signature: 0d2986cbf6d8b6324c828aed969de87e4e20f8e8a5526576747c4c70b895ce512eddeefb7a257628f5834192c9946efa6473147881dc47a57049b7897544b8ff04956bbeaa3c998a63840e6b566689b0f522d348e2031d400bea7b7df7920cc3d736dafafb10f613af606073258cf8ae95d4308dad8843f51dc94672731b003ed0613fcdf552484d45283bfc6043f2b8b87e36c83998a4f83e1d3bb2dd971cbc2fca5d4b8fa8d7c64472bf76f20e0d73f50776564cb075e290e132d03db6c36cac795a2bdf5de27ed0af1b981109157a1a027104a490befa60f2d8220f01e3386296351e356e95d1b865c3b8eed3ab9a0c9162e1ace5191afc11b0c81b620a28
Signature created successfully.


In [29]:
# @title Encrypt File With Salsa20

import struct
import os

# Constants for SALSA20
ROUNDS = 20


def rotl32(v, c):
    return ((v << c) & 0xffffffff) | (v >> (32 - c))


def quarter_round(a, b, c, d):
    b ^= rotl32(a + d, 7)
    c ^= rotl32(b + a, 9)
    d ^= rotl32(c + b, 13)
    a ^= rotl32(d + c, 18)
    return a, b, c, d


def salsa20_block(input_block):
    # Pad the input to 64 bytes if necessary
    input_block = input_block.ljust(64, b'\x00')

    # convert the 64-byte input block into a list of sixteen 32-bit integers
    x = list(struct.unpack('<16I', input_block))
    orig_x = x[:]

    for _ in range(ROUNDS):
        # Column rounds
        x[0], x[4], x[8], x[12] = quarter_round(x[0], x[4], x[8], x[12])
        x[5], x[9], x[13], x[1] = quarter_round(x[5], x[9], x[13], x[1])
        x[10], x[14], x[2], x[6] = quarter_round(x[10], x[14], x[2], x[6])
        x[15], x[3], x[7], x[11] = quarter_round(x[15], x[3], x[7], x[11])

        # Diagonal rounds
        x[0], x[1], x[2], x[3] = quarter_round(x[0], x[1], x[2], x[3])
        x[5], x[6], x[7], x[4] = quarter_round(x[5], x[6], x[7], x[4])
        x[10], x[11], x[8], x[9] = quarter_round(x[10], x[11], x[8], x[9])
        x[15], x[12], x[13], x[14] = quarter_round(x[15], x[12], x[13], x[14])

    # Adding the original input to the transformed block at the end of salsa20_block ensures that the algorithm is
    # reversible (i.e., it can be decrypted)
    for i in range(16):
        x[i] = (x[i] + orig_x[i]) & 0xffffffff

    return struct.pack('<16I', *x)


def salsa20_encrypt(key, nonce, plaintext):
    if len(key) != 32 or len(nonce) != 8:
        raise ValueError("Key must be 32 bytes and nonce must be 8 bytes")

    keystream = b''
    counter = 0
    while len(keystream) < len(plaintext):
        # Converts the integer counter into an 8-byte binary string (in little-endian format)
        counter_bytes = struct.pack('<Q', counter)

        block = key[:16] + nonce + counter_bytes + key[16:]
        keystream += salsa20_block(block)
        counter += 1
    # XORs the plaintext with the keystream to produce the ciphertext
    return bytes(a ^ b for a, b in zip(plaintext, keystream))


# OFB mode implementation
def ofb_mode_encrypt(key, iv, plaintext):
    block_size = 64
    # Splits the plaintext into blocks of block_size
    blocks = [plaintext[i:i + block_size] for i in range(0, len(plaintext), block_size)]
    ciphertext = b''
    previous_block = iv

    for block in blocks:
        keystream = salsa20_encrypt(key, previous_block[:8], b'\x00' * block_size)[:len(block)]
        cipher_block = bytes(a ^ b for a, b in zip(block, keystream))
        ciphertext += cipher_block
        previous_block = keystream[:8]  # Use the first 8 bytes as the next IV

    return ciphertext


# File encryption and decryption functions
def encrypt_file(input_file, output_file, key):
    iv = os.urandom(8)
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        plaintext = f_in.read()
        ciphertext = ofb_mode_encrypt(key, iv, plaintext)
        f_out.write(iv + ciphertext)


print("Encrypting file...")
input_file = 'input.txt'
if not os.path.exists(input_file):
    print(f"'{input_file}' not found. Creating a test file.")
    with open(input_file, 'w') as f:
        f.write("This is a test file for encryption and decryption.")

encrypt_file(input_file, 'encrypted.bin', salsa_key)
print("File encrypted successfully.")

Encrypting file...
'input.txt' not found. Creating a test file.
File encrypted successfully.


# Reciver Side

In [30]:
"""
Verify an RSA signature using the public key.
"""
def rsa_verify(public_key, message, signature):
    n, e = public_key

    # Hash the message using SHA-1 - same as in the signing, the receiver want's the same MD
    hash_object = hashlib.sha1(message)
    hashed_message = hash_object.digest()

    # Convert the hashed message and signature to integers
    message_int = int.from_bytes(hashed_message, 'big')
    signature_int = int.from_bytes(signature, 'big')

    # Compute signature^e mod n - this is the decryption part
    decrypted = pow(signature_int, e, n)

    # Check if the result matches the hashed message, if the private key which the sender used, and the publickey
    # which the receiver uses match then using the SHA1 on the same message will result it
    return decrypted == message_int


print("\nSimulating receiver's side:")
print("Verifying signature...")
if rsa_verify(rsa_public_key, str(encrypted_key).encode(), signature):
  print("Signature verified successfully")
else:
  print(f"An error occurred")


Simulating receiver's side:
Verifying signature...
Signature verified successfully


In [31]:
# @title Decrypting With Salsa20
def salsa20_decrypt(key, nonce, ciphertext):
    return salsa20_encrypt(key, nonce, ciphertext)  # Encryption and decryption are the same in OFB mode


# OFB mode implementation
def ofb_mode_encrypt(key, iv, plaintext):
    block_size = 64
    # Splits the plaintext into blocks of block_size
    blocks = [plaintext[i:i + block_size] for i in range(0, len(plaintext), block_size)]
    ciphertext = b''
    previous_block = iv

    for block in blocks:
        keystream = salsa20_encrypt(key, previous_block[:8], b'\x00' * block_size)[:len(block)]
        cipher_block = bytes(a ^ b for a, b in zip(block, keystream))
        ciphertext += cipher_block
        previous_block = keystream[:8]  # Use the first 8 bytes as the next IV

    return ciphertext


def ofb_mode_decrypt(key, iv, ciphertext):
    return ofb_mode_encrypt(key, iv, ciphertext)  # Decryption is the same as encryption in OFB mode


def decrypt_file(input_file, output_file, key):
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        iv = f_in.read(8)
        ciphertext = f_in.read()
        plaintext = ofb_mode_decrypt(key, iv, ciphertext)
        f_out.write(plaintext)

def mod_inverse(a, m):
    if a == 0:
        return 0
    lm, hm = 1, 0
    low, high = a % m, m
    while low > 1:
        ratio = high // low
        nm, new = hm - lm * ratio, high - low * ratio
        lm, low, hm, high = nm, new, lm, low
    return lm % m

def is_on_curve(point):
    if point is None:
        return True
    x, y = point
    return (y * y - x * x * x - a * x - b) % p == 0

def point_add(P1, P2):
    if P1 is None:
        return P2
    if P2 is None:
        return P1
    if P1[0] == P2[0] and P1[1] != P2[1]:
        return None
    if P1 == P2:
        lam = (3 * P1[0] * P1[0] + a) * mod_inverse(2 * P1[1], p)
    else:
        lam = (P2[1] - P1[1]) * mod_inverse(P2[0] - P1[0], p)
    x3 = (lam * lam - P1[0] - P2[0]) % p
    y3 = (lam * (P1[0] - x3) - P1[1]) % p
    return (x3, y3)


"""
This uses the double-and-add algorithm, an efficient method for scalar multiplication
Scalar multiplication involves computing the product of a scalar (integer) k and a point P on an elliptic curve
to obtain another point Q on the curve, denoted as Q=kP This operation is analogous to repeated addition in the
context of elliptic curves.
"""
def scalar_mult(k, P):
    Q = None
    for i in range(256):
        if k & (1 << i):
            Q = point_add(Q, P)
        P = point_add(P, P)
    return Q


def decrypt_key(private_key, ciphertext):
    C1, C2 = ciphertext
    S = scalar_mult(private_key, C1)
    plaintext = (C2 * mod_inverse(S[0], p)) % p
    return plaintext


print("Decrypting SALSA20 key...")
decrypted_key = decrypt_key(ec_private_key, encrypted_key)
recovered_salsa_key = decrypted_key.to_bytes(32, 'big')
print("SALSA20 key decrypted successfully.")

print("Decrypting file...")
decrypt_file('encrypted.bin', 'decrypted.txt', recovered_salsa_key)
print("File decrypted successfully")


# Verify the decrypted content
with open(input_file, 'r') as f_in, open('decrypted.txt', 'r') as f_out:
    original = f_in.read()
    decrypted = f_out.read()
    if original == decrypted:
        print("Decryption successful: original and decrypted files match.")
    else:
        print("Decryption failed: original and decrypted files do not match.")

Decrypting SALSA20 key...
SALSA20 key decrypted successfully.
Decrypting file...
File decrypted successfully
Decryption successful: original and decrypted files match.
