Inteiro aleatório

In [130]:
import secrets

def generate_num():
    # Gera um inteiro aleatório de 1024 bits
    n = secrets.randbits(1024)

    # Garante que ele tenha exatamente 1024 bits e seja ímpar:
    n |= (1 << 1023)    # define o bit mais significativo
    n |= 1              # define o bit 0 para torná-lo ímpar

    return n


def primes_up_to(n: int) -> list[int]:
    if n < 2: 
        return []
    
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False

    # Marca múltiplos de cada primo como compostos
    lim = int(n**0.5)
    for p in range(2, lim + 1):
        if is_prime[p]:
            for multiple in range(p * p, n + 1, p):
                is_prime[multiple] = False

    # Coleta aqueles que permaneceram True
    return [i for i, prime in enumerate(is_prime) if prime]


def is_divisible_by_small_primes(n: int, primes: list[int]) -> bool:
    for p in primes:
        if p * p > n:
            break

        if n % p == 0:
            return n != p
        
    return False
    

def likely_prime_miller_rabin(n: int, k: int = 40):
    s, d = 0, n-1
    while d % 2 == 0:
        s += 1
        d //= 2

    for _ in range(k):
        a = secrets.randbelow(n-3) + 2  # base aleatória em [2, n-2]
        x = pow(a, d, n)
        if x in (1, n-1):
            continue
        for _ in range(s-1):
            x = pow(x, 2, n)
            if x == n-1:
                break
        else:
            return False
        
    return True

def generate_prime():
    primes_up_to_2000 = primes_up_to(2000)

    for _ in range(10_000):
        num = generate_num()
        if is_divisible_by_small_primes(num, primes_up_to_2000):
            continue

        if likely_prime_miller_rabin(num, 40):
            return num
        
    raise TimeoutError("Failed to generate prime after 50k iterations.")

In [131]:
num = generate_prime()
num

135503862888757185693212828405051416767092771735314514063701781350446801135330331938711781237808372107649546810624430584087712880411052242761837650919564909490841990760316799720824759084447053510189611250005628653039961643053462677663175861274809627007051018062308029994187763690351097188013551140295845661009

# RSA

In [132]:
class PublicKeyInt:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        e = 65537

        self.key = (n, e)

    def verify(self, message: int, signature: int):
        return self.encrypt(signature) == message

    def encrypt(self, message: int):
        n, e = self.key
        return pow(message, e, n)


class PrivateKeyInt:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537
        d = pow(e, -1, phi)  # e^-1 (mod phi)

        self.key = (n, d)

    def sign(self, message: int):
        return self.decrypt(message)

    def decrypt(self, encrypted: int):
        n, d = self.key
        return pow(encrypted, d, n)

    def derive_public_key(self):
        return PublicKeyInt(self.p, self.q)

    @staticmethod
    def generate():
        return PrivateKeyInt(generate_prime(), generate_prime())

In [133]:
private_key = PrivateKeyInt.generate()
public_key = private_key.derive_public_key()

encrypted = public_key.encrypt(42)
print(encrypted)
print(private_key.decrypt(encrypted))

1109320190365225833453406389234499319043057987857080411113566775728754786632775539482692733861929392901786335312951736435968889738792928693117717356007165216332731593551764084506193117441704077628144170410502675837952450468185623383772794177398753668699977864260972155896396315707843783814953935450737388980924700416744224876539413824158756280413396819086753056188732352299830634505461096594687289018211581529541455152215447652200545980125115734183492458061165808344765709003744958895239884600928346926994909262706084019067936342996575460671035001586541010478525527192697488727027044059674214285667306203984831457495
42


In [134]:
private_key = PrivateKeyInt.generate()
public_key = private_key.derive_public_key()

signature = private_key.sign(42)
print(signature)
print(public_key.verify(42, signature))

