In [31]:
from math import sqrt, ceil
import os
import copy
import hashlib
import random
from typing import Tuple, Callable
import base64
import warnings
from sage.crypto.util import ascii_integer

bin = BinaryStrings()

#gerador de números primos : int(self.secure_param/2) +1
def genPrime(param):
    return random_prime(2**param-1,True,2**(param-1))


# Successive squaring algorithm
# effectively performs (a ** b) % m
def power(a, b, m):
    d = 1
    k = len(b.bits()) - 1
    for i in range(k, -1, -1):
        d = (d * d) % m
        if (b >> i) & 1:
            d = (d * a) % m
    return d

def os2ip(x):
    '''Converts an octet string to a nonnegative integer'''
    return int.from_bytes(x, byteorder='big')


def i2osp(x, xlen):
    '''Converts a nonnegative integer to an octet string of a specified length'''
    return int(x).to_bytes(xlen, byteorder='big')


def sha1(msg):
    '''SHA-1 hash function'''
    hasher = hashlib.sha1()
    hasher.update(msg)
    return hasher.digest()


def mgf(seed, mlen, f_hash = sha1):
    '''MGF1 mask generation function with SHA-1'''
    t = b''
    hlen = len(f_hash(b''))
    for c in range(0, ceil(mlen / hlen)):
        _c = i2osp(c, 4)
        t += f_hash(seed + _c)
    return t[:mlen]


def xor(data, mask):
    '''Byte-by-byte XOR of two byte arrays'''
    masked = b''
    ldata = len(data)
    lmask = len(mask)
    for i in range(max(ldata, lmask)):
        if i < ldata and i < lmask:
            masked += (data[i] ^^ mask[i]).to_bytes(1, byteorder='big')
        elif i < ldata:
            masked += data[i].to_bytes(1, byteorder='big')
        else:
            break
    return masked

