In [1]:
!pip install pycryptodome


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m22.2.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import random
import hashlib
import time

HASH_LEN = 512
HASH_BYTE_LEN = int(HASH_LEN/8)
RANDP_LIMIT = 2**17
DEFAULT_KEY_RANGE = 2**1032

In [3]:
def toBytes(data):
    if (type(data) == bytes):
        return data
    elif (type(data) == list):
        return bytes(data)
    elif (type(data) == int):
        b = []
        iterations = 0
        
        while (data != 0):
            b = [data % 256] + b
            data = int(data / 256)
            iterations += 1
        
        return bytes(b + [iterations])
    else:
        return bytes([data])

def toHex(data):
    return "".join([(("0" if i < 0x10 else "") + hex(i)[2:]) for i in data])

def hash(data):
    if isinstance(data, int):
        d = []
        
        while data > 0:
            d = [data%2] + d
            data //= 2
        
        data = d
    
    return list(SHA512.new(bytes(data)).digest())

def flatten(a):
    out = []
    
    for sub in a:
        out += sub
    
    return out

def chunkList(a, chunk):
    if (len(a)%chunk != 0):
        raise Exception("Can only chunk lists that are a multiple of the chunk size (got size %d, chunk %d)" % (len(a), chunk))
    
    out = []
    
    for i in range(int(len(a)/chunk)):
        out.append(a[(chunk*i):(chunk*(i+1))])
    
    return out

def pairify(a):
    return chunkList(a, 2)

def hashify(a):
    return chunkList(a, HASH_BYTE_LEN)

def genRange(n):
    out = []
    for i in range(n):
        out.append(i)
    
    return out

In [4]:
class Randp:
    def __init__(self, n):
        self.numsLeft = n
        self.nums = genRange(n)
    
    def nextInt(self):
        if self.numsLeft < 0:
            return 0
        
        index = random.randrange(self.numsLeft)
        ret = self.nums[index]
        self.nums[index] = self.nums[self.numsLeft-1]
        self.numsLeft -= 1
        return ret

class Rand:
    def __init__(self, n):
        self.n = n
    
    def nextInt(self):
        return random.randrange(self.n)

def createRandom(n):
    return Rand(n) if n > RANDP_LIMIT else Randp(n)

In [5]:
class PrivateKey:
    def __init__(self, keySize, keyRange=DEFAULT_KEY_RANGE):
        self.keySize = keySize
        self.keyRange = keyRange
        self.inputCapacity = keySize-HASH_BYTE_LEN
        self.public = PublicKey(self.regenKey())
    
    def regenKey(self):
        rand = createRandom(self.keyRange)
        self.privateData = [[rand.nextInt(), rand.nextInt()] for i in range(self.keySize)]
        return [[hash(i) for i in b] for b in self.privateData]
    
    def encrypt(self, msg):
        if (8*len(msg) > self.inputCapacity):
            raise Exception("Got %d byte message for encryption, size limit is %d" % (len(msg), self.inputCapacity))
        
        priv = self.privateData
        nextPub = self.regenKey()
        
        out = []
        
        pubFlat = flatten(flatten(nextPub))
        out += pubFlat
        
        self.encryptBuf(priv, out, hash(pubFlat), 0)
        
        self.encryptBuf(priv, out, msg, HASH_LEN)
        
        return out
    
    def encryptBuf(self, priv, out, buf, off):
        for i in range(len(buf)):
            for j in range(8):
                out.append(priv[off + 8*i + j][(buf[i] >> j) % 2])
    
    def sign(self, msg):
        return self.encrypt(hash(msg))

