In [8]:
from math import sqrt, ceil
import os
import copy
import hashlib
import random
from typing import Tuple, Callable
import base64
import warnings

#gerador de números primos : int(self.secure_param/2) +1
def genPrime(param):
    return random_prime(2**param-1,True,2**(param-1))


# Successive squaring algorithm
# effectively performs (a ** b) % m
def power(a, b, m):
    d = 1
    k = len(b.bits()) - 1
    for i in range(k, -1, -1):
        d = (d * d) % m
        if (b >> i) & 1:
            d = (d * a) % m
    return d

def os2ip(x):
    '''Converts an octet string to a nonnegative integer'''
    return int.from_bytes(x, byteorder='big')


def i2osp(x, xlen):
    '''Converts a nonnegative integer to an octet string of a specified length'''
    return int(x).to_bytes(xlen, byteorder='big')


def sha1(msg):
    '''SHA-1 hash function'''
    hasher = hashlib.sha1()
    hasher.update(msg)
    return hasher.digest()


def mgf(seed, mlen, f_hash = sha1):
    '''MGF1 mask generation function with SHA-1'''
    t = b''
    hlen = len(f_hash(b''))
    for c in range(0, ceil(mlen / hlen)):
        _c = i2osp(c, 4)
        t += f_hash(seed + _c)
    return t[:mlen]


def xor(data, mask):
    '''Byte-by-byte XOR of two byte arrays'''
    masked = b''
    ldata = len(data)
    lmask = len(mask)
    for i in range(max(ldata, lmask)):
        if i < ldata and i < lmask:
            masked += (data[i] ^^ mask[i]).to_bytes(1, byteorder='big')
        elif i < ldata:
            masked += data[i].to_bytes(1, byteorder='big')
        else:
            break
    return masked

