In [1]:

def init(_p, _r, _h, _A, _B):
    global p, r, h, Fp, Fr, A, B, E, K, x, L, y, eqn

    ## STEP 1
    # Initialize an elliptic curve
    p = _p
    r = _r
    h = _h
    Fp = GF(p)  # Base Field
    Fr = GF(r)  # Scalar Field
    A = _A
    B = _B
    E = EllipticCurve(Fp, [A, B])
    E.set_order(r * h)

    ## STEP 2
    K.<x> = Fp[]
    L.<y> = K[]
    eqn = y^2 - x^3 - A * x - B

## STEP 3
# Returns line passing through points, works for all points and returns 1 for O + O = O
def line(A, B):
    if A == 0 and B == 0:
        return 1
    else:
        [a, b, c] = Matrix([A, B, -(A+B)]).transpose().kernel().basis()[0]
        return a*x + b*y + c

## STEP 4
# Works for A == B but not A == -A, as the line has no slope or intercept
def slope_intercept(A, B):
    [a, b, c] = Matrix([A, B, -(A+B)]).transpose().kernel().basis()[0]
    return (-a/b, -c/b)

## STEP 5
# Fails at 0
def eval_point(f, P):
    (x, y) = P.xy()
    return f(x=x, y=y)

## STEP 6
# f(x) + y g(x) -> (f(x), g(x)), should reduce mod eqn first
def get_polys(D):
    return ( K(D(y=0)), K(D(y=1) - D(y=0)) )

## STEP 7
# Accepts arbitrary list of points, including duplicates and inverses, and constructs function
# intersecting exactly those points if they form a principal divisor (i.e. sum to zero).
def construct_function(Ps):
    # List of intermediate sums/principal divisors, removes 0
    xs = [(P, line(P, -P)) for P in Ps if P != 0]

    while len(xs) != 1:
        assert(sum(P for (P, _) in xs) == 0)
        xs2 = []

        # Carry extra point forward
        if mod(len(xs), 2) == 1:
            x0 = xs[0]
            xs = xs[1:]
        else:
            x0 = None

        # Combine the functions for all pairs
        for n in range(0, floor(len(xs)/2)):
            (A, aNum) = xs[2*n]
            (B, bNum) = xs[2*n+1]

            # Divide out intermediate (P, -P) factors
            num = L((aNum * bNum * line(A, B)).mod(eqn))
            den = line(A, -A) * line(B, -B)
            D = num / K(den)

            # Add new element
            xs2.append((A+B, D))

        if x0 != None:
            xs2.append(x0)

        xs = xs2

    assert(xs[0][0] == 0)

    # Normalize, might fail but negl probability for random points. Must be done for zkps
    # although free to use any coefficient
    D = D / D(x=0, y=0)

    # Make sure to reduce D mod(eqn) as well before computing dlog
    D = D.mod(eqn)

    return D

def random_element():
    # For general elliptic curves, we want to clear cofactor depending on application
    # Works for arbitrary curve groups
    return E.random_element() * h

## STEP 8
# Random principal divisor with n points
def random_principal(n):
    Ps = [random_element() for _ in range(0, n-1)]
    Ps.append(-sum(Ps))
    return Ps

## STEP 9
# Random principal divisor with points with given multiplicities
def random_principal_mults(ms):
    # Need to invert the last multiplicity to find the correct final value
    m0 = ms[-1]
    m0Inv = ZZ(Fr(m0)^(-1))

    Ps = [random_element() for _ in range(0, len(ms)-1)]
    Q = -m0Inv * sum(m * P for (m, P) in zip(ms[:-1], Ps))
    Ps.append(Q)

    assert(sum(m * P for (m, P) in zip(ms, Ps)) == 0)
    return sum(( m * [P] for (m, P) in zip(ms, Ps) ), [])

## STEP 10
# Test at a random principal divisor
def test_at_random_principal_divisor(uses_dlog=False):
    Ps = random_principal(33)
    D = construct_function(Ps)
    (f, g) = get_polys(D)

    # Should be the same up to constant
    assert((f^2 - (x^3 + A * x + B) * g^2) / product(x - P.xy()[0] for P in Ps) in Fp)
    assert(all(eval_point(D, P) == 0 for P in Ps))

    ## STEP 16
    # Both should be true (uses same points as higher mult test)
    [A0, A1] = [random_element() for _ in range(0, 2)]
    if uses_dlog: D = dlog(D)
    assert(eval_function_challenge_mixed(A0, A1, D, uses_dlog) == sum(eval_point_challenge(A0, A1, P) for P in Ps))
    assert(eval_function_challenge_dupl(A0, D, uses_dlog) == sum(eval_point_challenge(A0, A0, P) for P in Ps))