9366872236560745850060376424351834871765556244900581842670549270610694848226889795841396547374433219203873688317609664752272437620541548040342886314053495676556964020408538008615770730752812665666141660705800515316946372158969444592864906595660526598614217358922153963593697066648530019117909785419491086778107573706479753082640127080879514672797556709263201269297499937643720903888686096883316089334907145548348909884521163193435008553912537486632810514808563178981540194211504950935184661289555748919066333358818713800537625323257525499122439157438782335667705724532486809248867249689332764164847450791992930879823
True


# OAEP

In [135]:
import hashlib


def hash_sha3(data: bytes):
    sha3_256_hasher = hashlib.sha3_256()
    sha3_256_hasher.update(data)

    return sha3_256_hasher.digest()

In [136]:
import math


def i2osp(x: int, size: int) -> bytes:
    # Converte o inteiro x em um array de bytes big-endian de comprimento size.
    return x.to_bytes(size, byteorder="big")


def mgf(seed: bytes, mask_len: int, hash_bytes: int) -> bytes:
    # Limite de segurança: mask_len não pode exceder 2^32 * hash_bytes
    if mask_len > (2**32) * hash_bytes:
        raise ValueError("Mask too long")

    T = bytearray()
    # número de iterações necessárias
    for counter in range(math.ceil(mask_len / hash_bytes)):
        C = i2osp(counter, 4)
        T.extend(hash(seed + C))

    return bytes(T[:mask_len])

In [137]:
import secrets
from typing import Callable, Any


