In [1]:
import random
from random import randint
import sys
from functools import reduce
from Crypto.Cipher import AES
from Crypto.Util import Counter
from Crypto import Random 

# A cryptographically strong version of Python's standard "random" module

# Implementation notes:
# Since AES uses block length of 16 bytes, we use "ctr = Counter.new(128)"
# Since counter is a stateful function, we need 2 objects - one for encoding, one for decoding
class secretshare:
    '''
    Base class for Secret Sharing Schemes (SSS) that implements (k,n)-threshold sharing.
    Encryption: Use AES-256 to encode infile and save into outfile, then split AES key into n keys
    Decryption: Combine k keys into AES key and decrypt outfile
    Other SSS classes extend from this and implement their methods of splitting the AES key and combining keys
    Conventions:
    S = secret = AES key
    After splitting up a "key", we get "keys"/"shares"/"shadows"
    '''
    def __init__(self):
        '''
        Initialises a huge prime p for modulo (if needed), where S < p.
        Since S is 32 bytes, p has to be > 256 bits.
        Source: http://primes.utm.edu/lists/2small/200bit.html
        Verification: http://www.wolframalpha.com/input/?i=is+2%5E257+-+93+prime
        '''
        # You can randomly generate this if you wish to
        self.p = 2**257 - 93
        print("in init")
    def split_key(self, key, n, k):
        '''
        Generates coefficient vector a with a[0] = S
        Sample polynomial q(i) at i = 1, 2, ... , n
        
        Key i = [i, q(i)]
        '''

        # Generate coefficient vector a
        a = [int.from_bytes(key, byteorder = sys.byteorder)]
        for i in range(k-1):
            a.append(random.randint(0, 2**256))

        # Polynomial q(x) = a_0 + a_1 * x + a2 * x^2 + ... + a_(k-1) * x^(k-1) (mod p)
        # Generate q(1), q(2), ... , q(n) (mod p)
        keys = []
        for i in range(1, n+1):
            x = [i ** j for j in range(k)]
            keys.append([i, sum(a[j] * x[j] % self.p for j in range(k)) % self.p])

        # Return keys
        return keys
    def combine_keys(self, keys):
        '''
        Extract x and y = q(x) from keys
        Apply Lagrange interpolation to compute q(0) = S
        
        Key = q(0)
        '''

        # keys[i][0] = x value
        # keys[i][1] = q(x) value
        k = len(keys)
        x = [keys[i][0] for i in range(k)]
        y = [keys[i][1] for i in range(k)]
        print(x,y)
        # Find q(0) by directly applying definition of Lagrange interpolation formula
        # Secret S = AES key = q(0)
        #
        # Implementation notes:
        # Take modulo (2 ** 256) because insufficent/invalid keys may result in S > 256 bits
        # If S > 256 bits, then it will crash in the conversion to 32 byte representation
        S = int(sum(y[j] * basis(x, k, j, self.p) % self.p for j in range(k)) % self.p) % (2 ** 256)
        print("cobine:",S)
        key = S.to_bytes(32, byteorder = sys.byteorder)

        # Return key
        return key

    def encrypt(self, infile, outfile, keysfile, n, k):
        '''
        Encrypts infile to outfile via AES-256 and stores "broken up" key in keysfile
        1) Reads in plaintext from infile
        2) Create AES-256 encoder with 32 random bytes as key
        3) Encrypt plaintext
        4) Store ciphertext in outfile
        5) Split key into via split_key function (Output depends on n and k)
        6) Store keys/shares in keysfile
        '''

        # Read from infile
        with open(infile, 'rb') as f:
            plain = f.read()

        # Create AES-256 encoder with 32 random bytes as key
        key = Random.new().read(32)
        encoder = AES.new(key, AES.MODE_CTR, counter = Counter.new(128))

        # Encrypt plaintext
        cipher = encoder.encrypt(plain)
        
        # Write to outfile
        with open(outfile, 'wb') as f:
            f.write(cipher)

        # Generate n keys
        keys = self.split_key(key, n, k)

        # Store n keys
        with open(keysfile, 'w') as f:
            for key in keys:
                f.write("{}\n".format(key))

    def encrypt_paillier(self,infile,outfile, keysfile1,keysfile2 ,n, k):
        p = 3
        m = infile
        #m = 10
        q = 5
        t = p*q
        gLambda = lcm(p-1,q-1)  #private
        g = t + 1
        r = randint(1,t)
        l = (pow(g, gLambda)% (t*t)-1)//n
        gMu = inverse_of(l, t)
        k1 = pow(g, m) % (t*t)
        k2 = pow(r, t)% (t*t)
        cipher = (k1 * k2) % (t*t)
        #cipher = int(cipher).to_bytes(2,byteorder = sys.byteorder)
        key1 = int(gLambda).to_bytes(32,byteorder = sys.byteorder)
        key2 = int(gMu).to_bytes(32,byteorder = sys.byteorder)
        print("k1=",k1,"k2=",k2,"cipher=",cipher)
        
        

        keys1 = self.split_key(key1,n,k)
        keys2 = self.split_key(key2,n,k)

   
        # Write to outfile
        with open(outfile, 'w') as f:
            f.write(str(cipher))
        # Store n keys
        with open(keysfile1, 'w') as f:
            for key in keys1:
                f.write("{}\n".format(key))
        with open(keysfile2, 'w') as f:
            for key in keys2:
                f.write("{}\n".format(key))


    def decrypt_paillier(self, infile, outfile, keysfile1,keysfile2):
        with open(keysfile1, 'r') as f:
            keys1 = f.read().splitlines()
        keys1 = [[int(num) for num in key[1:-1].replace(' ', '').split(',')] for key in keys1]
        with open(keysfile2, 'r') as f:
            keys2 = f.read().splitlines()
        keys2 = [[int(num) for num in key[1:-1].replace(' ', '').split(',')] for key in keys2]
        self.decrypt_paillier_with_keys(infile,outfile,keys1,keys2)

    def decrypt_paillier_with_keys(self, infile, outfile, keys1,keys2):

        # Read from infile
        with open(infile, 'r') as f:
            cipher = f.read()
            print(cipher)

        try:
            # Combine given keys. May throw exception if < k valid keys are given
            key1 = self.combine_keys(keys1)
            key1=int.from_bytes(key1, byteorder=sys.byteorder)
            key2 = self.combine_keys(keys2)
            key2=int.from_bytes(key2, byteorder=sys.byteorder)
            print("key1= ",key1,"key2= ",key2)
            l = ((pow(int(cipher),key1)%225)-1) // 15
            plain = (l*key2) % 15
            print(plain)
            
        finally:
            # Write to outfile
            with open(outfile, 'w') as f:
                f.write(str(plain))
            #print(plain)

    def decrypt(self, infile, outfile, keysfile):
        '''
        Reads in keys/shares from keysfiles and parse them as a list of keys/shares
        '''

        # Read from keysfile
        with open(keysfile, 'r') as f:
            keys = f.read().splitlines()
        keys = [[int(num) for num in key[1:-1].replace(' ', '').split(',')] for key in keys]
        self.decrypt_with_keys(infile, outfile, keys)

    def decrypt_with_keys(self, infile, outfile, keys):
        '''
        Decrypts infile to outfile via AES-256 with keys
        1) Reads in ciphertext from infile
        2) Combine keys/shares into a AES-256 key
        3) Create AES-256 decoder with combined key
        4) Decrypt ciphertext
        5) Store plaintext in outfile
        '''

        # Read from infile
        with open(infile, 'rb') as f:
            cipher = f.read()

        try:
            # Combine given keys. May throw exception if < k valid keys are given
            key = self.combine_keys(keys)

            # Create AES-256 decoder with key
            decoder = AES.new(key, AES.MODE_CTR, counter = Counter.new(128))

            # Decrypt ciphertext
            plain = decoder.decrypt(cipher)
        except Exception as e:
            plain = str.encode(e.args[0])
        finally:
            # Write to outfile
            with open(outfile, 'wb') as f:
                f.write(plain)

    ####################
    # HELPER FUNCTIONS #
    ####################