## STEP 11
# Test at random principal divisor with multiplicity. For a divisor that does not contain
# both P and -P for any P, it is sufficient to check the previous conditions and that
# gcd(f, g) = 1
def test_at_random_principal_divisor_with_multiplicity(uses_dlog=False):
    Ps = random_principal_mults([1,2,3,4,5,6])
    D = construct_function(Ps)
    (f, g) = get_polys(D)

    assert((f^2 - (x^3 + A * x + B) * g^2) / product(x - P.xy()[0] for P in Ps) in Fp)
    assert(all(eval_point(D, P) == 0 for P in Ps))
    assert(gcd(f, g) == 1)

    ## STEP 16
    # Both should be true (uses same points as higher mult test)
    [A0, A1] = [random_element() for _ in range(0, 2)]
    if uses_dlog: D = dlog(D)
    assert(eval_function_challenge_mixed(A0, A1, D, uses_dlog) == sum(eval_point_challenge(A0, A1, P) for P in Ps))
    assert(eval_function_challenge_dupl(A0, D, uses_dlog) == sum(eval_point_challenge(A0, A0, P) for P in Ps))

# The test to check that a function hits exactly a certain set of points uses
# Weil reciprocity to check that the product of one function over the points of
# the divisor of the other is the same quantity, up to leading coefficients.
# Taking the logarithmic derivative wrt a coordinate of one divisor gives a sum
# of rational functions. That is what is being checked here. While the proof
# will evaluate the dlog function of at the points, it is important to note
# that this is also a rational function in the coefficients of the other
# function.

## STEP 12
# Return logarithmic derivative wrt x
def dlog(D):
    # Derivative via partials
    Dx = D.differentiate(x)
    Dy = D.differentiate(y)
    Dz = Dx + Dy * ((3*x^2 + A) / (2*y))

    # This is necessary because Sage fails when diving by D
    U = L(2*y * Dz)
    V = L(2*y * D)

    Den = K((V * V(y=-y)).mod(eqn))
    Num = L((U * V(y=-y)).mod(eqn))

    # Must clear the denonimator so mod(eqn) well defined
    assert(L(y * (Num * D - Den * Dz)).mod(eqn) == 0)

    return Num/Den # == Dz/D

## STEP 13
# Given a pair of distinct challenge points/line evaluate the function field element
def eval_function_challenge_mixed(A0, A1, D, uses_dlog=False):
    assert(A0 != A1)
    A2 = -(A0 + A1)
    (m, b) = slope_intercept(A0, A1)
    DLog = D if uses_dlog else dlog(D)

    # Coefficient per point
    coeff = 1/((3 * x^2 + A) / (2 * y) - m)
    expr = DLog * coeff

    # From paper, check that expr sum is 0, equals slope derivative wrt intercept
    assert(sum(eval_point(coeff, P) for P in [A0, A1, A2]) == 0)

    # Evaluate
    return sum(eval_point(expr, P) for P in [A0, A1, A2])

## STEP 14
# Given a duplicated challenge point/line evaluate the function field element
def eval_function_challenge_dupl(A0, D, uses_dlog=False):
    A2 = -(2*A0)
    (m, b) = slope_intercept(A0, A2)
    DLog = D if uses_dlog else dlog(D)

    # Coefficient for A2
    (xA0, yA0) = A0.xy()
    (xA2, yA2) = A2.xy()
    coeff2 = (2 * yA2) * (xA0 - xA2) / (3 * xA2^2 + A - 2 * m * yA2)
    coeff0 = (coeff2 + 2 * m)

    return eval_point(DLog * coeff0, A0) - eval_point(DLog * coeff2, A2)

## STEP 15
# Given a pair of challenge points, detect if duplicate/mixed and modify numerator
def eval_point_challenge(A0, A1, P, mult=1):
    (m, b) = slope_intercept(A0, A1)
    (xP, yP) = P.xy()

    if A0 == A1:
        (xA, _) = A0.xy()
        num = (xA - xP)
    else:
        num = -1

    den = yP - m * xP - b
    return mult*num/den

