In [20]:
# sage -pip install pycryptodome
import os
import numpy
import math
import struct
import random
from Crypto.Util import number
from Crypto.Util.number import getRandomRange, inverse
from Crypto.Util.py3compat import bchr, bord
from Crypto.Hash import SHA256
from Crypto.Util.number import ceil_div, long_to_bytes
from Crypto.Util.strxor import strxor

from collections import namedtuple

def is_coprime(x, y):
    return math.gcd(x, y) == 1

def is_prime(n):
    return number.isPrime(n)

def augment(c, H, s):
    return (2 * H(c + s)) + 1

# def largest_prime_factor(n):
#     factor = 1
#     i = 2 # smallest prime
    
#     while i <= (n / i):
#         if n % i == 0:
#             factor = i
#             n = n / i
#         else:
#             i += 1

#     if factor < n:
#         factor = n
#     return int(factor)

def prime_factors(n):
    if is_prime(n):
        return [n]
    
    factors = []
    factor = 1
    i = 3 # smallest prime
    
    if n % 2 == 0:
        factors.append(2)

    while i <= (n / i):
        if n % i == 0:
            factor = int(i)
            factors.append(factor)
            while n % i == 0:
                n = n / i
        else:
            i += 1

    if factor < n:
        factor = int(n)
    factors.append(factor)
    
    return factors

AUGMENTER_LENGTH = 32
MAX_ATTEMPTS = 10000
C_LEN = 2
    
def find_augmenter(C, H, L):
    # For all c_i, c_j such that c_i != cj: 
    # 1. The largest prime factor of f(c_i) must not divide f(c_j)
    # 2. f(c_i) must be relatively prime to L = \lambda = 2((p-1)/2)((q-1)/2)
    
    def is_valid_s(s, C, H, L):
        augmented = [augment(c, H, s) for c in C]
        print(augmented)
        for fci in augmented:
            if not is_coprime(fci, L):
                return False
            for fcj in augmented:
                if fci != fcj:
                    unique_factor = False
                    for pki in prime_factors(fci):
                        if (fcj % pki) != 0:
                            unique_factor = True
                            break
                    if not unique_factor:
                        return False
        return True
    
    attempt = 0
    while attempt < MAX_ATTEMPTS:
        s = os.urandom(AUGMENTER_LENGTH).hex()
        if is_valid_s(s, C, H, L):
            return s
        attempt += 1
    
    raise Exception("Unable to find augmenter")
    
def hasher(x, l):
    assert(l <= 32)
    return int.from_bytes(SHA256.new(data=x.encode('utf-8')).digest()[0:l], 'big')

print("Configuring parameters...")
    
# RSA parameters
PRIME_LENGTH = 256
p = number.getPrime(PRIME_LENGTH)
q = number.getPrime(PRIME_LENGTH)
L = (p - 1) * (q - 1)
print(p, q, L)

# Public auxilary data set
C = [str(i) for i in range(C_LEN)]

# Augmentation hash function
H_8 = lambda x : hasher(x, 8)
H_16 = lambda x : hasher(x, 16)
H_32 = lambda x : hasher(x, 32)

# print("Starting...")
# print("Found s:", find_augmenter(C, H_8, L))
#print(prime_factors(7393980126450832332490962757833565705214980010614030376803353873452916004836474833291037699316343345457410482720141970433747787863766971947932456105096487))

def random_integer(m, n):
    return getRandomRange(m, n) # numpy.random.randint(m, n)

def random_bytes(n):
    return os.urandom(n)

def inverse_mod(a, m):
    return inverse(a, m)

def I2OSP(val, length):
    val = int(val)
    if val < 0 or val >= (1 << (8 * length)):
        raise ValueError("bad I2OSP call: val=%d length=%d" % (val, length))
    ret = [0] * length
    val_ = val
    for idx in reversed(range(0, length)):
        ret[idx] = val_ & 0xff
        val_ = val_ >> 8
    ret = struct.pack("=" + "B" * length, *ret)
    assert OS2IP(ret, True) == val
    return ret

