In [None]:
from hashlib import sha1
from pyasn1.codec.der.decoder import decode
from pyasn1.type.univ import Sequence
import re

def parse_ecdsa_signature(signature_der):
    decoded_signature, _ = decode(signature_der, asn1Spec=Sequence())
    r = int(decoded_signature[0])
    s = int(decoded_signature[1])
    return r, s

def sha(value):
    h = hex(int(value))[2:]
    if len(h)%2!=0:
        h = "0"+h
    bh = bytes.fromhex(h)
    return int(sha1(bh).digest().hex(),16)


class CounterTest:

    def print_card_name(self, name):
        print(name,"#"*10)

    def open_csv_result(self,filepath):
        try:
            with open(filepath) as f:
                lines = f.readlines()[1:]
        except FileNotFoundError:
            print("Not measured\n")
            return []
        if len(lines)==0:
            print("Something went wrong\n")
            return []
        print(f"Number of runs: {len(lines)}")
        return [line.strip().split(";") for line in lines]

    def load_all_results(self,card,prefix):
        lines = []
        filepath = f"results/{card}/{self.test}"
        pattern = rf"^{prefix}(?:_\d+)?.csv$" #prefix and possibly and underscore with number
        for file in os.listdir(filepath):
            if re.match(pattern,file):
                with open(os.path.join(filepath,file)) as f:
                    lines.extend(f.readlines()[1:])
        if len(lines)==0:
            print("Not measured or something went wrong\n")
            return []
        print(f"Number of runs: {len(lines)}")
        return [line.strip().split(";") for line in lines]

    def load_csv_signatures(self,card,prefix):
        lines = self.load_all_results(card,prefix)
        sigs = []
        for line in lines:
            success,error,sig,valid,_,nonce_hex,key_hex,_,_,_,_,_ = line
            sigs.append({"success":success, "signature":parse_ecdsa_signature(bytes.fromhex(sig)), "nonce": nonce_hex, "key":key_hex, "valid":valid})
        return sigs

    def load_csv_ecdhs(self,card,prefix):
        lines = self.load_all_results(card,prefix)
        ecdhs = []
        for line in lines:
            success,error,secret,key,_,_,_,_,_ = line
            ecdhs.append({"success":success, "secret":secret, "key":key})
        return ecdhs

    def load_csv_keygens(self,card,prefix):
        lines = self.load_all_results(card,prefix)
        keygens = []
        for line in lines:
            success,error,key,point,_,_,_,_ = line
            keygens.append({"success":success, "key":key, "point":point})
        return keygens
        
