In [2]:
p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
F = GF(p)
E = EllipticCurve(F, [0,3])
Px.<x> = PolynomialRing(F)
Pxy.<y> = PolynomialRing(Px)

EQ = y^2 - x^3 - 3


In [119]:
# we will only use computations with polynomials and not use Sage's native divisor machinery
# this will define a wrapper class that we will use to work with regular functions on a curve
# we probably could use standard Sage construction, but this would obfuscate matters a bit


# this reduces the polynomial to the form only having 0, 1 powers of y;
# it is a canonical representative of the quotient ring F[x,y]/(y^2 = x^3 + 3)

def normal_form(poly): 
    acc = 0
    replace = y^2 - EQ
    
    poly2 = y*poly # ugly hack to ensure it does not get coerced to poly(x)
    l = poly2.list()
    for i in range(len(l)-1):
        acc += l[i+1] * replace^(i//2) * y^(i%2)
    return acc


def hcoeff(poly): #computes highest coeff. of a normal form and sets it to 1
    poly2 = y*poly # same ugly hack
    l = poly2.list()
    if len(l) == 2:
        return l[1].list()[-1]
    
    if len(l) == 3:
        l0 = (l[1]*x).list() # and once again to ensure they are not coerced to integers
        l1 = (l[2]*x).list() # I hate this so much :(
        if 2*len(l1) + 3 > 2*len(l0):
            return l1[-1]
        else:
            return l0[-1]
    print("not normal polynomials are unsupported")
    assert(False)
    
def reduce(poly):
    tmp = normal_form(poly)
    return (tmp // hcoeff(tmp))

In [120]:

# this returns a line passing through a pair of points. will return None if both points are at infinity

def linefunc(a, b):
    temp = None
    if a == E(0):
        temp = b
    if b == E(0):
        temp = a
    if temp == E(0):
        return None
    if temp is not None:
        line = x
        tx, ty = temp.xy()
        return line - line(x=tx, y=ty)
    ax, ay = a.xy()
    bx, by = b.xy()
    if ax == bx:
        if ay == by:
            line = - 3 * ax^2 * x + 2 * ay * y
        else:
            line = x
    else:
        line = (bx-ax) * y - (by-ay) * x    
    return line - line(x=ax, y=ay)

class PointReplacer:
    
    def __init__(self, inputs, output, poly):
        self.inputs = inputs
        self.output = output
        self.poly = poly
    
    def __mul__(self, other):
        inputs = self.inputs + other.inputs # this merges the lists of inputs
        A = self.output
        B = other.output
        output = A+B # this adds up outputs
        Ax, Ay = A.xy()
        Bx, By = B.xy()
        
        numerator = self.poly * other.poly * linefunc(-A, -B)
        num_res = numerator(x=Ax) 
        eq_res = EQ(x=Ax)                  # numerator vanishes on the vertical line x=Ax mod eq
        
        quot = num_res//eq_res
        
        numerator = normal_form((numerator - quot*EQ)//(x-Ax))
        num_res = numerator(x=Bx)
        eq_res = EQ(x=Bx)
        
        quot = num_res//eq_res
        
        numerator = normal_form((numerator - quot*EQ)//(x-Bx))
        return(PointReplacer(inputs, output, numerator))
        
def replacer_from_point(point):
    return(PointReplacer([point], -point, linefunc(point, -point)))

# computes the regular function with divisor = given collection of points
# throws if the sum is nonzero


def list_prod(factors):
    if len(factors)==0:
        return 1
    if len(factors)==1:
        return factors[0]
    factors_new = []
    if len(factors)%2 == 1:
        factors_new.append(factors[-1])
    for i in range(len(factors)//2):
        factors_new.append(normal_form(factors[2*i] * factors[2*i + 1]))
    return list_prod(factors_new)

def cartier(points):
    polys = []
    pts = points.copy()

    # initialize trivial replacers
    for i in range(len(pts)):
        pts[i] = replacer_from_point(pts[i])    

    
    while len(pts) > 1:
        tmp = []
        i = 0
        for pt in pts:
            if pt.output == E(0):
                polys.append(pt.poly)
            else:
                tmp.append(pt)
        pts = []
        if len(tmp)%2 == 1:
            pts.append(tmp[-1])
        for i in range(len(tmp)//2):
            pts.append(tmp[2*i]*tmp[2*i+1])
    
    if len(pts)>0:
        assert(pts[0].output == E(0))
        polys.append(pts[0].poly)

    return list_prod(polys)

In [121]:
def gen_random_felt(F):
        m = randrange(F.order())
        return F(m)
    
def gen_random_point(E):
    while True:
        m = gen_random_felt(E.base())
        try:
            n = E.lift_x(m)
            return n
        except:
            continue

In [122]:
# choosing non-standard norm, checking in a point; refer to page 5 of the paper https://eprint.iacr.org/2022/596.pdf

# multiplicative version

pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -(a+b)

ax, ay = a.xy()
bx, by = b.xy()
cx, cy = c.xy()


z = linefunc(a, b)
z = reduce(z)

lhs = D(x=ax,y=ay)*D(x=bx,y=by)*D(x=cx,y=cy)

prod = 1

for pt in pts:
    ptx, pty = pt.xy()
    prod *= -z(x=ptx,y=pty)

rhs = prod

print("D(A)D(B)D(C) = product over points: {b}".format(b=(rhs==lhs)))

D(A)D(B)D(C) = product over points: True


In [123]:
# computing log-derivatives, doing log-derivatives check

pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -(a+b)

ax, ay = a.xy()
bx, by = b.xy()
cx, cy = c.xy()

z = linefunc(a, b)
z = reduce(z)

# now we need to compute d/dz D *along* the curve
# the answer is in Liam's paper, sanity check here

# we have differentials

Dx = diff(D, x)
Dy = diff(D, y)

EQx = diff(EQ, x)
EQy = diff(EQ, y)

zx = diff(z, x)
zy = 1 # diff(z, y)

l = []

for (sx, sy) in [a.xy(), b.xy(), c.xy()]:
    Dxp = Dx(x=sx,y=sy)
    Dyp = Dy(x=sx,y=sy)
    
    EQxp = EQx(x=sx,y=sy)
    EQyp = EQy(x=sx,y=sy)
    
    # zx, zy are constants already
    # we now need to find such t that (dD - t dEQ) is collinear with dz
    # (Dxp - t EQxp)/(Dyp - t EQyp) = zx/zy = zx
    # Dxp - t EQxp = zx Dyp - t zx EQyp
    # Dxp - zx Dyp = t (EQxp - zx EQ yp)
    # t = (Dxp - zx Dyp) / (EQxp - zx EQyp)    -- note that this is divisible unless line is tangent

    t = (Dxp - zx * Dyp)/(EQxp - zx * EQyp)
    
    # now we need to find proportionality coeff between the result and dz
    
    deriv = (Dyp - t * EQyp) # / zy
    
    l.append(deriv/ D(x=sx,y=sy))

lhs = l[0] + l[1] + l[2]

acc = 0

for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print("Our description is correct: {b}".format(b = (lhs == rhs)))

Our description is correct: True


In [171]:
# This is testing Liam Eagen's protocol for the case when z is tangent to the point A
# This allows us to query only 2 points, not three - A, A, -2A
# But log derivative becomes annoying
# Liam's paper suggests some formula, which is based on the idea that we could vary not only z, but also the slope
# The drawback is that we get nontrivial numerators, equal to x(A)-x(P) for a point P

# I have devised a different formula, which does not vary the slope - but it, sadly, involves second derivatives
# So Liam's formula is better.

# I'm leaving it here as a reminder for my struggle (and that I shouldn't drop third order terms in such problem
# even if it looks like I could :) ) 

# -------------------- Let's go!!! --------------------


pts = []

for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = -(2*a)

ax, ay = a.xy()
bx, by = b.xy()

z = linefunc(a, a)
z = reduce(z)

t = diff(z, x)

Dx = diff(D, x)
Dy = diff(D, y)
Dxx = diff(Dx, x)
Dxy = diff(Dx, y)
Dyy = diff(Dy, y)

a00 = D(x=ax, y=ay)
a10 = Dy(x=ax, y=ay)
a01 = Dx(x=ax, y=ay) - t*Dy(x=ax,y=ay)
a02 = (Dxx(x=ax, y=ay) -  2 * t * Dxy(x=ax,y=ay) + t^2 * Dyy(x=ax, y=ay))/2

fr10 = a10/a00
fr01 = a01/a00
fr02 = a02/a00

sp = 1 / (t^2 - 3*ax)

logderiv_a = 2*(fr10 +  sp *  ((t - ay*sp)* fr01 + (fr01^2 - 2*fr02) * ay))

# and a small cleanup of the previous computation for the point b
# newer notation and mb cleaner outlook on a circuit

b00 = D(x=bx, y=by)
b10 = Dy(x=bx, y=by)
b01 = Dx(x=bx, y=by) - t*b10    # this was called Dxp - zx Dyp

# frac = (Dx - t*Dy) / (EQx - t*EQy)
# deriv = Dy - frac * EQy

deriv_b = b10 + b01 * (2*by)/(3*bx^2 + 2*t*by)

logderiv_b = deriv_b / b00

lhs = logderiv_a + logderiv_b


acc = 0


for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print("Such Puiseux series much wow: {b}".format(b = (lhs == rhs)))

Such Puiseux series much wow: True


In [178]:
# Now, real Liam Eagen's argument for coinciding points:

pts = []

for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = -(2*a)

ax, ay = a.xy()
bx, by = b.xy()

z = linefunc(a, a)
z = reduce(z)

t = -diff(z, x)

Dx = diff(D, x)
Dy = diff(D, y)

# ax dt/dz = -1

dtdz = -1/ax      # this essentially comes from the second derivative of EQ


dydx = (3*x^2) / (2*y)

dydx_b = dydx(x=bx, y=by)

dbdx = (1 + bx * dtdz)/(dydx_b - t)   # (1/ax) (ax - bx)(2*by)/(3*bx^2 - 2*by*t) - on parity with the paper

# dD/dx = Dx + dydx Dy
dDdx = Dx + dydx * Dy

dDdx_a = dDdx(x=ax, y=ay)
dDdx_b = dDdx(x=bx, y=by)

c2 = 2*by*(ax-bx)/(3*bx^2 - 2*t*by)

lhs = c2 * dDdx_b / D(x=bx, y=by) - (c2 + 2*t) * dDdx_a / D(x=ax, y=ay)

rhs = 0
for pt in pts:
    ptx, pty = pt.xy()
    rhs += (ax - ptx)/(-z(x=ptx, y=pty))

print("Liam's formulas work:{b}".format(b=(lhs==rhs)))

Liam's formulas work:True


In [105]:
pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -a-b

z = linefunc(a,b)
z = reduce(z)

Dx = diff(D,x)
Dy = diff(D,y)


t = diff(z, x)

acc = 0

for pt in [a,b,c]:
    ptx, pty = pt.xy()
    d00 = D(x=ptx, y=pty)
    d10 = Dy(x=ptx, y=pty)
    d01 = Dx(x=ptx, y=pty) - t*d10
    deriv = d10 + d01 * (2*pty)/(3*ptx^2 + 2*t*pty)
    acc += deriv/d00

lhs = acc

acc = 0

for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print(lhs)
print(rhs)

13285663120324467967862642209245345986190024243767896025012081799683985507056
13285663120324467967862642209245345986190024243767896025012081799683985507056


In [293]:
# we want to emulate circom-ish behavior
# circuit builder implementation
# call initialize first
# always give advices in "if WITH_WITNESS", or you risk using runtime variable



def initialize(warnings=True, with_witness=True):
    global WARNINGS
    WARNINGS = True
    global SIGNALS
    SIGNALS = []
    global CONSTRAINTS
    CONSTRAINTS = []
    global WITH_WITNESS
    WITH_WITNESS = True
    global EQUATIONS
    EQUATIONS = []


class Equation:
    # ("mul", [...])
    # ("add", [...])
    # ("const", felt)
    # ("signal", signal_id)
    
    def __init__(self, kind, stuff):
        if kind == "mul" or kind == "add":
            self.kind = kind
            self.stuff = stuff
            for s in stuff:
                assert s < len(EQUATIONS)
        if kind == "const":
            self.kind = kind
            self.stuff = stuff
            assert stuff in F
        if kind == "signal":
            self.kind = kind
            self.stuff = stuff.id
            assert stuff.id < len(SIGNALS)
        self.id = len(EQUATIONS)
        EQUATIONS.append(self)

    def __add__(self, other):
        if type(other) == Integer:
            other = F(other)
        if other in F:
            other = Equation("const", other)
        if type(other) == Signal:
            other = Equation("signal", other)
        kind = "add"
        stuff = [self.id, other.id]
        return Equation(kind, stuff)
    
    # __iadd__ will work as expected here
    
    def __mul__(self, other):
        if type(other) == Integer:
            other = F(other)
        if other in F:
            other = Equation("const", other)
        if type(other) == Signal:
            other = Equation("signal", other)
        kind = "mul"
        stuff = [self.id, other.id]
        return Equation(kind, stuff)
    
    def __radd__(self, other):
        return self+other
    
    def __rmul__(self, other):
        return self*other
    
    def __sub__(self, other):
        return self + other * (-1)
    
    def __neg__(self):
        return self*(-1)
    
    def __rsub__(self, other):
        return self*(-1) + other
    
    def __eq__(self,other):
        (self-other).constrain()
    
    def compute(self):
        if self.kind == "const":
            return self.stuff
        if self.kind == "signal":
            return ~SIGNALS[self.stuff]
        if self.kind == "add":
            return EQUATIONS[self.stuff[0]].compute()+EQUATIONS[self.stuff[1]].compute()
        if self.kind == "mul":
            return EQUATIONS[self.stuff[0]].compute()*EQUATIONS[self.stuff[1]].compute()

    def __repr__(self):
        if self.kind == "const":
            return str(Integer(self.stuff) - (0 if Integer(self.stuff) < F.order()//2 + 1 else p))
        if self.kind == "signal":
            return SIGNALS[self.stuff].__repr__()
        if self.kind == "add":
            return "{l} + {r}".format(l = EQUATIONS[self.stuff[0]].__repr__(), r = EQUATIONS[self.stuff[1]].__repr__())
        if self.kind == "mul":
            return "({l}) * ({r})".format(l = EQUATIONS[self.stuff[0]].__repr__(), r = EQUATIONS[self.stuff[1]].__repr__())
        
    def constrain(self):
        CONSTRAINTS.append(self.id)
    
class Signal:
    def __init__(self, n=None):
        self.value = None
        self.id = len(SIGNALS)
        self.name = n
        SIGNALS.append(self)
    
    def __repr__(self):
        return self.name
    
    def __lshift__(self, other):                  # <--
        assert self.value is None, "FATAL ERROR: attempt to assign signal value twice"
#        if not WITH_WITNESS:
#            return
        if type(other) == Integer:
            other = F(other)
        if other in F:
            self.value = lambda : other
            return
        if type(other) == Signal:
            self << (lambda : ~other)
        if type(other) == Equation:
            self.value = other.compute # this is a function!!
                
    def __eq__(self, other):                     # ===
        (other - self).constrain()
        
    def __le__(self, other):                     # <==
        if type(other) == Equation or type(other) == Signal:
            self == other
            self << other
        else:
            assert False, "FATAL ERROR: can not use <= with variable on rhs, use << instead if you want to give an advice or supply input"

    # ~     extract value from a signal, throws if value is None; must be used for advices
    # this will not always lead to an error at runtime, so make sure to run circuit without values first
    def __invert__(self):
        assert WITH_WITNESS, "FATAL ERROR: trying to extract values in circuit construction mode"
        assert self.value is not None, "FATAL ERROR: trying to read runtime variable at compile-time"
        return (self.value)()
        
    def __add__(self, other):
        return Equation("signal", self) + other
    
    def __mul__(self, other):
        return Equation("signal", self) * other
    
    def __radd__(self, other):
        return self + other
    
    def __rmul__(self, other):
        return self*other
    
    def __iadd__(self, other):
        assert False, "FATAL ERROR: signals are immutable"
    
    def __imul__(self, other):
        assert False, "FATAL ERROR: signals are immutable"
        
    def __sub__(self, other):
        return self + (-1)*other

    def __isub__(self, other):
        assert False, "FATAL ERROR: signals are immutable"

    def __neg__(self):
        return self*(-1)    

    
def alloc(*args, n=None, prefix=''):
    
    if len(args) == 0:
        return Signal(n)
    
    ret = []
    if len(args) == 1:
        for i in range(args[0]):
            ret.append(Signal(n = n+prefix+'['+str(i)+']'))
    else:
        last = args[0]
        for i in range(last):
            ret.append(alloc(*args[1:], n=n, prefix=prefix+'['+str(i)+']'))
    return ret
    

In [305]:
# We assume we are actually secretly in fixed-base scenario with randomly chosen base
# Otherwise the protocol will sometimes be incomplete, at least according to Liam
# (I'm not sure why normalizing is necessary)

# ---- RANDOM SETUP ----

SIZE = 10
SC_BITSIZE = 128
SC_TRITSIZE = ceil(128*log(2, 3)) #81


scalars = []
points = []
for _ in range(SIZE):
    scalars.append(randrange(2**128)-2**127)
    points.append(gen_random_point(E))

# ---- THE REST IS IN CIRCUIT, AND WE PRETEND THAT CHALLENGES ARE GIVEN BY THE VERIFIER ----
# convention for signals - I will use SnakeCase for them
# in the current version it is very easy to make a mistake; I really need a separation of
# runtime / construction variables, but this is a separate project, for now this should be fine
# with enough testing

# also, I do not have any particular arithmetization here - I could in theory allow only specific equations
# but I am mainly aiming at generic arithmetization for this prototype
# I will write this code using "R1CS", which has some consequences, it might be that PlonKish is a bit larger
# I am not sure. Though realistically we will be most likely using it in PlonK (unless anyone is willing
# to use my Groth16 + lookups lol).


initialize()

# Trit decomposition of the scalars

Scalars = alloc(SIZE, n="Scalars")
ScTrits = alloc(SIZE, SC_TRITSIZE, n="ScTrits")

for i in range(SIZE):
    Scalars[i] << scalars[i]

for i in range(SIZE):
    trunc = Integer(~Scalars[i])
    if trunc > p//2:
        trunc -= p
    for j in range(SC_TRITSIZE):
        trit = trunc % 3
        if trit == 2:
            trit -= 3

        ScTrits[i][j] << trit
        Tmp = alloc(n="_[i][j]")
        Tmp <= ScTrits[i][j] * ScTrits[i][j]
        Tmp*ScTrits[i][j] == ScTrits[i][j]
        
        trunc = -(trunc-trit) // 3   # we are doing base -3 decomposition
        
    acc = 0
    for j in range(SC_TRITSIZE):
        acc += ScTrits[i][j] * (-3)^j
    
    acc == Scalars[i]

print(EQUATIONS[-1].compute())

# This ensures that our scalar is the sum of (-1,0,1)3^i, with i = 0..80.

    
Points = alloc(SIZE+2, 2, n="Points")

for i in range(SIZE):
    px, py = points[i].xy()
    Points[i][0] << px
    Points[i][1] << py

for i in range(SIZE):
    _Px = Points[i][0]
    _Py = Points[i][1]
    YSq = alloc(n="YSq[i]")
    YSq <= (_Py * _Py) 
    XSq = alloc(n="XSq[i]")
    XSq <= (_Px * _Px)
    XCb = alloc(n="XCb[i]")
    XCb <= (_Px * XSq)
    
    XCb + 3 == YSq           # constrain that points are on curve
    
print(EQUATIONS[-1].compute())

DComm0 = alloc(SC_TRITSIZE, (SIZE+2)//2 + 1, n="Dcomm0")    # SIZE+2 is the amount of points in each row 
DComm1 = alloc(SC_TRITSIZE, ((SIZE+2)-1)//2 + 1, n="Dcomm1") 

# ---- prover commits to D: ----

points_table =  [[None for j in range(SIZE+2)] for i in range(SC_TRITSIZE)]


acc = E(0)
for i in range(SC_TRITSIZE):
    acc = -(3*acc)
    points_table[-1-i][SIZE] = acc    # we go from the top bits
    for j in range(SIZE):
        trit = Integer(~ScTrits[j][-1-i])
        if trit > 1:
            trit -= F.order()
#        assert trit in [-1, 0, 1]
        pt = trit*points[j]
        points_table[-1-i][j] = pt    # btw none of this would have happened if we would write bits
        acc += pt
    points_table[-1-i][SIZE+1] = -acc

# beware the wraparound; sage converts field coefficients from a field to the integer scalar automatically
# how nice of it...
# sadly it breaks anything because it treats negative values as positive :<

D = []

for i in range(SC_TRITSIZE):
    l = []
    for pt in points_table[i]:
        if pt != E(0):
            l.append(pt)
    tmp = cartier(l)
    D.append(tmp // tmp(x=0,y=0))  # Normalize by lowest coefficient.
                                   # With negligible probability, this will not work
                                   # Though it only affects completeness, not soundness
                                   # For fixed-base random basis MSM it is definitely not an issue
    
# I will use ugly trick again

coeffs_D0 = []
coeffs_D1 = []
for i in range(SC_TRITSIZE):
    tmp_poly = y*D[i]
    coeffs_D0.append(tmp_poly.list()[1].list())
    while len(coeffs_D0[-1]) < len(DComm0):
        coeffs_D0[-1].append(0)
    coeffs_D1.append(tmp_poly.list()[2].list())
    while len(coeffs_D1[-1]) < len(DComm1):
        coeffs_D1[-1].append(0)
        
# ---- and now prover actually commits to D: ----

for i in range(SC_TRITSIZE):
    DComm0[i] << coeffs_D0[i]
    DComm1[i] << coeffs_D1[i]

0
0
[1, 3998323545570997444631250062439575693723331727781417002441436645132682045853, 13348852193750329910350358210506691038609707436366000021134197128234394059718, 10996008046362909479722242753913022229334167743060790891964829133810716261690, 5207589946405459093434574764459367526186209265935090951294093332960107903114, 15617454328755373929738001400592669480318853695170209367087923801521137412538, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[649181552773467447935017139018900450492815479061387086447874070123319872313, 20149255581399687715118917458940184117484496697135953509964949784522001344867, 12258107353818200145983148177589059016210192644024409579255723320363466209455, 17659165837465983880760485745009357555302470156001607160305342811400073545876, 33220692442711000194023643644850555703451055771728435366546

In [271]:
poly = x^3 + x*y
(poly*y).list()

[0, x^3, x]