def OS2IP(octets, skip_assert=False):
    ret = 0
    for octet in struct.unpack("=" + "B" * len(octets), octets):
        ret = ret << 8
        ret += octet
    if not skip_assert:
        assert octets == I2OSP(ret, len(octets))
    return ret

PublicKey = namedtuple("PublicKey", "n e")
PrivateKey = namedtuple("PrivateKey", "n p q d")

def RSASP1(skS, m):
    s = pow(m, skS.d, skS.n)
    return s

def RSAVP1(pkS, s):
    m = pow(s, pkS.e, pkS.n)
    return m

def rsassa_pss_sign_encode(n, msg_hash):
    mgf = lambda x,y: MGF1(x, y, msg_hash)
    k_bits = number.size(n)
    salt_length = 0
    return EMSA_PSS_ENCODE(msg_hash, k_bits - 1, random_bytes, mgf, salt_length)

def rsassa_pss_sign_blind(pkS, msg_hash):
    """
    1. encoded_message = EMSA-PSS-ENCODE(msg, k_bits - 1)
    2. If EMSA-PSS-ENCODE outputs an error, output an error and stop.
    3. m = OS2IP(encoded_message)
    4. r = RandomInteger(0, n - 1)
    5. x = RSAVP1(pkS, r)
    6. z = m * x mod n
    7. r_inv = inverse_mod(r, n)
    8. blinded_message = I2OSP(z, k)
    9. blind_inv = I2OSP(r_inv, k)
    10. return blinded_message, blind_inv
    """
    k_bits = number.size(pkS.n)
    k = number.ceil_div(k_bits, 8)
    
    encoded_message = rsassa_pss_sign_encode(pkS.n, msg_hash)
    m = OS2IP(encoded_message)
    
    r = random_integer(1, pkS.n - 1)
    r_inv = inverse_mod(r, pkS.n)
    assert((r * r_inv) % pkS.n == 1)
    
    x = RSAVP1(pkS, r) # r^e mod n
    z = (m * x) % pkS.n # m*r^e mod n

    blinded_message = I2OSP(z, k)
    blind_inv = I2OSP(r_inv, k)
    return blinded_message, blind_inv

def rsassa_pss_sign_evaluate(skS, blinded_msg):
    '''
    1. m = OS2IP(blinded_msg)
    2. s = RSASP1(skS, m)
    3. evaluated_message = I2OSP(s, k)
    4. return evaluated_message
    '''
    k_bits = number.size(pkS.n)
    k = number.ceil_div(k_bits, 8)
    
    m = OS2IP(blinded_msg)
    s = RSASP1(skS, m) # (m*r^e)^d = r*m^d mod n
    evaluated_message = I2OSP(s, k)
    return evaluated_message

def rsassa_pss_sign_finalize(pkS, msg_hash, evaluated_message, blind_inv):
    '''
    1. z = OS2IP(evaluated_message)
    2. r_inv = OS2IP(blind_inv)
    3. s = z * r_inv mod n
    4. result = rsassa_pss_sign_verify(pkS, msg, s)
    5. sig = I2OSP(s, k)
    6. If result = true, return s, else output "invalid signature" and stop
    '''
    k_bits = number.size(pkS.n)
    k = number.ceil_div(k_bits, 8)
    
    z = OS2IP(evaluated_message)
    r_inv = OS2IP(blind_inv)
    s = (z * r_inv) % pkS.n # (r*m^d) * r^{-1} = m^d mod n
    
    sig = I2OSP(s, k)
    if rsassa_pss_sign_verify(pkS, msg_hash, sig):
        return sig
    else:
        raise Exception("invalid signature")
        
def rsassa_pss_sign(skS, msg_hash):
    k_bits = number.size(pkS.n)
    k = number.ceil_div(k_bits, 8)
    EM = rsassa_pss_sign_encode(pkS.n, msg_hash)
    m = OS2IP(EM)
    s = RSASP1(skS, m)
    sig = I2OSP(s, k)
    return sig