class Test3n(CounterTest):


    def __init__(self,curve_prime,curve_full,point,key):
        self.test = "test3n"
        with open(curve_prime) as f:
            p,a,b,x,y,n,_ = map(lambda x: int(x,16),f.read().split(","))
        self.n = n
        self.curve = EllipticCurve(GF(p),[a,b])
        self.prime_gen = self.curve(x,y)

        with open(curve_full) as f:
            _,_,_,x,y,_,_ = map(lambda x: int(x,16),f.read().split(","))
        self.n3_gen = self.curve(x,y)
        self.public = self.n3_gen
        self.gen3 = n*self.n3_gen
        
        with open(key) as f:
            self.k = int(f.read(),16)


        

    def print_ecdh(self,card):
        self.print_card_name(card)

        scalars = {self.k+i*self.n:f"k+{i}*n" for i in range(3)}
        shas = {}
        for scalar,s in scalars.items():
            shas[sha((scalar*self.public)[0])] = scalar
        
        counts = {s:0 for l,s in scalars.items()}
        secrets = [int(s["secret"],16) for s in self.load_csv_ecdhs(card,"ecdh")]
        N = len(secrets)
        if not secrets:return
        others = 0
        for s in secrets:
            if not s in shas:
                others+=1
            else:
                counts[scalars[shas[s]]]+=1
        for c,v in counts.items():
            print(c,v)
        print("others",others)
        print()

    def print_ecdsa(self,card, fixed_key = False):
        self.print_card_name(card)
        filename = f"ecdsa_fixed" if fixed_key else f"ecdsa"
        sigs = self.load_csv_signatures(card,filename)
        if not sigs:return

        N = len(sigs)
        G = self.n3_gen
        nG = self.gen3
        counts = {j:{i:0 for i in range(3)} for j in range(3)}
        others = 0
        sigs_set = set()
        for sig in sigs:
            r, s = sig["signature"]
            r = r%self.n
            sigs_set.add((r,s))
            nonce = int(sig["nonce"],16)
            nonceG = nonce*G
            values = {ZZ((nonceG+i*nG)[0])%self.n:i for i in range(3)}
            if r in values:
                counts[nonce%3][values[r]]+=1
                print(nonce%2,values[r])
            else:
                others+=1
        for kmod,cnts in counts.items():
            print(f"For k mod 3 = {kmod}:")
            for c,v in cnts.items():
                print(f"\t k+{c}*n",v)
        print("others",others)
        print("number of diff sigs",len(sigs_set))
        print()



    def print_keygen(self,card):
        self.print_card_name(card)
        counts = {j:{i:0 for i in range(3)} for j in range(3)}
        keys = self.load_csv_keygens(card,"keygen")
        if not keys:return

        others = 0
        for line in keys:
            pubW = line["point"]
            pubWx,pubWy = pubW[2:][:len(pubW)//2-1],pubW[2:][len(pubW)//2-1:]
            pubWx,pubWy,k = map(lambda x: int(x,16),[pubWx,pubWy,line["key"]])
            scalars = {k+i*self.n:i for i in range(3)}
            P = self.curve(pubWx,pubWy)
            for scalar,s in scalars.items():
                PP = scalar*self.n3_gen
                if PP==P:
                    counts[scalar%3][s]+=1
                    break
            else:
                others+=1
        N = len(keys)
        for kmod,cnts in counts.items():
            print(f"For k mod 3 = {kmod}:")
            for c,v in cnts.items():
                print(f"\t k+{c}*n",v)
        print("others",others)
        print()


class Testinverse(CounterTest):

    def __init__(self,curve,point,key,l):
        self.l = l
        self.test = "testinverse"
        with open(curve) as f:
            p,a,b,x,y,n,_ = map(lambda x: int(x,16),f.read().split(","))
        self.n = n
        self.curve = EllipticCurve(GF(p),[a,b])
        self.gen = self.curve(x,y)
        self.public = self.gen
        
        with open(key) as f:
            self.k = int(f.read(),16)

        
    def print_ecdh(self,card):
        self.print_card_name(card)
        secrets_lines = self.load_csv_ecdhs(card,f"ecdh_{self.l}")
        if not secrets_lines:return

        N = len(secrets_lines)
        secrets = [int(line["secret"],16) for line in secrets_lines]
        print("Number of secrets as a set: ", len(set(secrets)),"\n") 
        P = self.public
        correct = 0
        for line in secrets_lines:
            secret,privS = map(lambda x: int(x,16),[line["secret"],line["key"]])
            S = privS*P
            if sha(S[0])==secret:
                correct+=1
        print(f"Correct secrets: {correct}\n") 
    
        
    def print_ecdsa(self,card, fixed_key=False):
        self.print_card_name(card)
        filename = f"ecdsa_fixed_{self.l}" if fixed_key else f"ecdsa_{self.l}"
        sigs_lines = self.load_csv_signatures(card,filename)
        if not sigs_lines:return

        valid_sigs = [line for line in sigs_lines if line["valid"]=="1"]
        print("num of valid sigs (out of 1000)",len(valid_sigs), f"{100*float(len(valid_sigs)/1000)}%")
        print(f"expected: {1-float(1/11)} (one inverse) or {1-float(1-(10/11)**2)} (two inverses) or {1-float(1-(10/11)**3)} (three inverses)")
        sigs_set = set(line["signature"] for line in valid_sigs)
        print(f"num of diff sigs {len(sigs_set)}\n")
    

    def print_keygen(self,card):
        self.print_card_name(card)
        key_lines = self.load_csv_keygens(card,f"keygen_{self.l}")
        if not key_lines:return

        correct = 0
        for line in key_lines:
            privS = int(line["key"],16)
            P = privS*self.gen
            correct += (P==self.public)        
        print("num of runs (out of 1000)",correct, f"{100*correct/1000}%")
        print() 


class Testk10(CounterTest):

    def __init__(self, curve, point, key):
        self.test = "testk10"
        with open(curve) as f:
            p,a,b,x,y,n,_ = map(lambda x: int(x,16),f.read().split(","))
        self.n = n
        self.curve = EllipticCurve(GF(p),[a,b])
        self.gen = self.curve(x,y)
        with open(point) as f:
            px,py = map(lambda x: int(x,16),f.read().split(","))
        self.public = self.curve(px,py)
        
        with open(key) as f:
            self.k = int(f.read(),16)

        S = self.k*self.public
        self.correct_result = sha(S[0])
        


    def print_testk10_stats(self, card):
        self.print_card_name(card)
        secrets_lines = self.load_csv_ecdhs(card,"ecdh")
        if not secrets_lines:return
        if len(lines)<1000:
            print(f"fail after {len(lines)} ecdhs")
            return
        correct = 0
        for line in secrets_lines:
            if self.correct_result==int(line["secret"],16):
                correct+=1
        print(f"Correct secrets: {correct}\n") 



import cypari2
from tqdm import tqdm


def divisors(primes, powers):
    for comb in itertools.product(*[range(power+1) for power in powers]):
        value = 1
        for prime, power in zip(primes, comb):
            value *= prime**power
        yield value

def pari_factor(number):
    pari = cypari2.Pari(256_000_000, 2_000_000_000)
    factors = pari.factor(number)
    primes = list(map(int, factors[0]))
    powers = list(map(int, factors[1]))
    return primes, powers

def pari_dlog(a,b,p, P, G, real_n, facts_str):
    pari = cypari2.Pari(256_000_000, 2_000_000_000)
    e = pari.ellinit([a,b],p)
    e[15][0] = real_n
    facts = pari(facts_str)
    dlog = pari.elllog(e, P, G, facts)
    return int(dlog)



class GSRmask(CounterTest):

    def __init__(self, curve, realn, point, key):
        self.test = "testdn"
        with open(curve) as f:
            p,a,b,x,y,n,_ =map(lambda x: int(x,16),f.read().split(","))
        with open(realn) as f:
            real_n = int(f.read(),16)
        with open(point) as f:
            px,py = map(lambda x: int(x,16),f.read().split(","))
        self.n = n
        self.real_n = real_n
        self.curve = EllipticCurve(GF(p),[a,b])
        self.gen = self.curve(x,y)
        self.public = self.curve(px,py)

        self.pari_real_n_facts = repr(pari.factor(real_n))
        self.a = a
        self.b = b
        self.p = p

        

    def compute_mask(self,scalar,point_candidates,G):
        ds = []
        for P in point_candidates:
            d = pari_dlog(self.a,self.b,self.p, [int(P[0]),int(P[1])], [int(G[0]),int(G[1])], self.real_n, self.pari_real_n_facts)
            ds.append(d)
            for dp in [d,self.real_n-d]:
                dp = ZZ(dp)
                scalar = ZZ(scalar)
                if ((dp-scalar)%self.real_n)%(self.n-self.real_n)==0:
                    mask = ZZ(((dp-scalar)%self.real_n)/(self.n-self.real_n))
                    print(f"Mask has {mask.nbits()} bits ({mask}), {scalar}, {dp}")
                    
        return ds
        print("No mask found")
        
    
    def recover_ecdsa(self,card, N = 3):
        self.print_card_name(card)
        sig_lines = self.load_csv_signatures(card,"ecdsa")
        params = []
        for line in sig_lines[:N]:
            nonce = int(line["nonce"],16)
            print(nonce)
            r, s = line["signature"]
            r = r%self.n
            candidates = []
            rss = []
            while r<self.p:
                R = self.curve.lift_x(self.curve.base_field()(r))
                candidates.append(R)
                rss.append(r)
                r+=self.n
            ds = self.compute_mask(nonce,candidates,self.gen)
            params.append((nonce,rss,ds))
        return params
        print()
            


    def recover_keygen(self,card, N=3):
        self.print_card_name(card)
        keygen_lines = self.load_csv_keygens(card,"keygen")
        for line in keygen_lines:
            pubW,privS = line["point"],line["key"]
            pubWx,pubWy = pubW[2:][:len(pubW)//2-1],pubW[2:][len(pubW)//2-1:]
            pubWx,pubWy,privS = map(lambda x: int(x,16),[pubWx,pubWy,privS])
            
            P = self.curve(pubWx,pubWy)
            self.compute_mask(privS,[P],self.gen)
        print()


    def recover_ecdh_plain(self,card, N = 3):
        self.print_card_name(card)
        secret_lines = self.load_csv_ecdhs(card,"ecdh_randomkey")
        for line in secret_lines:
            privS,secret = line["key"],line["secret"]
            secret,privS = map(lambda x: int(x,16),[secret,privS])
            print(privS%2)
            R = self.curve.lift_x(self.curve.base_field()(secret))
            self.compute_mask(privS,[R],self.public)    
        print()


    def recover_ecdh_plain_good_gen(self,card, N=3):
        self.print_card_name(card)
        secret_lines = self.load_csv_ecdhs(card,"ecdh_plain_good_gen")

        for line in secret_lines:
            privS,secret = line["key"],line["secret"]
            secret,privS = map(lambda x: int(x,16),[secret,privS])
            
            R = self.curve.lift_x(self.curve.base_field()(secret))
            self.compute_mask(privS,[R],self.public)  
        print()




In [None]:
cards = ['A1','F1','F2','G1','I1','I2','N1','N10','N2','N3','N4','N5','N6','N7','N8','N9','S1','S2']


In [None]:
test3n = Test3n("tests/test3n/curve_prime_gen.csv","tests/test3n/curve.csv" ,"tests/test3n/point_3n.csv","tests/test3n/key.csv")

for card in cards:
    test3n.print_ecdh(card)

In [None]:
for card in cards:
    test3n.print_ecdsa(card)

In [None]:
for card in cards:
    test3n.print_ecdsa(card, fixed_key = True)

In [None]:
for card in cards:
    test3n.print_keygen(card)

### Test inverse

In [None]:
testinverse = Testinverse("tests/testinverse/cofactor256p11_full.csv","tests/testinverse/point_11n.csv","tests/test3n/key.csv",11)


In [None]:
for card in cards:
    testinverse.print_ecdh(card)

In [None]:
for card in cards:
    testinverse.print_ecdsa(card)

In [None]:
for card in cards:
    testinverse.print_ecdsa(card, fixed_key = True)

In [None]:
for card in cards:
    testinverse.print_keygen(card)

### Recover GSR mask

In [None]:
gsrmask = GSRmask("tests/testdn/weakcurve_32_n_1.csv","tests/testdn/realn.csv", "tests/testdn/weakcurve_32_n_1_point.csv","tests/testdn/key.csv")

In [None]:
for card in cards:
    params = gsrmask.recover_keygen(card,N=10)

In [None]:
for card in cards:
    params = gsrmask.recover_ecdsa(card,N=10)

In [None]:
for card in cards:
    params = gsrmask.recover_ecdh_plain(card,N=10)