In [6]:
class PublicKey:
    def __init__(self, publicData):
        self.publicData = publicData
        self.keySize = len(publicData)
        self.flatKeySize = self.keySize*HASH_BYTE_LEN*2
    
    def decrypt(self, msg, errWhenUnknownHash=True):
        if (len(msg) < self.flatKeySize):
            raise Exception("Encrypted message must be at least %d bytes, got %d" % (self.flatKeySize, len(msg)))
        
        flatPub = msg[:self.flatKeySize]
        nextPub = pairify(hashify(flatPub))
        encrypted = msg[self.flatKeySize:]
        
        if (len(encrypted) > self.flatKeySize):
            raise Exception("Got %d byte message for decryption, size limit is %d" % (len(encrypted), len(self.publicData)))
        
        if (len(encrypted) % 8 != 0):
            raise Exception("Encrypted message should be a multiple of 8 in length, got %d" % len(msg))
        
        decrypted = []
        for i in range(int(len(encrypted)/8)):
            num = 0
            for j in range(8):
                index = 8*i + j
                publicByte = hash(encrypted[index])
                
                if not (publicByte in self.publicData[index]):
                    if errWhenUnknownHash:
                        raise Exception("Message not encrypted from real source (index %d). Acceptable hashes are %s and %s, got %s" % (index, toHex(self.publicData[index][0]), toHex(self.publicData[index][1]), toHex(publicByte)))
                    else:
                        return
                
                num += (self.publicData[index].index(publicByte)) << j
            
            decrypted.append(num)
        
        publicHash = decrypted[:HASH_BYTE_LEN]
        if (publicHash != hash(flatPub)):
            raise Exception("Hash mismatch for next public key!")
        
        self.publicData = nextPub
        return decrypted[HASH_BYTE_LEN:]
    
    def verify(self, msg, sig):
        return self.decrypt(sig, True) == hash(msg)

In [7]:
currentTest = None

def generateTestPayload():
    testPayload = []

    for i in range(32):
        testPayload.append(random.randrange(256))

    return testPayload

class Test:
    def __init__(self):
        global currentTest

        self.testPayloads = [generateTestPayload() for i in range(6)]

        currentTest = self

        self.sigs = []
        self.signN = 0
        self.verifyN = 0

class LamportCertTest(Test):
    def __init__(self, keyRange):
        super().__init__()

        self.private = PrivateKey(2048, keyRange)
        self.public = self.private.public

    def testSign(self):
        self.sigs.append(self.private.sign(self.testPayloads[self.signN]))
        self.signN += 1

    def testVerify(self):
        assert self.public.verify(self.testPayloads[self.verifyN], self.sigs[self.verifyN]), "Failed LamportCert signature verification"
        self.verifyN += 1

from Cryptodome.PublicKey import RSA, ECC, DSA
from Cryptodome.Cipher import PKCS1_OAEP
from Cryptodome.Signature import DSS
from Cryptodome.Hash import SHA512

def bytesToInt(b):
    res = 0

    for byte in b:
        res *= 2
        res += byte

    return res

rsaKeyPair = None
rsaPublic = None

class RsaTest(Test):
    def __init__(self, k):
        super().__init__()

        self.padToBytes = k//8

        self.keyPair = RSA.generate(k, e=2**(k-1) + 1)
        self.public = self.keyPair.publickey()
  
    def testSign(self):
        h = hash(self.testPayloads[self.signN])

        self.sigs.append(self.keyPair._decrypt(bytesToInt(h)))
        self.signN += 1

    def testVerify(self):
        h = hash(self.testPayloads[self.verifyN])

        assert bytesToInt(h) == self.public._encrypt(self.sigs[self.verifyN]), "Failed RSA signature verification"
        self.verifyN += 1

class DssTest(Test):
    def initSignVerify(self):
        self.signer = DSS.new(self.key, "fips-186-3")
        self.verifier = DSS.new(self.key, "fips-186-3")
    
    def testSign(self):
        h = self.hashFunc.new(bytes(self.testPayloads[self.signN]))
        self.sigs.append(self.signer.sign(h))
        self.signN += 1

    def testVerify(self):
        h = self.hashFunc.new(bytes(self.testPayloads[self.verifyN]))
        self.verifier.verify(h, self.sigs[self.verifyN])
        self.verifyN += 1