def rsassa_pss_sign_verify(pkS, msg_hash, sig):
    '''
    1. If len(sig) != k, output false
    2. s = OS2IP(sig)
    3. m = RSAVP1(pkS, s)
    4. If RSAVP1 output "signature representative out of range", output false
    5. encoded_message = I2OSP(m, L_em)
    6. result = EMSA-PSS-VERIFY(msg, encoded_message, k_bits - 1).
    7. If result = "consistent", output true, otherwise output false
    8. output result
    '''
    k_bits = number.size(pkS.n)
    k = number.ceil_div(k_bits, 8)
    
    if len(sig) != k:
        return False
    
    s = OS2IP(sig)
    m = RSAVP1(pkS, s)
    EM = I2OSP(m, k)
   
    mgf = lambda x,y: MGF1(x, y, msg_hash)
    salt_length = 0
    return EMSA_PSS_VERIFY(msg_hash, EM, k_bits - 1, mgf, salt_length)

def MGF1(mgfSeed, maskLen, hash):
    """Mask Generation Function, described in B.2.1"""
    T = bytes([])
    for counter in range(ceil_div(maskLen, hash.digest_size)):
        c = long_to_bytes(counter, 4)
        try:
            T = T + hash.new(mgfSeed + c).digest()
        except AttributeError:
            # hash object doesn't have a "new" method.  Use Crypto.Hash.new() to instantiate it
            T = T + Hash_new(hash, mgfSeed + c).digest()
    assert(len(T)>=maskLen)
    return T[:maskLen]

def EMSA_PSS_ENCODE(mhash, emBits, randFunc, mgf, sLen):
    """
    Implement the ``EMSA-PSS-ENCODE`` function, as defined
    in PKCS#1 v2.1 (RFC3447, 9.1.1).
    The original ``EMSA-PSS-ENCODE`` actually accepts the message ``M`` as input,
    and hash it internally. Here, we expect that the message has already
    been hashed instead.
    :Parameters:
     mhash : hash object
            The hash object that holds the digest of the message being signed.
     emBits : int
            Maximum length of the final encoding, in bits.
     randFunc : callable
            An RNG function that accepts as only parameter an int, and returns
            a string of random bytes, to be used as salt.
     mgf : callable
            A mask generation function that accepts two parameters: a string to
            use as seed, and the lenth of the mask to generate, in bytes.
     sLen : int
            Length of the salt, in bytes.
    :Return: An ``emLen`` byte long string that encodes the hash
            (with ``emLen = \ceil(emBits/8)``).
    :Raise ValueError:
        When digest or salt length are too big.
    """
    emLen = ceil_div(emBits,8)

    # Bitmask of digits that fill up
    lmask = 0
    for i in range(8*emLen-emBits):
        lmask = lmask>>1 | 0x80

    # Step 1 and 2 have been already done
    # Step 3
    if emLen < mhash.digest_size+sLen+2:
        raise ValueError("Digest or salt length are too long for given key size.")
    # Step 4
    salt = bytes([])
    if randFunc and sLen>0:
        salt = randFunc(sLen)
    # Step 5 and 6
    try:
        h = mhash.new(bchr(0x00)*8 + mhash.digest() + salt)
    except AttributeError:
        # hash object doesn't have a "new" method.  Use Crypto.Hash.new() to instantiate it
        h = Hash_new(mhash, bchr(0x00)*8 + mhash.digest() + salt)
    # Step 7 and 8
    db = bchr(0x00)*(emLen-sLen-mhash.digest_size-2) + bchr(0x01) + salt
    # Step 9
    dbMask = mgf(h.digest(), emLen-mhash.digest_size-1)
    # Step 10
    maskedDB = strxor(db,dbMask)
    # Step 11
    maskedDB = bchr(bord(maskedDB[0]) & ~lmask) + maskedDB[1:]
    # Step 12
    em = maskedDB + h.digest() + bchr(0xBC)
    return em

