Inteiro aleatório

In [1]:
import secrets

def generate_num():
    # Gera um inteiro aleatório de 1024 bits
    n = secrets.randbits(1024)

    # Garante que ele tenha exatamente 1024 bits e seja ímpar:
    n |= (1 << 1023)    # define o bit mais significativo
    n |= 1              # define o bit 0 para torná-lo ímpar

    return n


def primes_up_to(n: int) -> list[int]:
    if n < 2: 
        return []
    
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False

    # Marca múltiplos de cada primo como compostos
    lim = int(n**0.5)
    for p in range(2, lim + 1):
        if is_prime[p]:
            for multiple in range(p * p, n + 1, p):
                is_prime[multiple] = False

    # Coleta aqueles que permaneceram True
    return [i for i, prime in enumerate(is_prime) if prime]


def is_divisible_by_small_primes(n: int, primes: list[int]) -> bool:
    for p in primes:
        if p * p > n:
            break

        if n % p == 0:
            return n != p
        
    return False
    

def likely_prime_miller_rabin(n: int, k: int = 40):
    s, d = 0, n-1
    while d % 2 == 0:
        s += 1
        d //= 2

    for _ in range(k):
        a = secrets.randbelow(n-3) + 2  # base aleatória em [2, n-2]
        x = pow(a, d, n)
        if x in (1, n-1):
            continue
        for _ in range(s-1):
            x = pow(x, 2, n)
            if x == n-1:
                break
        else:
            return False
        
    return True

def generate_prime():
    primes_up_to_2000 = primes_up_to(2000)

    for _ in range(10_000):
        num = generate_num()
        if is_divisible_by_small_primes(num, primes_up_to_2000):
            continue

        if likely_prime_miller_rabin(num, 40):
            return num
        
    raise TimeoutError("Failed to generate prime after 50k iterations.")

In [2]:
num = generate_prime()
num

99039502339928311167911017692805904869657351799754507225836299038483488945891077793379396270992078273771847975267831889726789026400191041905003266300616353557397696354763691625919309798502816656382834356973953067464397789240706231468298086177502287658988735284748592871649583414565453924263591908381205734279

# RSA

In [3]:
class PublicKeyInt:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        e = 65537

        self.key = (n, e)

    def verify(self, message: int, signature: int):
        return self.encrypt(signature) == message

    def encrypt(self, message: int):
        n, e = self.key
        return pow(message, e, n)


class PrivateKeyInt:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537
        d = pow(e, -1, phi)  # e^-1 (mod phi)

        self.key = (n, d)

    def sign(self, message: int):
        return self.decrypt(message)

    def decrypt(self, encrypted: int):
        n, d = self.key
        return pow(encrypted, d, n)

    def derive_public_key(self):
        return PublicKeyInt(self.p, self.q)

    @staticmethod
    def generate():
        return PrivateKeyInt(generate_prime(), generate_prime())

In [4]:
private_key = PrivateKeyInt.generate()
public_key = private_key.derive_public_key()

print("Chave privada:", private_key.key)
print("Chave pública:", public_key.key)

encrypted = public_key.encrypt(42)
print("Valor criptografado:", encrypted)
print("Valor descriptografado:", private_key.decrypt(encrypted))

