Inteiro aleatório

In [None]:
import secrets

def generate_num():
    # Gera um inteiro aleatório de 1024 bits
    n = secrets.randbits(1024)

    # Garante que ele tenha exatamente 1024 bits e seja ímpar:
    n |= (1 << 1023)    # define o bit mais significativo
    n |= 1              # define o bit 0 para torná-lo ímpar

    return n


def primes_up_to(n: int) -> list[int]:
    if n < 2: 
        return []
    
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False

    # Marca múltiplos de cada primo como compostos
    lim = int(n**0.5)
    for p in range(2, lim + 1):
        if is_prime[p]:
            for multiple in range(p * p, n + 1, p):
                is_prime[multiple] = False

    # Coleta aqueles que permaneceram True
    return [i for i, prime in enumerate(is_prime) if prime]


def is_divisible_by_small_primes(n: int, primes: list[int]) -> bool:
    for p in primes:
        if p * p > n:
            break

        if n % p == 0:
            return n != p
        
    return False
    

def likely_prime_miller_rabin(n: int, k: int = 40):
    s, d = 0, n-1
    while d % 2 == 0:
        s += 1
        d //= 2

    for _ in range(k):
        a = secrets.randbelow(n-3) + 2  # base aleatória em [2, n-2]
        x = pow(a, d, n)
        if x in (1, n-1):
            continue
        for _ in range(s-1):
            x = pow(x, 2, n)
            if x == n-1:
                break
        else:
            return False
        
    return True

def generate_prime():
    primes_up_to_2000 = primes_up_to(2000)

    for _ in range(10_000):
        num = generate_num()
        if is_divisible_by_small_primes(num, primes_up_to_2000):
            continue

        if likely_prime_miller_rabin(num, 40):
            return num
        
    raise TimeoutError("Failed to generate prime after 50k iterations.")

In [33]:
num = generate_prime()
num

132969991645413115086731765141296196556073117841327563220755524396879773796629336394602678758629813152277058823390705691896160054431466674540597801714896636350325448064334899567341498138503752296158413913701894368368686943963066281660834317286923214700737732647223160305236049350625574239539860494953737927343

# RSA

In [52]:
class PublicKey:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        e = 65537

        self.key = (n, e)

    def verify(self, message: int, signature: int):
        return self.encrypt(signature) == message

    def encrypt(self, message: int):
        n, e = self.key
        return pow(message, e, n)


class PrivateKey:
    def __init__(self, p: int, q: int):
        self.p = p
        self.q = q

        n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537
        d = pow(e, -1, phi)  # e^-1 (mod phi)

        self.key = (n, d)

    def sign(self, message: int):
        return self.decrypt(message)

    def decrypt(self, encrypted: int):
        n, d = self.key
        return pow(encrypted, d, n)

    def derive_public_key(self):
        return PublicKey(self.p, self.q)

    @staticmethod
    def generate():
        return PrivateKey(generate_prime(), generate_prime())

In [53]:
private_key = PrivateKey.generate()
public_key = private_key.derive_public_key()

encrypted = public_key.encrypt(42)
print(encrypted)
print(private_key.decrypt(encrypted))

9802800770874020263726328074971390033149990873237056087711617728222194915529507993867258910533701886328964328626983837967592099330677210671439232867589501727063617132224544587097525582064592395048314834251028834582194346275539865243522500937865312843989872827663825136844465443254576784032171925052393445719723676076086770376908106410776378517346477196044721130385370626103157834925044942115706598483090390075293371833754461787480545634412055083140544821534239414562582044549363675436645539107093717045174526681071335075256219423680379200678772303848447378619313137196513863548725756184261772836271167143908927280290
42


In [58]:
private_key = PrivateKey.generate()
public_key = private_key.derive_public_key()

signature = private_key.sign(42)
print(signature)
print(public_key.verify(42, signature))

8851725765535431813821541256208142707732301542596334628163123986410958208183189863455739180316042691342676365241388379514685595391265252859647639175620594984199699946133720975540183889559669966026918193785598684146818402345942346343386068703781939875619071527405757843442729918138632330795000171521318003675132581506788596840634257298113586531730525266053454320412931226657509802558827418170615178079895329209096695642051583562888587121729520573152611125470648081810369323537592454392928582477123004066084135261440660353332043724626182026186633297145561617806876684141324455529471892836766469045155744349746184334085
True