class DsaTest(DssTest):
    def __init__(self, keySize):
        super().__init__()
        
        self.hashFunc = SHA512
        self.key = DSA.generate(keySize)
        self.initSignVerify()

class EccTest(DssTest):
    def __init__(self, keySize):
        super().__init__()

        self.hashFunc = SHA512
        self.key = ECC.generate(curve=("p" + str(keySize)))
        self.initSignVerify()

schemes = [LamportCertTest, RsaTest, DsaTest, EccTest]

def initTest(scheme, param):
    schemes[scheme](param)

def testSign():
    currentTest.testSign()

def testVerify():
    currentTest.testVerify()

### The next cell takes roughly 20 minutes to run.

In [8]:
import time
import sys
tStart = time.time()

def testSignVerify(scheme, param):
    initTest(scheme, param)
    signResult = %timeit -q -r2 -n3 -o testSign()
    verifyResult = %timeit -q -r2 -n3 -o testVerify()

    return (signResult.best, verifyResult.best)

class TestCase:
    def __init__(self, security, cert, rsa, dsa=-1, ecdsa=-1):
        self.security = security
        self.cert = cert
        self.rsa = rsa
        self.dsa = dsa
        self.ecdsa = ecdsa
  
    def doTest(self):
        (self.certSign, self.certVerify) = testSignVerify(0, 2**self.cert)
        (self.rsaSign, self.rsaVerify) = testSignVerify(1, self.rsa)

        if self.dsa > 0:
            (self.dsaSign, self.dsaVerify) = testSignVerify(2, self.dsa)

        if self.ecdsa > 0:
            (self.ecdsaSign, self.ecdsaVerify) = testSignVerify(3, self.ecdsa)

testCases = [
    TestCase(80, 72, 1024, dsa=1024),
    TestCase(112, 104, 2048, dsa=2048),
    TestCase(128, 120, 3072, dsa=3072, ecdsa=256),
    TestCase(192, 184, 7680, ecdsa=384),
    TestCase(256, 248, 15360, ecdsa=521),
]
data = {}

for test in testCases:
    test.doTest()
    data[test.security] = {"lamport_cert_sign_time": test.certSign, "lamport_cert_verify_time": test.certVerify, "rsa_sign_time": test.rsaSign, "rsa_verify_time": test.rsaVerify}
    
    if test.dsa > 0:
        data[test.security]["dsa_sign_time"] = test.dsaSign
        data[test.security]["dsa_verify_time"] = test.dsaVerify
    
    if test.ecdsa > 0:
        data[test.security]["ecdsa_sign_time"] = test.ecdsaSign
        data[test.security]["ecdsa_verify_time"] = test.ecdsaVerify

tEnd = time.time()

data

