In [1]:
from abc import ABC, abstractmethod
from collections import namedtuple

from charm.toolbox.eccurve import prime192v1
from charm.toolbox.ecgroup import ECGroup, G, ZR
from charm.toolbox.integergroup import IntegerGroupQ

In [2]:
class SchnorrGroup(ABC):
    @abstractmethod
    def order(self):
        pass
    @abstractmethod
    def random_generator(self):
        pass
    @abstractmethod
    def random_scalar(self):
        pass
    @abstractmethod
    def hash(self, *args):
        pass

class IntegerSchnorrGroup(SchnorrGroup):
    def __init__(self, bits):
        self.group = IntegerGroupQ()
        self.group.paramgen(bits)
    def order(self):
        return self.group.q
    def random_generator(self):
        return self.group.randomGen()
    def random_scalar(self):
        return self.group.random()
    def hash(self, *args):
        return self.group.hash(*args)

class ECSchnorrGroup(SchnorrGroup):
    def __init__(self, curve):
        self.group = ECGroup(curve)
    def order(self):
        return self.group.order()
    def random_generator(self):
        return self.group.random(G)
    def random_scalar(self):
        return self.group.random(ZR)
    def hash(self, *args):
        return self.group.hash(args)

In [3]:
Params = namedtuple('Params', ('group', 'g'))
Proof = namedtuple('Proof', ('u', 'z'))

def generate_params(group):
    g = group.random_generator()
    return Params(group=group, g=g)

def generate_witness_and_statement(params):
    x = params.group.random_scalar()
    h = params.g ** x
    return (x, h)

def prover(params, h, x):
    r = params.group.random_scalar()
    u = params.g ** r
    c = params.group.hash(params.g, params.group.order(), h, u)
    z = r + c * x
    return Proof(u=u, z=z)

def verifier(params, h, proof):
    c = params.group.hash(params.g, params.group.order(), h, proof.u)
    return params.g ** proof.z == proof.u * (h ** c)

def test_correctness(iterations, group_class, **group_args):
    for _ in range(iterations):
        params = generate_params(group_class(**group_args))
        x, h = generate_witness_and_statement(params)
        proof = prover(params, h, x)
        if not verifier(params, h, proof):
            return False
    return True

def random_z_proof(params):
    r = params.group.random_scalar()
    u = params.g ** r
    z = params.group.random_scalar()
    return Proof(u=u, z=z)

def test_random_z_proof_rejected(iterations, group_class, **group_args):
    for _ in range(iterations):
        params = generate_params(group_class(**group_args))
        _, h = generate_witness_and_statement(params)
        proof = random_z_proof(params)
        if verifier(params, h, proof):
            return False
    return True

In [4]:
%time test_correctness(iterations=1000, group_class=IntegerSchnorrGroup, bits=128)

CPU times: user 4.83 s, sys: 75.8 ms, total: 4.91 s
Wall time: 4.91 s


True

In [5]:
%time test_correctness(iterations=1000, group_class=ECSchnorrGroup, curve=prime192v1)

CPU times: user 1.9 s, sys: 15.9 ms, total: 1.92 s
Wall time: 1.92 s


True

In [6]:
%time test_random_z_proof_rejected(iterations=1000, group_class=IntegerSchnorrGroup, bits=128)

CPU times: user 4.88 s, sys: 72 ms, total: 4.95 s
Wall time: 4.95 s


True

In [7]:
%time test_random_z_proof_rejected(iterations=1000, group_class=ECSchnorrGroup, curve=prime192v1)

CPU times: user 1.91 s, sys: 11.9 ms, total: 1.92 s
Wall time: 1.92 s


True