def prod(lst):
    '''
    Returns the product of all values in the list
    '''
    return reduce(lambda x, y: x * y, lst)

# Source: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def xgcd(b, n):
    '''
    Extended gcd (Iterative form)
    '''
    x0, x1, y0, y1 = 1, 0, 0, 1
    while n != 0:
        q, b, n = b // n, n, b % n
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return  b, x0, y0

# Source: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def mulinv(b, n):
    '''
    Returns the modulo inverse of b in mod n
    i.e. x = mulinv(b) mod n, (x * b) % n == 1
    '''
    g, x, _ = xgcd(b, n)
    if g == 1:
        return x % n

# Source: https://en.wikipedia.org/wiki/Lagrange_polynomial
def basis(x, k, j, p):
    '''
    Computes the basis for Lagrange interpolating polynomial based on the formula
    '''
    terms = [(0-x[m])*mulinv(x[j] - x[m], p) for m in range(k) if m != j]
    return prod(terms) % p


#source : https://asecuritysite.com/encryption/pal_ex
def gcd(a,b):
    """Compute the greatest common divisor of a and b"""
    while b > 0:
        a, b = b, a % b
    return a
    
def lcm(a, b):
    """Compute the lowest common multiple of a and b"""
    return a * b / gcd(a, b)

def extended_euclidean_algorithm(a, b):
    """
    Returns a three-tuple (gcd, x, y) such that
    a * x + b * y == gcd, where gcd is the greatest
    common divisor of a and b.

    This function implements the extended Euclidean
    algorithm and runs in O(log b) in the worst case.
    """
    s, old_s = 0, 1
    t, old_t = 1, 0
    r, old_r = b, a

    while r != 0:
        quotient = old_r // r
        old_r, r = r, old_r - quotient * r
        old_s, s = s, old_s - quotient * s
        old_t, t = t, old_t - quotient * t

    return old_r, old_s, old_t