def generate_db(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    ps_size = k - len(message) - 2 * hash_bytes - 2

    return hash(label) + b"\x00" * ps_size + b"\x01" + message


def generate_seed(size: int):
    return secrets.token_bytes(size)


def oaep_encode(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    if len(message) > k - 2 * hash_bytes - 2:
        raise ValueError("message too long for OAEP")

    datablock = generate_db(message, rsa_bits, hash_func, label)
    seed = generate_seed(hash_bytes)

    db_mask = mgf(seed, k - hash_bytes - 1, hash_bytes)
    datablock_xor = bytes(a ^ b for a, b in zip(db_mask, datablock))

    seed_mask = mgf(datablock_xor, hash_bytes, hash_bytes)
    seed_xor = bytes(a ^ b for a, b in zip(seed_mask, seed))

    return b"\x00" + seed_xor + datablock_xor


def oaep_decode(
    encoded_block: bytes, rsa_bits: int = 1024, hash_func=hash_sha3, label: bytes = b""
) -> bytes:
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8

    if len(encoded_block) != k:
        raise ValueError("Incorrect block size")
    
    if encoded_block[0] != 0:
        raise ValueError("Decoding error: first byte must be 0x00")

    # split off maskedSeed and maskedDB
    maskedSeed = encoded_block[1 : 1 + hash_bytes]
    maskedDB = encoded_block[1 + hash_bytes :]

    # 1) recover seed = maskedSeed ⊕ MGF1(maskedDB, hash_bytes)
    seedMask = mgf(maskedDB, hash_bytes, hash_bytes)
    seed = bytes(ms ^ sm for ms, sm in zip(maskedSeed, seedMask))

    # 2) recover DB   = maskedDB   ⊕ MGF1(seed, k-hash_bytes-1)
    dbMask = mgf(seed, k - hash_bytes - 1, hash_bytes)
    DB = bytes(md ^ dm for md, dm in zip(maskedDB, dbMask))

    # 3) verify lHash
    lHash = hash_func(label)
    if DB[:hash_bytes] != lHash:
        raise ValueError("Decoding error: label hash mismatch")

    # 4) find the 0x01 delimiter and extract M
    #    DB = lHash || PS (zeros) || 0x01 || message
    try:
        idx = DB.index(b"\x01", hash_bytes)
    except ValueError:
        raise ValueError("Decoding error: 0x01 delimiter not found")

    return DB[idx + 1 :]

In [138]:
message = oaep_encode("hello".encode("utf-8"))
message

b'\x00|\xeaK\xbc\x8d\xd1^t\x95LK\xd1f\xfdK\x9c\xd0\x1b\xd5\xda\xb6\xb1$\x1f&\xd0\xd5\x02\xa3\xfb\x05\x1f0\xc3\xe3\x1f\x82\xcdC\xb3\xed8\xb3\xd5\xd9\x9ae$\xd7c\x0crL\xe7\xf9\xee\x15\xb0\xdb#K\x88x-\xbb\x03\x8a\x1a\xbf^/>\xf6w\x18\xa6\xd1\xdf|\xaa*\xfe\xc1\xfb6\x9e\x1f\xb5rh\x05d\t6nfP=~!\xd9\x13\xed\xb6\x0c]\xce\x013W\x1b\xe91\xf9\xcaI\x1e\x8e\xeaw\xe5\xb0\xc6\x8ae|\x9d'

In [139]:
oaep_decode(message).decode("utf-8")

'hello'

In [149]:
class PublicKey:
    def __init__(self, public_key: PublicKeyInt, rsa_bits: int = 2048):
        self.public_key = public_key
        self.rsa_bits = rsa_bits

    def encrypt(self, message: bytes):
        encoded = oaep_encode(message, self.rsa_bits)
        encrypted = self.public_key.encrypt(int.from_bytes(encoded, byteorder="big"))
        return encrypted.to_bytes(self.rsa_bits // 8, byteorder="big")


class PrivateKey:
    def __init__(self, private_key: PrivateKeyInt, rsa_bits: int = 2048):
        self.private_key = private_key
        self.rsa_bits = rsa_bits

    def decrypt(self, encrypted: bytes):
        decrypted = self.private_key.decrypt(int.from_bytes(encrypted, byteorder="big"))
        decrypted_bytes = decrypted.to_bytes(self.rsa_bits // 8, byteorder="big")
        return oaep_decode(decrypted_bytes, self.rsa_bits)

    def derive_public_key(self):
        return PublicKey(self.private_key.derive_public_key())
    
    @staticmethod
    def generate():
        return PrivateKey(PrivateKeyInt.generate())

In [150]:
private_key = PrivateKey.generate()
public_key = private_key.derive_public_key()

In [151]:
message = "Attack at Dawn"
encrypted = public_key.encrypt(message.encode("utf-8"))
encrypted

b'4\x80C\xe4\xe6s\x92K\x95\xb2\x9b/c\xa4\xecU\x80\xe7B\xc4m\xa3\xb0o\xa3\x06\xb7\xc6(\xd0s\xf4\x02?8}\xd3\xd3%\xc96\xc2\xc2i\xcb;z\x12\'\xb4B)\xc2\xce"\x92\xa7\x17t_l+L5\xed\xb1-\xee\x89\xd8[n\xf8\xcdhI\xee\xeb\n"\x0f\x92H\x8e\xf5\xee\xfe\xef\xbdt\x0f\x91G\xee\xc73\nX\x82\xb8>\xf1\x16{\xb9\xbc\xa2\xb9k\x0e\x83\xd3\x83 Y\xb0\x15D\xfb\xb3\xed\xd6\xed\xcb\x971\xd9\xd17\xef\xa1\x8cC\x85>\xbe\x0b\xfe\x12\xc7\xca\xdb6\x1a\x8c\x0f\xc0\xd8@\xc2\x05M\xb4\xad><\x12\xdf\x0cR\x83\x12\x8f%-\x04-\x7f\x14K/\xe2\xa3\x05\x7f\xdc\x0f\xe0~^\xc5v\xc4\xcf@\x84:Z\xe300\xbe\xf5\x04F\x0fX6$Y\xdf\x0bC\xfd\xa7}_\x0c_\xe1V/\x86\xd8\xbc\xee\xe9\x016(\xcdz[YZ/\xc2a\xa7G1fc\x97b\xe3\x8d\xa3\xe4\xb4~\xe1\x9a3/-x\xe8\x11\xeb\xb9T7.\x0f\x8e'

In [152]:
decrypted = private_key.decrypt(encrypted)
decrypted.decode("utf-8")

'Attack at Dawn'