In [6]:
from random import randrange, getrandbits
import secrets

#find d, x, and y such that d = ax+by and d = gcd(a, b)
def xgcd(a, b):

  x1 = 1; y1 = 0
  x2 = 0; y2 = 1

  c = a
  d = b

  while True :
       r = c % d
       if r == 0:
          return d, x2, y2

       q = c // d
       assert c == q*d + r

       t_x1 = x1
       t_y1 = y1

       x1 = x2
       y1 = y2

       x2 = t_x1 - q*x2
       y2 = t_y1 - q*y2

       c = d
       d = r

# mod inverse of a modulo b
def modInverse(a, b):
    d,x,y = xgcd(a, b)

    assert d == 1

    # ax + by = 1
    # ax = 1-by
    # ax = 1 (mod b)
    # x is the multiplicate inverse of a

    result = x % b
    assert result*a % b == 1
    return result

def dPdQ(p, q, N):
  d, x, y = xgcd(p, q)
  dP = y*q % N
  dQ = x*p % N

  return dP, dQ

def is_prime(n, k=128):
    """ Test if a number is prime
        Args:
            n -- int -- the number to test
            k -- int -- the number of tests to do
        return True if n is prime
    """
    # Test if n is not even.
    # But care, 2 is prime !
    if n == 2 or n == 3:
        return True
    if n <= 1 or n % 2 == 0:
        return False
    # find r and s
    s = 0
    r = n - 1
    while r & 1 == 0:
        s += 1
        r //= 2
    # do k tests
    for _ in range(k):
        a = randrange(2, n - 1)
        x = pow(a, r, n)
        if x != 1 and x != n - 1:
            j = 1
            while j < s and x != n - 1:
                x = pow(x, 2, n)
                if x == 1:
                    return False
                j += 1
            if x != n - 1:
                return False
    return True

def generate_prime_candidate(length):
    """ Generate an odd integer randomly
        Args:
            length -- int -- the length of the number to generate, in bits
        return a integer
    """
    # generate random bits
    p = getrandbits(length)
    # apply a mask to set MSB and LSB to 1
    p |= (1 << length - 1) | 1
    return p

def generate_prime_number(length=1024):
    """ Generate a prime
        Args:
            length -- int -- length of the prime to generate, in          bits
        return a prime
    """
    p = 4
    # keep generating while the primality test fail
    while not is_prime(p, 128):
        p = generate_prime_candidate(length)
    return p

def CRT(xp, xq, dP, dQ, N):
  return (xp*dP + xq * dQ) % N

def gcd(a, b):
  return xgcd(a, b)[0]

def lcm(a, b):
  # lcm(a, b) = (a*b)/gcd(a, b)
  return a//gcd(a, b) * b

def keyGen(keySize=2048, e = 65537):
  while True:
    p = generate_prime_number(keySize//2)
    q = generate_prime_number(keySize//2)
    lamda = lcm((p-1), (q-1))
    if(gcd(e, lamda) == 1):
      d = modInverse(e, lamda)
      return p*q, p, q, lamda, d 

#find a random group elt in Z*n
def rdmGrpElt(n):
  while True:
    test = secrets.randbelow(n)
    if xgcd(test, n)[0] == 1:
      return test

def enc(m, e, n):
  return pow(m, e, n)

def dec(c, d, n):
  return pow(c, d, n)

def dec_blinding(c, d, n, e):
  r = rdmGrpElt(n)
  rInv = modInverse(r, n)
  c = (pow(r, e, n) * c) % n
  return (dec(c, d, n) * rInv) %n

def dec_CRT_demo(c, d, p, q):
  n = p*q
  dP, dQ = dPdQ(p, q, n)
  return dec_CRT(c, p, q, dP, dQ, d, n)


def dec_CRT(c, p, q, dP, dQ, d, n, e=65537):
  mP = pow(c, d%(p-1), p)
  mQ = pow(c, d%(q-1), q)
  
  mN = CRT(mP, mQ, dP, dQ, n)
  #assert enc(mN, e, n) == c
  return mN

In [7]:
def properrun():
  keySize = 1024
  e = 65537
  n, p, q, lamda, d = keyGen(keySize, e)
  
  # encrypt a random msg
  m = randrange(2, n-1)
  c = enc(m, e, n)

  # decrypt using Chinese Remainder Theorem
  mprime = dec_CRT_demo(c, d, p, q)
  print(m == mprime)

properrun()

True


In [8]:
#introduce bit flips in mQ (introduced by jack hammer)
def dec_CRT(c, p, q, dP, dQ, d, n, e=65537):
  mP = pow(c, d%(p-1), p)
  mQ = 3#pow(c, d%(q-1), q)  #we changed mq to 3 threre ny introduced the fault
  
  mN = CRT(mP, mQ, dP, dQ, n)
  #assert enc(mN, e, n) == c
  return mN

In [9]:
def bitfliprun():
  keySize = 1024
  e = 65537
  n, p, q, lamda, d = keyGen(keySize, e)
  
  # encrypt a random msg
  m = randrange(2, n-1)
  c = enc(m, e, n)

  # decrypt using Chinese Remainder Theorem
  mprime = dec_CRT_demo(c, d, p, q)
  print(m == mprime)
  factor = gcd(mprime-m,n)
  print(factor)
  print(p)
  print(q)
  print(factor==p or factor==q)
bitfliprun()

False
12143843270490332773112680676792944104282753533520973085220687416374594136831771110859729050383976105890841526851864823106008465682754958795099818282732177
12143843270490332773112680676792944104282753533520973085220687416374594136831771110859729050383976105890841526851864823106008465682754958795099818282732177
11666642569323084411809691489725549022825813453601251622631989409626096316842573759702763521298792862456230829745219237540475197289160085110362770897537931
True