# The ECIP takes advantage of the linearity of the right hand sides of the
# equations to sum multiplicities of the same point in different functions.
# The following shows how this works in base with scalars that are half the
# length of the field. Note this is important; if the scalars can exceed field
# length protocol can fail to be sound. Also works for random linear
# combinations.

## STEP 17
# return base -3 digits from {-1, 0, 1} from starting with least signficant
def base_neg3(n,k):
    ds = []
    for i in range(0, k):
        q = -floor(n/3)
        r = ZZ(mod(n, 3))
        if r == 2:
            q = q - 1
            r = -1
        ds.append(r)
        n = q

    assert(n == 0)
    assert(sum(d * (-3)^i for (i, d) in enumerate(ds)))

    return ds


def neg_3_base_le(scalar):
    """
    Decomposes a scalar into base -3 representation.
    :param scalar: The integer to be decomposed.
    :return: A list of coefficients in base -3 representation. (Least significant bit first),
    with digits [-1, 0, 1]
    """
    if scalar == 0:
        return [0]
    digits = []
    while scalar != 0:
        remainder = scalar % 3
        if (
            remainder == 2
        ):  # if the remainder is 2, we set it to -1 and add 1 to the next digit
            remainder = -1
            scalar += 1
        # For remainder 1 and 0, no change is required
        digits.append(remainder)
        scalar = -(scalar // 3)  # divide by -3 for the next digit

    return digits

## STEP 18
# P and -P are counted separately in basis
def pos_neg_mults(ds):
    a = sum((-3)^i for (i, d) in enumerate(ds) if d == 1)
    b = sum((-3)^i for (i, d) in enumerate(ds) if d == -1)
    return (a, b)

## STEP 19
# Construct the principal divisor for each row given sum from previous row
def row_function(A0, ds, Ps, Q, uses_dlog=False):
    # Construct divisor for row
    Q2 = -3*Q + sum(d * P for (d, P) in zip(ds, Ps))
    div_ = [-Q, -Q, -Q, -Q2] + [d * P for (d, P) in zip(ds, Ps)]
    div = [P for P in div_ if P != 0]
    assert(sum(div) == 0)

    # Check that polynomial for row is correct
    D = construct_function(div)
    if uses_dlog: D = dlog(D)
    LHS = eval_function_challenge_dupl(A0, D, uses_dlog)
    RHS = sum(eval_point_challenge(A0, A0, P) for P in div)
    assert(LHS == RHS)

    return (D, Q2, div)

## STEP 20
# Compute the function for each row using Shamir's trick and -3
def ecip_functions(A0, Bs, dss, uses_dlog=False):
    rows = list(dss)
    rows.reverse()
    Q = 0
    Ds = []
    for ds in rows:
        (p, Q, _) = row_function(A0, ds, Bs, Q, uses_dlog)
        Ds.append(p)

    # Want lowest order first
    Ds.reverse()
    return (Q, Ds)

## STEP 21
# Construct digit vectors, note scalars are smaller than characteristic by construction
def construct_digit_vectors(es):
    dss_ = [neg_3_base_le(e) for e in es] # Base -3 digits
    max_len = max([len(ds) for ds in dss_])
    dss_ = [ds+[0]*(max_len-len(ds)) for ds in dss_]
    epns = list(map(pos_neg_mults, dss_))                                # list of P and -P mults per e
    dss = Matrix(dss_).transpose()
    return (epns, dss)

def prover(A0, Bs, es, uses_dlog=False):
    assert len(Bs) == len(es)
    (epns, dss) = construct_digit_vectors(es)

    ## STEP 22
    # Kinda slow
    (Q, Ds) = ecip_functions(A0, Bs, dss, uses_dlog)

    ## STEP 23
    # Q is the final sum
    assert(Q == sum(e * B for (e, B) in zip(es, Bs)))
    assert(Q == sum((ep - en) * B for ((ep, en), B) in zip(epns, Bs)))
    return (epns, Q, Ds)

## STEP 24
# Takes two mults and evaluates both P and -P
def eval_point_challenge_signed(A0, A1, P, mp, mn):
    return eval_point_challenge(A0, A1, P, mult=mp) + eval_point_challenge(A0, A1, -P, mult=mn)

## STEP 25
# Sides should equal, remember to account for result point (-Q)
def verifier(A0, Bs, epns, Q, Ds, uses_dlog=False):
    LHS = sum((-3)^i * eval_function_challenge_dupl(A0, D, uses_dlog) for (i, D) in enumerate(Ds))
    LHS_D=sum((-3)^i * D for i in enumerate(Ds))
    LHS_2 = eval_function_challenge_dupl(A0, LHS_D, uses_dlog)
    assert LHS == LHS_2
    basisSum = sum(eval_point_challenge_signed(A0, A0, B, ep, en) for ((ep, en), B) in zip(epns, Bs))
    RHS = basisSum + eval_point_challenge(A0, A0, -Q)
    return LHS == RHS, LHS, RHS




In [2]:
init(0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD47, 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001,1,0,3)  
def testt(n):
    Bs=[random_element() for i in range(n)]
    es=[ZZ.random_element(1, 2^127) for _ in range(len(Bs))]
    epns, dss = construct_digit_vectors(es)
    A0 = random_element()
    Q, Ds = ecip_functions(A0, Bs, dss)
    print(len(Ds))
    Ds_poly = [get_polys(d) for d in Ds]
    max_x=max([p.degree() for p in [x[0] for x in Ds_poly]])
    max_y=max([p.degree() for p in [x[1] for x in Ds_poly]])
    print(f"Max x {max_x}, Max y {max_y} Total coeffs {n}_points = {len(Ds) * (max_x+max_y+1)} = {len(Ds) * (max_x+max_y+1)*4} limbs")
    return Q, Ds
A0=random_element()


In [3]:
for i in range(2):
    testt(2)

81
Max x 3, Max y 1 Total coeffs 2_points = 405 = 1620 limbs
81
Max x 3, Max y 1 Total coeffs 2_points = 405 = 1620 limbs


In [4]:
Q, Ds = testt(3)

81
Max x 3, Max y 2 Total coeffs 3_points = 486 = 1944 limbs


In [9]:
# Return logarithmic derivative wrt x
def eval_dlog(D, x0, y0):
    # Derivative via partials
    #print("D input:", D)
    Dx = D.differentiate(x)
    #print(f"Dx", Dx)
    Dy = D.differentiate(y)
    #print(f"Dy", Dy)
    a,b = get_polys(D)
    assert Dy==b
    Dz = (Dx + b * ((3*x^2 + A) / (2*y)))
    print(f"Dz", Dz)
    print(f"Dz(P)", Dz(x=x0, y=y0))
    # This is necessary because Sage fails when diving by D
    U = L(2*y * Dz)
    V = L(2*y * D)

    Num = L((U * V(y=-y)).mod(eqn))
    Den = K((V * V(y=-y)).mod(eqn))

    #print("Num1", Num)
    #print("Den1", Den)

    # Must clear the denonimator so mod(eqn) well defined
    assert(L(y * (Num * D - Den * Dz)).mod(eqn) == 0)

    print(f"Dz/D(P)", Dz(x=x0, y=y0) / D(x=x0, y=y0))
    print(f"D(P)", D(x=x0, y=y0))
    return Num(x=x0, y=y0)/Den(x=x0) # == Dz/D

def eval_dlog2(D, x0, y0):
    # evaluate (D'/D)(x,y)= (a'(x) - ((3x^2 +A)/2y) * b(x) - y b'(x)) / D(x,y)
    a,b=get_polys(D)
    print("a", a)
    print("b", b)
    Da=a.differentiate(x)
    Db=b.differentiate(x)
    print("Da", Da)
    print("Db", Db)
    num_p = Da + b * (3*x^2 + A) * (2*y)^(-1) + y*Db
    print(type(num_p))
    print("num_p", num_p)
    print("num_p(P)", num_p(x=x0, y=y0))
    num = Da(x0) + b(x0)*(3*x0*x0 + A) * (Fp(2*y0)^(-1)) + y0*Db(x0)
    print(f"Num(P): {num}")
    den = D(x=x0, y=y0)
    print(f"Den(P): {den}")
    res = num/den
    return res
    
def eval_dlog3(D, x0, y0):
    # evaluate (D'/D)(x,y)= (a'(x) - ((3x^2 +A)/2y) * b(x) - y b'(x)) / D(x,y)
    a,b=get_polys(D)
    print("a", a)
    print("b", b)
    Da=a.differentiate(x)
    Db=b.differentiate(x)
    print("Da", Da)
    print("Db", Db)
    num_p = Da + b * (3*x^2 + A) * (2*y)^(-1) + y*Db
    print(type(num_p))
    print("num_p", num_p)
    print("num_p(P)", num_p(x=x0, y=y0))
    num = Da(x0) + b(x0)*(3*x0*x0 + A) * (Fp(2*y0)^(-1)) + y0*Db(x0)
    print(f"Num(P): {num}")
    den = D(x=x0, y=y0)
    print(f"Den(P): {den}")
    res = num/den
    return res
    

d_in=Ds[0]
d_in= (3*x + 11)*y + x^2 + x + 1
print("input", d_in)
print("dlog", dlog(d_in).mod(eqn))
print("True eval:", dlog(d_in)(x=A0[0], y=A0[1]))
print("Second", eval_dlog(d_in, A0[0], A0[1]))

print("\n\n")
print("Third", eval_dlog2(d_in, A0[0], A0[1]))



input (3*x + 11)*y + x^2 + x + 1
dlog ((3648040478639879203707734290876212514782718526216303943781506315774204368097*x^5 + 19456215886079355753107916218006466745507832139820287700168033684129089963185*x^4 + 9728107943039677876553958109003233372753916069910143850084016842064544981591*x^3 + 18240202393199396018538671454381062573913592631081519718907531578871021840485*x^2 + 14592161914559516814830937163504850059130874104865215775126025263096817472396*x + 7296080957279758407415468581752425029565437052432607887563012631548408736197)/(x^8 + 17024188900319436283969426690755658402319353122342751737647029473612953717794*x^7 + 17024188900319436283969426690755658402319353122342751737647029473612953717800*x^6 + 7296080957279758407415468581752425029565437052432607887563012631548408736200*x^5 + 12160134928799597345692447636254041715942395087387679812605021052580681227034*x^4 + 2432026985759919469138489527250808343188479017477535962521004210516136245478*x^3 + 8*x^2 + 145921619145595168148309371635048

81


In [5]:
dlogs = [dlog(D) for D in Ds]
sum_dlog = sum((-3)^i *dl for (i, dl) in enumerate(dlogs))
print(f"Sum dlog: {sum_dlog}")
LHS = sum((-3)^i * eval_function_challenge_dupl(A0, D) for (i, D) in enumerate(Ds))
LHS_2 = eval_function_challenge_dupl(A0, sum_dlog, uses_dlog=True)
print(f"LHS: {LHS}")
print(f"LHS_2: {LHS_2}")
print(f"LHS == LHS_2: {LHS == LHS_2}")



Sum dlog: ((21422882318070656530151345494506221804740209662840182291642722119843293027870*x^4 + 3122511079805592727717723271926566277163617956811936165377677951366344747726*x^3 + 16049139882603089494283805724267123698012133335446039385752066535571653796964*x^2 + 891178554665645623370609996036944720127802112244720218775669741564345853649*x + 6573918638837752092119019315388415011832191442772085680381840050956468258054)/(x^7 + 20704459468082423530280781160287040120326860814222375708909666781559467369310*x^6 + 8204363097683289283274310411126674022729595294576828391401057789297176926024*x^5 + 15218034260443896329175285644511749269523085190240202134534104869016742840584*x^4 + 7027188473410110289285095950782436270720576716688607476507044412650165279727*x^3 + 2724846421210592627576525488122746979492474726432661511514135473246304569489*x^2 + 1877617037653138543033045443020697631176633256124959078224238817759776104577*x + 98473731822027208732995033718221484387904720804470303508464353610770991840

In [20]:
print(sum_dlog)
print(sum_dlog.coefficients()[0].numerator().coefficients()[0])


((21422882318070656530151345494506221804740209662840182291642722119843293027870*x^4 + 3122511079805592727717723271926566277163617956811936165377677951366344747726*x^3 + 16049139882603089494283805724267123698012133335446039385752066535571653796964*x^2 + 891178554665645623370609996036944720127802112244720218775669741564345853649*x + 6573918638837752092119019315388415011832191442772085680381840050956468258054)/(x^7 + 20704459468082423530280781160287040120326860814222375708909666781559467369310*x^6 + 8204363097683289283274310411126674022729595294576828391401057789297176926024*x^5 + 15218034260443896329175285644511749269523085190240202134534104869016742840584*x^4 + 7027188473410110289285095950782436270720576716688607476507044412650165279727*x^3 + 2724846421210592627576525488122746979492474726432661511514135473246304569489*x^2 + 1877617037653138543033045443020697631176633256124959078224238817759776104577*x + 9847373182202720873299503371822148438790472080447030350846435361077099184055))*y + (