class RSA:

    def __init__(self,secure_param):
        self.secure_param = int(secure_param)
        self.r = 788255724614721016190591162463944054696650907899
        self.n = 0
        self.e = 0
        self.d = 0

    def keygen(self):
        # geração de números aletórios p e q
        p = genPrime(int(self.secure_param/2) +1)
        q = genPrime(int(self.secure_param/2))
        while p <= 2*q:
            p = genPrime(int(self.secure_param/2) +1)
            q = genPrime(int(self.secure_param/2))
        #cálculo do parâmetro n
        n = p*q
        self.n = n 
        #Cálculo de phi de n para primos
        phin = (p-1)*(q-1)
        #geração de número com inversa multiplicativamodulo phi de n
        #e tem de satisfazer a igualdade  1 < e < phi(N)
        e = randint(2,phin)
        #este ciclo assegura que depois de gerado E,
        #o mdc entre E e phi(N) tem que ser igual a 1
        while gcd(phin,e) != 1:
            e = randint(2,phin)
        self.e = e
        #cálculo da inversa de e
        d = power_mod(e,-1,phin)
        self.d = d
        #Cálculo das chaves pública e privada
        PubKey = (e,n)
        PrivKey = (d,n) # esta modificação na forma da chave é para reutilizar código (f: get_key_len())
        return PubKey,PrivKey


    def get_pub_key():
        return (self.e,self.n)
    
    def get_priv_key():
        return (self.d,self.n)
    
    #Retornatamanho da chave em bits
    def get_key_len(self, key):
        '''Get the number of octets of the public/private key modulus'''
        _, n = key
        return int(n).bit_length() // 8

    def oaep_encode(self, msg, k, label = b'', f_hash = sha1, f_mgf = mgf):
        '''EME-OAEP encoding'''
        mlen = len(msg)
        lhash = f_hash(label)
        hlen = len(lhash)
        ps = b'\x00' * (k - mlen - 2 * hlen - 2)
        db = lhash + ps + b'\x01' + msg
        seed = os.urandom(hlen)
        db_mask = f_mgf(seed, k - hlen - 1, f_hash)
        masked_db = xor(db, db_mask)
        seed_mask = f_mgf(masked_db, hlen, f_hash)
        masked_seed = xor(seed, seed_mask)
        return b'\x00' + masked_seed + masked_db


    def oaep_decode(self, cypher, k, label = b'', f_hash = sha1, f_mgf = mgf):
        '''EME-OAEP decoding'''
        clen = len(cypher)
        lhash = f_hash(label)
        hlen = len(lhash)
        _, masked_seed, masked_db = cypher[:1], cypher[1:1 + hlen], cypher[1 + hlen:]
        seed_mask = f_mgf(masked_db, hlen, f_hash)
        seed = xor(masked_seed, seed_mask)
        db_mask = f_mgf(seed, k - hlen - 1, f_hash)
        db = xor(masked_db, db_mask)
        _lhash = db[:hlen]
        assert lhash == _lhash
        i = hlen
        while i < len(db):
            if db[i] == 0:
                i += 1
                continue
            elif db[i] == 1:
                i += 1
                break
            else:
                raise Exception()
        m = db[i:]
        return m

    #def EncSY(msg):
    #    return msg
    
    #def fo_encrypt(self,msg, public_key):
        '''
        msg    => mensagem a criptografar
        G e H  => funções de hash
        EncSY  => Encryptação Simétrica
        EncASY => Encriptação Assimética
        M ASY  => .....  Assimétrica
        '''
        #G = sha1()
        #H = sha1()
        
        # r <- M ASY
        #r = self.r
        
        # k := G(r)
        #k = G(r)
        
        # cm = EncSY(msg)
        #cm = EncSY(msg)
        
        # h := H(r,cm)
        # h = H(r+cm)
        
        # cr := EncASY(self.r+h,public_key)
        # e, n = public_key
        # cr = pow(r+h, e, n)
        
        #return c := cm || cr  
        #return cm + cr
        
        
    def encrypt(self, msg, public_key):
        '''Encrypt an integer using RSA public key'''
        e, n = public_key
        return pow(msg, e, n)


    def encrypt_raw(self, msg, public_key):
        '''Encrypt a byte array without padding'''
        k = self.get_key_len(public_key)
        c = self.encrypt(os2ip(msg), public_key)
        return i2osp(c, k)


    def encrypt_oaep(self, msg, public_key):
        '''Encrypt a byte array with OAEP padding'''
        hlen = 20  # SHA-1 hash length
        k = self.get_key_len(public_key)
        assert len(msg) <= k - hlen - 2
        return self.encrypt_raw(self.oaep_encode(msg, k), public_key)
    
    #def fo_decrypt(self,c, secret_key, public_key):
        '''
        c      => criptograma
        G e H  => funções de hash
        EncSY  => Encryptação Simétrica
        EncASY => Encriptação Assimética
        sk     => secret_key
        pk     => public_key
        '''
        #G = sha1()
        #H = sha1()
        
        # Fazer o parser de c para cm e cr. Retornar False se falhar!
        #
        
        # r' <- DecASY(cr,sk)
        # d, n = secret_key
        #rLinha = pow(cr, d, n)
        
        # if (r' !∈  M^asy) return False
        # if rLinha != self.r:
        #     return False
        
        # h' := H(r',cm)
        #hLinha = H(rLinha,cm)
        
        # cr' := EncASY(r',h',public_key)
        # e, n = public_key
        # crLinha = pow(rLinha +hLinha, e, n)
        
        # if (cr ≠ cr') return False
        # if cr != crLinha:
        #     return False
        
        # k' := G(r')
        # kLinha = G(rLinha)
        
        #return DecSY(cm,k')
        # return DecSY(cm,kLinha)
        
    def decrypt(self, cypher, private_key):
        '''Decrypt an integer using RSA private key'''
        d, n = private_key
        return pow(cypher, d, n)


    def decrypt_raw(self, cypher, private_key):
        '''Decrypt a cipher byte array without padding'''
        k = self.get_key_len(private_key)
        msg = self.decrypt(os2ip(cypher), private_key)
        return i2osp(msg, k)


    def decrypt_oaep(self, cypher, private_key):
        '''Decrypt a cipher byte array with OAEP padding'''
        k = self.get_key_len(private_key)
        hlen = 20  # SHA-1 hash length
        assert len(cypher) == k
        assert k >= 2 * hlen + 2
        return self.oaep_decode(self.decrypt_raw(cypher, private_key), k)
    
    
X = RSA(512)
PubKey,PrivKey = X.keygen()
cipher_text = X.encrypt_oaep('hello world!'.encode('ascii'), PubKey)
print('RSA-OAEP Encrypted text is:')
# isto serve para eliminar o warning: DeprecationWarning: encodestring()
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=DeprecationWarning)
    print(base64.encodestring(cipher_text).decode('ascii'))
print('RSA-OAEP Decrypted text is:')
plain_text = X.decrypt_oaep(cipher_text, PrivKey)
print(plain_text.decode('ascii'))

9239066499756987458227262338963950071925878317442853934906198023720840428602683575604291727750116316153769813037223752340739527965923831320764101707788743
RSA-OAEP Encrypted text is:
pTRnOSdpnDpV06HAGkYnm0fMm1G8J924reVrzWanb7hHrunfza7/frK5CGD0rWs9dn2jVoXLpX7Q
/qfjEyY4rA==

RSA-OAEP Decrypted text is:
hello world!