# OAEP

In [116]:
import hashlib


def hash_sha3(data: bytes):
    sha3_256_hasher = hashlib.sha3_256()
    sha3_256_hasher.update(data)

    return sha3_256_hasher.digest()

In [None]:
import math


def i2osp(x: int, size: int) -> bytes:
    # Converte o inteiro x em um array de bytes big-endian de comprimento size.
    return x.to_bytes(size, byteorder="big")


def mgf(seed: bytes, mask_len: int, hash_bytes: int) -> bytes:
    # Limite de segurança: mask_len não pode exceder 2^32 * hash_bytes
    if mask_len > (2**32) * hash_bytes:
        raise ValueError("Mask too long")

    T = bytearray()
    # número de iterações necessárias
    for counter in range(math.ceil(mask_len / hash_bytes)):
        C = i2osp(counter, 4)
        T.extend(hash(seed + C))

    return bytes(T[:mask_len])

In [126]:
import secrets
from typing import Callable, Any


def generate_db(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    ps_size = k - len(message) - 2 * hash_bytes - 2

    return hash(label) + b"\x00" * ps_size + b"\x01" + message


def generate_seed(size: int):
    return secrets.token_bytes(size)


def oaep_encode(
    message: bytes,
    rsa_bits: int = 1024,
    hash_func: Callable[[bytes], bytes] = hash_sha3,
    label: bytes = b"",
):
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8
    if len(message) > k - 2 * hash_bytes - 2:
        raise ValueError("message too long for OAEP")

    datablock = generate_db(message, rsa_bits, hash_func, label)
    seed = generate_seed(hash_bytes)

    db_mask = mgf(seed, k - hash_bytes - 1, hash_bytes)
    datablock_xor = bytes(a ^ b for a, b in zip(db_mask, datablock))

    seed_mask = mgf(datablock_xor, hash_bytes, hash_bytes)
    seed_xor = bytes(a ^ b for a, b in zip(seed_mask, seed))

    return b"\x00" + seed_xor + datablock_xor


def oaep_decode(
    encoded_block: bytes, rsa_bits: int = 1024, hash_func=hash_sha3, label: bytes = b""
) -> bytes:
    hash_bytes = len(hash_func(b""))
    k = rsa_bits // 8

    if len(encoded_block) != k:
        raise ValueError("Incorrect block size")
    
    if encoded_block[0] != 0:
        raise ValueError("Decoding error: first byte must be 0x00")

    # split off maskedSeed and maskedDB
    maskedSeed = encoded_block[1 : 1 + hash_bytes]
    maskedDB = encoded_block[1 + hash_bytes :]

    # 1) recover seed = maskedSeed ⊕ MGF1(maskedDB, hash_bytes)
    seedMask = mgf(maskedDB, hash_bytes, hash_bytes)
    seed = bytes(ms ^ sm for ms, sm in zip(maskedSeed, seedMask))

    # 2) recover DB   = maskedDB   ⊕ MGF1(seed, k-hash_bytes-1)
    dbMask = mgf(seed, k - hash_bytes - 1, hash_bytes)
    DB = bytes(md ^ dm for md, dm in zip(maskedDB, dbMask))

    # 3) verify lHash
    lHash = hash_func(label)
    if DB[:hash_bytes] != lHash:
        raise ValueError("Decoding error: label hash mismatch")

    # 4) find the 0x01 delimiter and extract M
    #    DB = lHash || PS (zeros) || 0x01 || message
    try:
        idx = DB.index(b"\x01", hash_bytes)
    except ValueError:
        raise ValueError("Decoding error: 0x01 delimiter not found")

    return DB[idx + 1 :]

In [127]:
message = oaep_encode("hello".encode("utf-8"))
message

b'\x00\xa4\xael\xf4\x1cn\xf2\xd6\xd0\xb7B~]\x99\x03\xdd\x8a7\xa2N\xd6aU\x07.\xd1\x1f\x98\xab:\x89b-\x9b\xdb\x8b\xea=\x06\x12\x99`\xcak\xc8\xd6c_i\x8d{\x8b\xe8\x9c\x99\x06?\xfd\xbd\xd3%\xc5=\xa4\x00>@\xdc\x0bG\x98\x9a\x8d\xb6\x93TH\x90\xed\xdf\xa7\x1e\x1fYh\x90\xd3%oQ!\x0b\\\xf6\x168&&\xa2\xbc7\x89(eO}\xd9\xf2\xe2\x18\xe0\x14\x0f\x0f\xf7\xfa\xaaN\x8c\xba[AD\xb4\xf0\xa1\x1e'

In [128]:
oaep_decode(message).decode("utf-8")

'hello'

In [108]:
# generate the OAEP‐padded block
message = generate_oaep("hello")

# turn it into an integer (big-endian!)
ct_int = int.from_bytes(message, byteorder="big")

# back to bytes: specify the same length and byteorder
k = rsa_size // 8
recovered = ct_int.to_bytes(k, byteorder="big")

assert recovered == message
message

b'\x00\xcc\xc1\xe1\xbe\x07\xea\x88\x14\xdc47]NRp&\xfbD|R\xf0\xcf\x92k\xc6\x07\xbf\x02J\x80\xf5G\x08\x0c\xb3r\xfa]\xf0%\xf5\x17\x97\xb5\xed\xfa\x10\xcc\r\xa3\x9d|\xaa\x80\x02\x14\x93\x9f\xf4Y\xa8\xe5\xb6\xe4+\xe0:\xd7U\xf1\xf8{-\xbd\xef\xbd\xf6\x17S-[\xdf\xd6\xe8elS\x1d\xe1\xeb\xa6\xbf{\xdc\xe7\x89\xf81\xd3\xe3\x15\x12o\xa9@=\xed\xfb\xe3)\x8agI\x13\xe7\x92\xe4;2\x97P\xc13*\xb1\xe1\x9c'

In [113]:
private_key = PrivateKey.generate()
public_key = private_key.derive_public_key()
message = generate_oaep("hello")

encrypted = public_key.encrypt(int.from_bytes(message, byteorder="big"))
message

b'\x00\xef\xc7\xb1\xd9\x8d\xc6l\x11i\xbaX(\xfbAi\\9\xcd\x81\x8b%\x92q(\x1f\xa1\xcc\xb6\x15$e\xac\xfa\x13\x8b\x1e^\xc4\xa3\xa2\xce\x9d\x18"I>\xc4\xe4h\xf5\x03p\xfb\xe2~E}\xbbH\x98\x8d\xdb\xe5`\x08qj\xb5\xe2 \xda e(\xb2w[\x1c^<%V\xe9\xf0\x07h\xea-\xe8\x07[\x15\x16\x173\xf2\xf6\xb4wT\x96\xa8J\xd9\xb1\xc8?\x1c\xcf\xcbM\x10O\x8c\xe4\xc8\x0e\x08\x8f\x96_m\xab\'a\xec6'

In [114]:
decrypted = private_key.decrypt(encrypted)

k = rsa_size // 8
message = decrypted.to_bytes(k, byteorder="big")
message

b'\x00\xef\xc7\xb1\xd9\x8d\xc6l\x11i\xbaX(\xfbAi\\9\xcd\x81\x8b%\x92q(\x1f\xa1\xcc\xb6\x15$e\xac\xfa\x13\x8b\x1e^\xc4\xa3\xa2\xce\x9d\x18"I>\xc4\xe4h\xf5\x03p\xfb\xe2~E}\xbbH\x98\x8d\xdb\xe5`\x08qj\xb5\xe2 \xda e(\xb2w[\x1c^<%V\xe9\xf0\x07h\xea-\xe8\x07[\x15\x16\x173\xf2\xf6\xb4wT\x96\xa8J\xd9\xb1\xc8?\x1c\xcf\xcbM\x10O\x8c\xe4\xc8\x0e\x08\x8f\x96_m\xab\'a\xec6'