Chave privada: (17631324864090333423883312230479075892941306243114681948164588717202910649125611972634851398058974549680616159285815331048222023357237688100578883790172618752576490884210222982994560515642287499314228197274337953023526563772951205936142296973936144522510888730897139855194432554584657971516208092896661972714309836523375461402100050868090676874451440127356723169068266052466746929762672821710097782641206169089268510298049584135652971282059267211747412825282183963936791120878195965083926774881763786622079723091709580831408978532622162508570985879120923961081870023950433015675792047922240117495569280413016169378077, 10580355284113472281222559827277432545243079358670891430145062255979920808227597657197971158015523308815467175982918738109228004556104853710416502764848997517453167424420093221771061781271341116819963784463816821284301275616256847689955357361352528980290648519290213409602005638138844144586865921196269619648411275323411603685002905860100494422350560199226261397704

In [5]:
private_key = PrivateKeyInt.generate()
public_key = private_key.derive_public_key()

signature = private_key.sign(42)
print(signature)
print(public_key.verify(42, signature))

1253795199040440088767614308298817598717295358778549032857305244912163498291340505627525733019343614177894747272518776771915202064041437448219674355416141987094037122856482425730865887056624416097464408284535945648361803204979237015356904014038823145036559365274229197505438264562712955994231611172974391144332681967838744289766034528564931949574596289067926551131532642368979492079250833187084763252385453370059955328860673165643539442393902846194129127622486855983814679174082873436265485366161350877549752291579138229527165673690534327172762051601160387414269832185104589343215887362979914348828958221668254227182
True


# OAEP

In [6]:
import hashlib


def hash_sha3(data: bytes):
    sha3_256_hasher = hashlib.sha3_256()
    sha3_256_hasher.update(data)

    return sha3_256_hasher.digest()

In [21]:
import math
from typing import Callable


def i2osp(x: int, size: int) -> bytes:
    # Converte o inteiro x em um array de bytes big-endian de comprimento size.
    return x.to_bytes(size, byteorder="big")


def mgf(seed: bytes, mask_len: int, hash_func: Callable[[bytes], bytes] = hash_sha3) -> bytes:
    hash_bytes = len(hash_func(b""))
    
    # Limite de segurança: mask_len não pode exceder 2^32 * hash_bytes
    if mask_len > (2**32) * hash_bytes:
        raise ValueError("Mask too long")

    T = bytearray()
    # número de iterações necessárias
    for counter in range(math.ceil(mask_len / hash_bytes)):
        C = i2osp(counter, 4)
        T.extend(hash_func(seed + C))

    return bytes(T[:mask_len])

In [20]:
import secrets
from typing import Callable, Any


def generate_db(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    ps_size = k - len(message) - 2 * hash_bytes - 2

    return hash_func(label) + (b"\x00" * ps_size) + b"\x01" + message


def generate_seed(size: int):
    return secrets.token_bytes(size)


def oaep_encode(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    if len(message) > k - 2 * hash_bytes - 2:
        raise ValueError("message too long for OAEP")

    datablock = generate_db(message, rsa_bits, hash_func, label)
    seed = generate_seed(hash_bytes)

    db_mask = mgf(seed, k - hash_bytes - 1, hash_func)
    datablock_xor = bytes(a ^ b for a, b in zip(db_mask, datablock))

    seed_mask = mgf(datablock_xor, hash_bytes, hash_func)
    seed_xor = bytes(a ^ b for a, b in zip(seed_mask, seed))

    return b"\x00" + seed_xor + datablock_xor


def oaep_decode(
    encoded_block: bytes, rsa_bits: int = 1024, hash_func=hash_sha3, label: bytes = b""
) -> bytes:
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8

    if len(encoded_block) != k:
        raise ValueError("Incorrect block size")
    
    if encoded_block[0] != 0:
        raise ValueError("Decoding error: first byte must be 0x00")

    # split off maskedSeed and maskedDB
    maskedSeed = encoded_block[1 : 1 + hash_bytes]
    maskedDB = encoded_block[1 + hash_bytes :]

    # 1) recover seed = maskedSeed ⊕ MGF1(maskedDB, hash_bytes)
    seedMask = mgf(maskedDB, hash_bytes, hash_func)
    seed = bytes(ms ^ sm for ms, sm in zip(maskedSeed, seedMask))

    # 2) recover DB   = maskedDB   ⊕ MGF1(seed, k-hash_bytes-1)
    dbMask = mgf(seed, k - hash_bytes - 1, hash_func)
    DB = bytes(md ^ dm for md, dm in zip(maskedDB, dbMask))

    # 3) verify lHash
    lHash = hash_func(label)
    if DB[:hash_bytes] != lHash:
        raise ValueError("Decoding error: label hash mismatch")

    # 4) find the 0x01 delimiter and extract M
    #    DB = lHash || PS (zeros) || 0x01 || message
    try:
        idx = DB.index(b"\x01", hash_bytes)
    except ValueError:
        raise ValueError("Decoding error: 0x01 delimiter not found")

    return DB[idx + 1 :]

In [22]:
message = oaep_encode("hello".encode("utf-8"))
message

b'\x00\xf1\x06\xdc\x04\xeb\xd9=\xa0\xdd\xfb\xd5\x87)\x15dwm\x9cgA\xca\xb3\xb2\xe7\xa5"|m\x8e\xeb\xe2\x18\n/F\xac\xf2\xbd\n\xc1(\xe1\xc7p\x19\xe2x\xc5\x90\xe6+}&\xff\xf9X\x06\x89\xd1\xbb\xe6,5\x8f\x83\x0f\x1f\xae\xe8&\xa5:s\xd7\xa2\x95\x0f\xbfr\xe8\xd1]\x1c\xcd\xe5\x9dg\xc6%:\xb8\x1f$\xea\xa1\x83s\xac\x7fB\xd6\x83\xec\xe2\xe4\x9f$\t\x1c\xc5\x84 \xce*\xbb\xbd%D\xfa!g}^M\x1d\xfbE'

In [23]:
oaep_decode(message).decode("utf-8")

'hello'

In [24]:
class PublicKey:
    def __init__(self, public_key: PublicKeyInt, rsa_bits: int = 2048):
        self.public_key = public_key
        self.rsa_bits = rsa_bits

    def encrypt(self, message: bytes):
        encoded = oaep_encode(message, self.rsa_bits)
        encrypted = self.public_key.encrypt(int.from_bytes(encoded, byteorder="big"))
        return encrypted.to_bytes(self.rsa_bits // 8, byteorder="big")


class PrivateKey:
    def __init__(self, private_key: PrivateKeyInt, rsa_bits: int = 2048):
        self.private_key = private_key
        self.rsa_bits = rsa_bits

    def decrypt(self, encrypted: bytes):
        decrypted = self.private_key.decrypt(int.from_bytes(encrypted, byteorder="big"))
        decrypted_bytes = decrypted.to_bytes(self.rsa_bits // 8, byteorder="big")
        return oaep_decode(decrypted_bytes, self.rsa_bits)

    def derive_public_key(self):
        return PublicKey(self.private_key.derive_public_key())
    
    @staticmethod
    def generate():
        return PrivateKey(PrivateKeyInt.generate())

In [25]:
private_key = PrivateKey.generate()
public_key = private_key.derive_public_key()

In [26]:
message = "Attack at Dawn"
encrypted = public_key.encrypt(message.encode("utf-8"))
encrypted

b's-\xd8\xca\xb0\xea\x18\x17p\x11\xc2"\xa8n\xf8\tf\xda\xf4\xcf\xe4y\t5\xa9\xac5\x1c\xd2\x88w>\xa7F\xf1\xdc\x02\x14l\xe5\xea\x0f\xd6\x19\xf0D\x15\x10\xd5\xc9-[\xf4\x10\x1f?\xce\xae\xd9\xbck\xa5\x16\xc1e\x8a\xf6\x83\x05\xfd\xe5%\xb6\xa6\xd8^\xc5\xfc\xb0\x94\x0b=s\x9a\x91\xa2\x98\xc0\x81\xb9\xb4`\xc5\xb1\xf1\xc7i\x0c\x89\xadNR\xe0+Y\x05KL\xbet\x8c\xab\xb2\xbe\x9d\x88\x88\xf0\x8bP\xcfp\xb56\xed\x15\xf2\x866\x8d\x0eC\xf3dT\xe7\xcc\xd5\xe0i\xde\x04(II\x08\x0for\xcf\xfb\xae\t\xc1d\xfdT\xa8\r\xdce\xdb1K\xdd\x8a)\xed`w5\x88\xd0"\xb5\x0b_\x89\xc8;\xdd\x1f\x14\x82\xd2\x04\xf7{\x0c\xf8(\xe1\x92\xc4u\xe1\x97\xa6\xc8\xe8\xea\xd4\xfc\xaa\xda\xec2\xd8\xe8\xd97~`\x82\x85\xee]\xb5\xfa\x95\xcfr\r\xb9\x92\xa8\x00\xebI\xb50"\xfe\r\xef\xf6\xea\xa9\x01\xc1\xe3\x0c\xc6\x17t\x1f\xd6\x7f+LF*\xdbcx\xb5'

In [27]:
decrypted = private_key.decrypt(encrypted)
decrypted.decode("utf-8")

'Attack at Dawn'