{80: {'lamport_cert_sign_time': 0.11562282366667394,
  'lamport_cert_verify_time': 0.032888640000000656,
  'rsa_sign_time': 0.0008064549999933964,
  'rsa_verify_time': 0.0003080560000038683,
  'dsa_sign_time': 0.00027196466666623564,
  'dsa_verify_time': 0.0001905379999982415},
 112: {'lamport_cert_sign_time': 0.1774642979999991,
  'lamport_cert_verify_time': 0.047535477999995614,
  'rsa_sign_time': 0.004878503666664831,
  'rsa_verify_time': 0.0020477323333333666,
  'dsa_sign_time': 0.0005253876666699853,
  'dsa_verify_time': 0.0006290976666605275},
 128: {'lamport_cert_sign_time': 0.29987889033333204,
  'lamport_cert_verify_time': 0.08333095699999642,
  'rsa_sign_time': 0.014872875666668506,
  'rsa_verify_time': 0.005778042666662486,
  'dsa_sign_time': 0.0007915460000068227,
  'dsa_verify_time': 0.0012151223333300247,
  'ecdsa_sign_time': 0.0008238996666705134,
  'ecdsa_verify_time': 0.00159256699999825},
 192: {'lamport_cert_sign_time': 0.35259488066666717,
  'lamport_cert_verify_tim

In [9]:
transposedData = {"security_level": [], "lamport_cert_sign_time": [], "lamport_cert_verify_time": [], "rsa_sign_time": [], "rsa_verify_time": [], "dsa_sign_time": [], "dsa_verify_time": [], "ecdsa_sign_time": [], "ecdsa_sign_time": [], "ecdsa_verify_time": []}

for k in data.keys():
    transposedData["security_level"].append(k)

    for l in transposedData.keys():
        if l != "security_level":
            transposedData[l].append(data[k][l] if l in data[k] else "N/A")

transposedData

{'security_level': [80, 112, 128, 192, 256],
 'lamport_cert_sign_time': [0.11562282366667394,
  0.1774642979999991,
  0.29987889033333204,
  0.35259488066666717,
  0.5387471136666591],
 'lamport_cert_verify_time': [0.032888640000000656,
  0.047535477999995614,
  0.08333095699999642,
  0.09390976766666388,
  0.1370143803333311],
 'rsa_sign_time': [0.0008064549999933964,
  0.004878503666664831,
  0.014872875666668506,
  0.16312619766665648,
  0.9361354423333145],
 'rsa_verify_time': [0.0003080560000038683,
  0.0020477323333333666,
  0.005778042666662486,
  0.07413186700000551,
  0.39958166633334713],
 'dsa_sign_time': [0.00027196466666623564,
  0.0005253876666699853,
  0.0007915460000068227,
  'N/A',
  'N/A'],
 'dsa_verify_time': [0.0001905379999982415,
  0.0006290976666605275,
  0.0012151223333300247,
  'N/A',
  'N/A'],
 'ecdsa_sign_time': ['N/A',
  'N/A',
  0.0008238996666705134,
  0.0013133549999982581,
  0.0023568619999802345],
 'ecdsa_verify_time': ['N/A',
  'N/A',
  0.0015925669999

In [10]:
from tabulate import tabulate
import pandas as pd

def lookupKeys(data, keyMap):
    out = {}
    
    for k in keyMap.keys():
        out[keyMap[k]] = data[k]
    
    return out

dfSign = pd.DataFrame(lookupKeys(transposedData, {"security_level": "Security_Level", "lamport_cert_sign_time": "LC", "rsa_sign_time": "RSA", "dsa_sign_time": "DSA", "ecdsa_sign_time": "ECDSA"}))
dfVerify = pd.DataFrame(lookupKeys(transposedData, {"security_level": "Security_Level", "lamport_cert_verify_time": "LC", "rsa_verify_time": "RSA", "dsa_verify_time": "DSA", "ecdsa_verify_time": "ECDSA"}))

print("Signing times:")
print(tabulate(dfSign, headers="keys", tablefmt="psql"))

print("Verification times:")
print(tabulate(dfVerify, headers="keys", tablefmt="psql"))

Signing times:
+----+------------------+----------+-------------+------------------------+-----------------------+
|    |   Security_Level |       LC |         RSA | DSA                    | ECDSA                 |
|----+------------------+----------+-------------+------------------------+-----------------------|
|  0 |               80 | 0.115623 | 0.000806455 | 0.00027196466666623564 | N/A                   |
|  1 |              112 | 0.177464 | 0.0048785   | 0.0005253876666699853  | N/A                   |
|  2 |              128 | 0.299879 | 0.0148729   | 0.0007915460000068227  | 0.0008238996666705134 |
|  3 |              192 | 0.352595 | 0.163126    | N/A                    | 0.0013133549999982581 |
|  4 |              256 | 0.538747 | 0.936135    | N/A                    | 0.0023568619999802345 |
+----+------------------+----------+-------------+------------------------+-----------------------+
Verification times:
+----+------------------+-----------+-------------+--------------