class RSA:

    def __init__(self,secure_param):
        self.secure_param = int(secure_param)
        self.n = 0
        self.e = 0
        self.d = 0
        self.k = random.randint(0,10000)

    def keygen(self):
        # geração de números aletórios p e q
        p = genPrime(int(self.secure_param/2) +1)
        q = genPrime(int(self.secure_param/2))
        while p <= 2*q:
            p = genPrime(int(self.secure_param/2) +1)
            q = genPrime(int(self.secure_param/2))
        #cálculo do parâmetro n
        n = p*q
        self.n = n 
        #Cálculo de phi de n para primos
        phin = (p-1)*(q-1)
        #geração de número com inversa multiplicativamodulo phi de n
        #e tem de satisfazer a igualdade  1 < e < phi(N)
        e = randint(2,phin)
        #este ciclo assegura que depois de gerado E,
        #o mdc entre E e phi(N) tem que ser igual a 1
        while gcd(phin,e) != 1:
            e = randint(2,phin)
        self.e = e
        #cálculo da inversa de e
        d = power_mod(e,-1,phin)
        self.d = d
        #Cálculo das chaves pública e privada
        PubKey = (e,n)
        PrivKey = (d,n) # esta modificação na forma da chave é para reutilizar código (f: get_key_len())
        return PubKey,PrivKey


    def get_pub_key():
        return (self.e,self.n)
    
    def get_priv_key():
        return (self.d,self.n)
    
    #Retornatamanho da chave em bits
    def get_key_len(self, key):
        '''Get the number of octets of the public/private key modulus'''
        _, n = key
        return int(n).bit_length() // 8

    def oaep_encode(self, msg, k, label = b'', f_hash = sha1, f_mgf = mgf):
        '''EME-OAEP encoding'''
        mlen = len(msg)
        lhash = f_hash(label)
        hlen = len(lhash)
        ps = b'\x00' * (k - mlen - 2 * hlen - 2)
        db = lhash + ps + b'\x01' + msg
        seed = os.urandom(hlen)
        db_mask = f_mgf(seed, k - hlen - 1, f_hash)
        masked_db = xor(db, db_mask)
        seed_mask = f_mgf(masked_db, hlen, f_hash)
        masked_seed = xor(seed, seed_mask)
        return b'\x00' + masked_seed + masked_db


    def oaep_decode(self, cypher, k, label = b'', f_hash = sha1, f_mgf = mgf):
        '''EME-OAEP decoding'''
        clen = len(cypher)
        lhash = f_hash(label)
        hlen = len(lhash)
        _, masked_seed, masked_db = cypher[:1], cypher[1:1 + hlen], cypher[1 + hlen:]
        seed_mask = f_mgf(masked_db, hlen, f_hash)
        seed = xor(masked_seed, seed_mask)
        db_mask = f_mgf(seed, k - hlen - 1, f_hash)
        db = xor(masked_db, db_mask)
        _lhash = db[:hlen]
        assert lhash == _lhash
        i = hlen
        while i < len(db):
            if db[i] == 0:
                i += 1
                continue
            elif db[i] == 1:
                i += 1
                break
            else:
                raise Exception()
        m = db[i:]
        return m

    
    # Como gerar um k aleatório se 
    def Enc1(self, pub_key, msg, size):
        '''Encrypt k using RSA public key'''
        e, n = pub_key
        #k é aleatório
        return (pow(self.k, e, n), self.k)
    
    

    def Encrypt1(self, pub_key, msg, a):
        hsh = hashlib.sha1()
        hsh.update(str(msg).encode('ascii'))
        h = hsh.digest()
        b = a+str(msg).encode('ascii')
        size = len(b)
        (enc, k) = self.Enc1(pub_key, str(a) + str(h), size)
        k2 = i2osp(k,len(b))
        return (enc, xor(b,k2))

        
    def Encrypt(self, pub_key, msg, label = b"", f_hash = sha1):
        lhash = f_hash(label)
        hlen = len(lhash)
        seed = os.urandom(hlen)
        a = mgf(seed, hlen, sha1)
        return self.Encrypt1(pub_key, msg, a)


    def encrypt_raw(self, msg, public_key):
        '''Encrypt a byte array without padding'''
        k = self.get_key_len(public_key)
        return self.Encrypt(public_key, os2ip(msg))


    def encrypt_oaep(self, msg, public_key):
        '''Encrypt a byte array with OAEP padding'''
        hlen = 20  # SHA-1 hash length
        k = self.get_key_len(public_key)
        assert len(msg) <= k - hlen - 2
        return self.encrypt_raw(self.oaep_encode(msg, k), public_key)
    
    
    def Rev(self, priv_key, enc):
        '''Decrypt k using RSA private key'''
        d, n = priv_key
        return pow(enc, d, n)
    
    #Decrypt(secret_key, c): c é um par
    # (enc,m') <- c 
    # k        <- Rev(secret_key,enc)
    # a || m   <- xor(m',k) 
    # if (c == Encript1(pub_key,m,a)) then m else Fail 
   
    def Decrypt(self, pub_key, priv_key, c, label = b'', f_hash = sha1):
        lhash = f_hash(label)
        hlen = len(lhash)
        enc, m = c
        
        k = self.Rev(priv_key, enc)
        k2 = i2osp(lift(k),len(m))
        pad_msg = xor(m,k2)
        a, m = pad_msg[:hlen],pad_msg[hlen:]
        #verificação
        if c == self.Encrypt1(pub_key,int(m),a):
            return int(m)
        else:
            return False

        
    def decrypt_raw(self, cypher, private_key, public_key):
        '''Decrypt a cipher byte array without padding'''
        k = self.get_key_len(private_key)
        return i2osp(self.Decrypt(public_key, private_key, cypher),k)


    def decrypt_oaep(self, cypher, private_key, public_key):
        '''Decrypt a cipher byte array with OAEP padding'''
        k = self.get_key_len(private_key)
        (kLinha,enc) = cypher
        hlen = 20  # SHA-1 hash length
        assert k >= 2 * hlen + 2
        return self.oaep_decode(self.decrypt_raw(cypher, private_key, public_key), k)
    
    
X = RSA(512)
PubKey,PrivKey = X.keygen()
(k,cipher_text) = X.encrypt_oaep('hello world!'.encode('ascii'), PubKey)
print('RSA-OAEP Encrypted text is:')
print(base64.encodebytes(cipher_text).decode('ascii'))
print('RSA-OAEP Decrypted text is:')
plain_text = X.decrypt_oaep((k,cipher_text), PrivKey, PubKey)
print(plain_text.decode('ascii'))

RSA-OAEP Encrypted text is:
QNpXf4zniuy6tz0VbSfPm59TB2QzOTc4MTk1MDU0NTM2ODE4MDAxMjA5MzM5NzUxNjY0Mjg0MjYx
NjcxODE0MTgzNzY1Njg4NTYzNzIyMTY1NDExODY3OTk3MTYxNDE1MTc1NDUzMzQ5MjkyODk2Nzc3
MDI0NDQwMjkwMjIwMzM5NTUwOTI4OTA3ODE0Mzk4Njc0MjE0OTY0OTExNTQ3NDI0NjU2NzMxNjkj
iQ==

RSA-OAEP Decrypted text is:
hello world!