def inverse_of(n, p):
    """
    Returns the multiplicative inverse of
    n modulo p.

    This function returns an integer m such that
    (n * m) % p == 1.
    """
    gcd, x, y = extended_euclidean_algorithm(n, p)
    assert (n * x + p * y) % p == gcd

    if gcd != 1:
        # Either n is 0, or p is not a prime number.
        raise ValueError(
            '{} has no multiplicative inverse '
            'modulo {}'.format(n, p))
    else:
        return x % p

def L(x,n):
	return ((x-1)//n)


ModuleNotFoundError: ignored

In [0]:
t = secretshare()

in init


In [0]:
t.encrypt_paillier(10,"shit.txt","keys1.txt","keys2.txt",7,5)

k1= 151 k2= 224 cipher= 74


In [0]:
t.decrypt_paillier("shit.txt","god.txt","keys1.txt","keys2.txt")

74
[1, 2, 3, 4, 5, 6, 7] [183955446488165767267367997706103673361795404446810101512524357174500072754761, 5147509824276657956811803417685109131518703376726393776790902369977894952772, 200058817576392741870652364337278041328964832023799492501135833588783946444141, 169303715038371831502157395918213165832964716149698909869812988498002812167427, 61885871501560999659819112021000481740244996025243207253230356571889515667655, 80445744830899406925928314861201844245808118436556610971204575044390741256979, 18092224515653519962290454974679333229852398022506756135130592877492315455124]
cobine: 4
[1, 2, 3, 4, 5, 6, 7] [193647337535349443534896601772479830857983117366977708713701394943661338595370, 205588319757890212682304196542104660493783711590180596508113603211552699705801, 227930514909827222211127773406574292432673278807026707569445540083614936686566, 64915089375727923818993076625141616670044068317676340937528466738364358794978, 5925352158105186952081448568811587249756895991015656282147666231619

In [0]:
t.encrypt("lenna.png","cipher.png","2020_1_21.txt",10,5)

In [0]:
t.decrypt("cipher.png","new1.png","2020_1_21.txt")

In [0]:
t.decrypt_paillier("output.png","jacky.png","first.txt","second.txt")


b"unsupported operand type(s) for ** or pow(): 'bytes' and 'bytes'"


In [2]:
sudo pip install pycrypto

SyntaxError: ignored