def EMSA_PSS_VERIFY(mhash, em, emBits, mgf, sLen):
    """
    Implement the ``EMSA-PSS-VERIFY`` function, as defined
    in PKCS#1 v2.1 (RFC3447, 9.1.2).

    ``EMSA-PSS-VERIFY`` actually accepts the message ``M`` as input,
    and hash it internally. Here, we expect that the message has already
    been hashed instead.

    :Parameters:
     mhash : hash object
            The hash object that holds the digest of the message to be verified.
     em : string
            The signature to verify, therefore proving that the sender really signed
            the message that was received.
     emBits : int
            Length of the final encoding (em), in bits.
     mgf : callable
            A mask generation function that accepts two parameters: a string to
            use as seed, and the lenth of the mask to generate, in bytes.
     sLen : int
            Length of the salt, in bytes.

    :Return: 0 if the encoding is consistent, 1 if it is inconsistent.

    :Raise ValueError:
        When digest or salt length are too big.
    """

    emLen = ceil_div(emBits,8)

    # Bitmask of digits that fill up
    lmask = 0
    for i in range(8*emLen-emBits):
        lmask = lmask>>1 | 0x80

    # Step 1 and 2 have been already done
    # Step 3
    if emLen < mhash.digest_size+sLen+2:
        return False
    # Step 4
    if ord(em[-1:])!=0xBC:
        return False
    # Step 5
    maskedDB = em[:emLen-mhash.digest_size-1]
    h = em[emLen-mhash.digest_size-1:-1]
    # Step 6
    if lmask & bord(em[0]):
        return False
    # Step 7
    dbMask = mgf(h, emLen-mhash.digest_size-1)
    # Step 8
    db = strxor(maskedDB, dbMask)
    # Step 9
    db = bchr(bord(db[0]) & ~lmask) + db[1:]
    # Step 10
    if not db.startswith(bchr(0x00)*(emLen-mhash.digest_size-sLen-2) + bchr(0x01)):
        return False
    # Step 11
    salt = bytes([])
    if sLen: salt = db[-sLen:]
    # Step 12 and 13
    try:
        hp = mhash.new(bchr(0x00)*8 + mhash.digest() + salt).digest()
    except AttributeError:
        # hash object doesn't have a "new" method.  Use Crypto.Hash.new() to instantiate it
        hp = Hash_new(mhash, bchr(0x00)*8 + mhash.digest() + salt).digest()
    # Step 14
    if h!=hp:
        return False
    return True

# Generate key pair
PRIME_LENGTH = 4096
p = number.getPrime(PRIME_LENGTH)
q = number.getPrime(PRIME_LENGTH)
phi = (p-1)*(q-1)

e = random.randint(3, phi)
while math.gcd(e, phi) != 1:
    e = random.randint(3, phi)
d = inverse_mod(e, phi)
n = p * q

skS = PrivateKey(n, p, q, d)
pkS = PublicKey(n, e)

# Sanity check
m = random.randint(3, phi)
s = RSASP1(skS, m)
mm = RSAVP1(pkS, s)
assert(m == mm)

msg = 'hello world'.encode("utf-8")
msg_hash = SHA256.new()
msg_hash.update(msg)

# Run the non-blind variant
sig = rsassa_pss_sign(skS, msg_hash)
valid = rsassa_pss_sign_verify(pkS, msg_hash, sig)
assert(valid)

# Run the blind variant
blinded_message, blind_inv = rsassa_pss_sign_blind(pkS, msg_hash)
evaluated_message = rsassa_pss_sign_evaluate(skS, blinded_message)
sig = rsassa_pss_sign_finalize(pkS, msg_hash, evaluated_message, blind_inv)
valid = rsassa_pss_sign_verify(pkS, msg_hash, sig)
assert(valid)



Configuring parameters...
88941359103639885804192383930543479456847975490366807609742431183810135234043 89185430141802147643862416020835505421375645430526105925317418733115858953793 7932273369054613515857923242086221429094335605525697403082075278987149550343390119434439878047320184111338679426785491416433163440945068425368143183387264
579262824287160600490017856873020029624865881124890591554956580858299694802873315908766919618436701081453691058627457721285392376173639905083772598319178813802415279517287822715377136013289836420886653998698756385402107310976037261166550457432945438193220810370896254873819344149848004284654616326159639634431871492131664762318437696496810376521058789050262259202679807519044532567368959561396376248073430208729360195878050493323308526925282227607575249613395430128383606684380410304532303786406867658213902061553202941662821764417107733128785886139743677587129857920376366063307655982069777946050024831992695112828567572998625659428664878715071636285000182176824