In [39]:
p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
F = GF(p)
E = EllipticCurve(F, [0,3])
Px.<x> = PolynomialRing(F)
Pxy.<y> = PolynomialRing(Px)

EQ = y^2 - x^3 - 3


In [40]:
# we will only use computations with polynomials and not use Sage's native divisor machinery
# this will define a wrapper class that we will use to work with regular functions on a curve
# we probably could use standard Sage construction, but this would obfuscate matters a bit


# this reduces the polynomial to the form only having 0, 1 powers of y;
# it is a canonical representative of the quotient ring F[x,y]/(y^2 = x^3 + 3)

def normal_form(poly): 
    acc = 0
    replace = y^2 - EQ
    
    poly2 = y*poly # ugly hack to ensure it does not get coerced to poly(x)
    l = poly2.list()
    for i in range(len(l)-1):
        acc += l[i+1] * replace^(i//2) * y^(i%2)
    return acc


def hcoeff(poly): #computes highest coeff. of a normal form and sets it to 1
    poly2 = y*poly # same ugly hack
    l = poly2.list()
    if len(l) == 2:
        return l[1].list()[-1]
    
    if len(l) == 3:
        l0 = (l[1]*x).list() # and once again to ensure they are not coerced to integers
        l1 = (l[2]*x).list() # I hate this so much :(
        if 2*len(l1) + 3 > 2*len(l0):
            return l1[-1]
        else:
            return l0[-1]
    print("not normal polynomials are unsupported")
    assert(False)
    
def reduce(poly):
    tmp = normal_form(poly)
    return (tmp // hcoeff(tmp))

In [41]:

# this returns a line passing through a pair of points. will return None if both points are at infinity

def linefunc(a, b):
    temp = None
    if a == E(0):
        temp = b
    if b == E(0):
        temp = a
    if temp == E(0):
        return None
    if temp is not None:
        line = x
        tx, ty = temp.xy()
        return line - line(x=tx, y=ty)
    ax, ay = a.xy()
    bx, by = b.xy()
    if ax == bx:
        if ay == by:
            line = - 3 * ax^2 * x + 2 * ay * y
        else:
            line = x
    else:
        line = (bx-ax) * y - (by-ay) * x    
    return line - line(x=ax, y=ay)

class PointReplacer:
    
    def __init__(self, inputs, output, poly):
        self.inputs = inputs
        self.output = output
        self.poly = poly
    
    def __mul__(self, other):
        inputs = self.inputs + other.inputs # this merges the lists of inputs
        A = self.output
        B = other.output
        output = A+B # this adds up outputs
        Ax, Ay = A.xy()
        Bx, By = B.xy()
        
        numerator = self.poly * other.poly * linefunc(-A, -B)
        num_res = numerator(x=Ax) 
        eq_res = EQ(x=Ax)                  # numerator vanishes on the vertical line x=Ax mod eq
        
        quot = num_res//eq_res
        
        numerator = normal_form((numerator - quot*EQ)//(x-Ax))
        num_res = numerator(x=Bx)
        eq_res = EQ(x=Bx)
        
        quot = num_res//eq_res
        
        numerator = normal_form((numerator - quot*EQ)//(x-Bx))
        return(PointReplacer(inputs, output, numerator))
        
def replacer_from_point(point):
    return(PointReplacer([point], -point, linefunc(point, -point)))

# computes the regular function with divisor = given collection of points
# throws if the sum is nonzero


def list_prod(factors):
    if len(factors)==0:
        return 1
    if len(factors)==1:
        return factors[0]
    factors_new = []
    if len(factors)%2 == 1:
        factors_new.append(factors[-1])
    for i in range(len(factors)//2):
        factors_new.append(normal_form(factors[2*i] * factors[2*i + 1]))
    return list_prod(factors_new)

def cartier(points):
    polys = []
    pts = points.copy()

    # the procedure below does not work for 2 points because of how sage treats multivariable polynomials
    # this exception shouldn't be there in the real code
    if len(pts) == 2:
        assert pts[0] == -pts[1]
        return linefunc(pts[0], pts[1])
    
    # initialize trivial replacers
    for i in range(len(pts)):
        pts[i] = replacer_from_point(pts[i])    

    
    while len(pts) > 1:
        tmp = []
        i = 0
        for pt in pts:
            if pt.output == E(0):
                polys.append(pt.poly)
            else:
                tmp.append(pt)
        pts = []
        if len(tmp)%2 == 1:
            pts.append(tmp[-1])
        for i in range(len(tmp)//2):
            pts.append(tmp[2*i]*tmp[2*i+1])
    
    if len(pts)>0:
        assert(pts[0].output == E(0))
        polys.append(pts[0].poly)

    return list_prod(polys)

In [42]:
def gen_random_felt(F):
        m = randrange(F.order())
        return F(m)
    
def gen_random_point(E):
    while True:
        m = gen_random_felt(E.base())
        try:
            n = E.lift_x(m)
            return n
        except:
            continue

In [43]:
# choosing non-standard norm, checking in a point; refer to page 5 of the paper https://eprint.iacr.org/2022/596.pdf

# multiplicative version

pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -(a+b)

ax, ay = a.xy()
bx, by = b.xy()
cx, cy = c.xy()


z = linefunc(a, b)
z = reduce(z)

lhs = D(x=ax,y=ay)*D(x=bx,y=by)*D(x=cx,y=cy)

prod = 1

for pt in pts:
    ptx, pty = pt.xy()
    prod *= -z(x=ptx,y=pty)

rhs = prod

print("D(A)D(B)D(C) = product over points: {b}".format(b=(rhs==lhs)))

D(A)D(B)D(C) = product over points: True


In [44]:
# computing log-derivatives, doing log-derivatives check

pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -(a+b)

ax, ay = a.xy()
bx, by = b.xy()
cx, cy = c.xy()

z = linefunc(a, b)
z = reduce(z)

# now we need to compute d/dz D *along* the curve
# the answer is in Liam's paper, sanity check here

# we have differentials

Dx = diff(D, x)
Dy = diff(D, y)

EQx = diff(EQ, x)
EQy = diff(EQ, y)

zx = diff(z, x)
zy = 1 # diff(z, y)

l = []

for (sx, sy) in [a.xy(), b.xy(), c.xy()]:
    Dxp = Dx(x=sx,y=sy)
    Dyp = Dy(x=sx,y=sy)
    
    EQxp = EQx(x=sx,y=sy)
    EQyp = EQy(x=sx,y=sy)
    
    # zx, zy are constants already
    # we now need to find such t that (dD - t dEQ) is collinear with dz
    # (Dxp - t EQxp)/(Dyp - t EQyp) = zx/zy = zx
    # Dxp - t EQxp = zx Dyp - t zx EQyp
    # Dxp - zx Dyp = t (EQxp - zx EQ yp)
    # t = (Dxp - zx Dyp) / (EQxp - zx EQyp)    -- note that this is divisible unless line is tangent

    t = (Dxp - zx * Dyp)/(EQxp - zx * EQyp)
    
    # now we need to find proportionality coeff between the result and dz
    
    deriv = (Dyp - t * EQyp) # / zy
    
    l.append(deriv/ D(x=sx,y=sy))

lhs = l[0] + l[1] + l[2]

acc = 0

for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print("Our description is correct: {b}".format(b = (lhs == rhs)))

Our description is correct: True


In [45]:
# This is testing Liam Eagen's protocol for the case when z is tangent to the point A
# This allows us to query only 2 points, not three - A, A, -2A
# But log derivative becomes annoying
# Liam's paper suggests some formula, which is based on the idea that we could vary not only z, but also the slope
# The drawback is that we get nontrivial numerators, equal to x(A)-x(P) for a point P

# I have devised a different formula, which does not vary the slope - but it, sadly, involves second derivatives
# So Liam's formula is better.

# I'm leaving it here as a reminder for my struggle (and that I shouldn't drop third order terms in such problem
# even if it looks like I could :) ) 

# -------------------- Let's go!!! --------------------


pts = []

for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = -(2*a)

ax, ay = a.xy()
bx, by = b.xy()

z = linefunc(a, a)
z = reduce(z)

t = diff(z, x)

Dx = diff(D, x)
Dy = diff(D, y)
Dxx = diff(Dx, x)
Dxy = diff(Dx, y)
Dyy = diff(Dy, y)

a00 = D(x=ax, y=ay)
a10 = Dy(x=ax, y=ay)
a01 = Dx(x=ax, y=ay) - t*Dy(x=ax,y=ay)
a02 = (Dxx(x=ax, y=ay) -  2 * t * Dxy(x=ax,y=ay) + t^2 * Dyy(x=ax, y=ay))/2

fr10 = a10/a00
fr01 = a01/a00
fr02 = a02/a00

sp = 1 / (t^2 - 3*ax)

logderiv_a = 2*(fr10 +  sp *  ((t - ay*sp)* fr01 + (fr01^2 - 2*fr02) * ay))

# and a small cleanup of the previous computation for the point b
# newer notation and mb cleaner outlook on a circuit

b00 = D(x=bx, y=by)
b10 = Dy(x=bx, y=by)
b01 = Dx(x=bx, y=by) - t*b10    # this was called Dxp - zx Dyp

# frac = (Dx - t*Dy) / (EQx - t*EQy)
# deriv = Dy - frac * EQy

deriv_b = b10 + b01 * (2*by)/(3*bx^2 + 2*t*by)

logderiv_b = deriv_b / b00

lhs = logderiv_a + logderiv_b


acc = 0


for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print("Such Puiseux series much wow: {b}".format(b = (lhs == rhs)))

Such Puiseux series much wow: True


In [46]:
# Now, real Liam Eagen's argument for coinciding points:

pts = []

for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = -(2*a)

ax, ay = a.xy()
bx, by = b.xy()

z = linefunc(a, a)
z = reduce(z)

t = -diff(z, x)

Dx = diff(D, x)
Dy = diff(D, y)

# ax dt/dz = -1

dtdz = -1/ax      # this essentially comes from the second derivative of EQ


dydx = (3*x^2) / (2*y)

dydx_b = dydx(x=bx, y=by)

dbdx = (1 + bx * dtdz)/(dydx_b - t)   # (1/ax) (ax - bx)(2*by)/(3*bx^2 - 2*by*t) - on parity with the paper

# dD/dx = Dx + dydx Dy
dDdx = Dx + dydx * Dy

dDdx_a = dDdx(x=ax, y=ay)
dDdx_b = dDdx(x=bx, y=by)

c2 = 2*by*(ax-bx)/(3*bx^2 - 2*t*by)

lhs = c2 * dDdx_b / D(x=bx, y=by) - (c2 + 2*t) * dDdx_a / D(x=ax, y=ay)

rhs = 0
for pt in pts:
    ptx, pty = pt.xy()
    rhs += (ax - ptx)/(-z(x=ptx, y=pty))

print("Liam's formulas work:{b}".format(b=(lhs==rhs)))

Liam's formulas work:True


In [47]:
pts = []
for _ in range(10):
    pts.append(gen_random_point(E))

acc = E(0)
for pt in pts:
    acc += pt
pts.append(-acc)

D = cartier(pts)
D = reduce(D)

a = gen_random_point(E)
b = gen_random_point(E)
c = -a-b

z = linefunc(a,b)
z = reduce(z)

Dx = diff(D,x)
Dy = diff(D,y)


t = diff(z, x)

acc = 0

for pt in [a,b,c]:
    ptx, pty = pt.xy()
    d00 = D(x=ptx, y=pty)
    d10 = Dy(x=ptx, y=pty)
    d01 = Dx(x=ptx, y=pty) - t*d10
    deriv = d10 + d01 * (2*pty)/(3*ptx^2 + 2*t*pty)
    acc += deriv/d00

lhs = acc

acc = 0

for pt in pts:
    ptx, pty = pt.xy()
    acc += 1/(-z(x=ptx,y=pty))
    
rhs = acc

print(lhs)
print(rhs)

6671819443761759626519327952020306558381434144537570644755055548012546081268
6671819443761759626519327952020306558381434144537570644755055548012546081268


In [48]:
# we want to emulate circom-ish behavior
# circuit builder implementation
# call initialize first
# always give advices in "if WITH_WITNESS", or you risk using runtime variable



def initialize(warnings=True, with_witness=True):
    global WARNINGS
    WARNINGS = True
    global SIGNALS
    SIGNALS = []
    global CONSTRAINTS
    CONSTRAINTS = []
    global WITH_WITNESS
    WITH_WITNESS = True
    global EQUATIONS
    EQUATIONS = []


class Equation:
    # ("mul", [...])
    # ("add", [...])
    # ("const", felt)
    # ("signal", signal_id)
    
    def __init__(self, kind, stuff):
        if kind == "mul" or kind == "add":
            self.kind = kind
            self.stuff = stuff
            for s in stuff:
                assert s < len(EQUATIONS)
        if kind == "const":
            self.kind = kind
            self.stuff = stuff
            assert stuff in F
        if kind == "signal":
            self.kind = kind
            self.stuff = stuff.id
            assert stuff.id < len(SIGNALS)
        self.id = len(EQUATIONS)
        EQUATIONS.append(self)

    def __add__(self, other):
        if type(other) == Integer:
            other = F(other)
        if other in F:
            other = Equation("const", other)
        if type(other) == Signal:
            other = Equation("signal", other)
        kind = "add"
        stuff = [self.id, other.id]
        return Equation(kind, stuff)
    
    # __iadd__ will work as expected here
    
    def __mul__(self, other):
        if type(other) == Integer:
            other = F(other)
        if other in F:
            other = Equation("const", other)
        if type(other) == Signal:
            other = Equation("signal", other)
        kind = "mul"
        stuff = [self.id, other.id]
        return Equation(kind, stuff)
    
    def __radd__(self, other):
        return self+other
    
    def __rmul__(self, other):
        return self*other
    
    def __sub__(self, other):
        return self + other * (-1)
    
    def __neg__(self):
        return self*(-1)
    
    def __rsub__(self, other):
        return self*(-1) + other
    
    def __eq__(self,other):
        (self-other).constrain()
    
    def compute(self):
        if self.kind == "const":
            return self.stuff
        if self.kind == "signal":
            return ~SIGNALS[self.stuff]
        if self.kind == "add":
            return EQUATIONS[self.stuff[0]].compute()+EQUATIONS[self.stuff[1]].compute()
        if self.kind == "mul":
            return EQUATIONS[self.stuff[0]].compute()*EQUATIONS[self.stuff[1]].compute()

    def __repr__(self):
        if self.kind == "const":
            return str(Integer(self.stuff) - (0 if Integer(self.stuff) < F.order()//2 + 1 else p))
        if self.kind == "signal":
            return SIGNALS[self.stuff].__repr__()
        if self.kind == "add":
            return "{l} + {r}".format(l = EQUATIONS[self.stuff[0]].__repr__(), r = EQUATIONS[self.stuff[1]].__repr__())
        if self.kind == "mul":
            return "({l}) * ({r})".format(l = EQUATIONS[self.stuff[0]].__repr__(), r = EQUATIONS[self.stuff[1]].__repr__())
        
    def constrain(self):
        CONSTRAINTS.append(self.id)
    
class Signal:
    def __init__(self, n=None):
        self.value = None
        self.id = len(SIGNALS)
        self.name = n
        SIGNALS.append(self)
        self.known_value = None
    
    def __repr__(self):
        return self.name
    
    def __lshift__(self, other):                  # <--
        assert self.value is None, "FATAL ERROR: attempt to assign signal value twice"
#        if not WITH_WITNESS:
#            return
        if type(other) == Integer:
            other = F(other)
        if other in F:
            self.value = lambda : other
            return
        if type(other) == Signal:
            self << (lambda : ~other)
        if type(other) == Equation:
            self.value = other.compute # this is a function!!
                
    def __eq__(self, other):                     # ===
        (other - self).constrain()
        
    def __le__(self, other):                     # <==
        if type(other) == Equation or type(other) == Signal:
            self == other
            self << other
        else:
            assert False, "FATAL ERROR: can not use <= with variable on rhs, use << instead if you want to give an advice or supply input"

    # ~     extract value from a signal, throws if value is None; must be used for advices
    # this will not always lead to an error at runtime, so make sure to run circuit without values first
    def __invert__(self):
        assert WITH_WITNESS, "FATAL ERROR: trying to extract values in circuit construction mode"
        assert self.value is not None, "FATAL ERROR: trying to read runtime variable at compile-time"
        if self.known_value is None:
            self.known_value = (self.value)()
        return self.known_value
        
    def __add__(self, other):
        return Equation("signal", self) + other
    
    def __mul__(self, other):
        return Equation("signal", self) * other
    
    def __radd__(self, other):
        return self + other
    
    def __rmul__(self, other):
        return self*other
    
    def __iadd__(self, other):
        assert False, "FATAL ERROR: signals are immutable"
    
    def __imul__(self, other):
        assert False, "FATAL ERROR: signals are immutable"
        
    def __sub__(self, other):
        return self + (-1)*other

    def __isub__(self, other):
        assert False, "FATAL ERROR: signals are immutable"

    def __neg__(self):
        return self*(-1)    

    
def alloc(*args, n=None, prefix=''):
    
    if len(args) == 0:
        return Signal(n)
    
    ret = []
    if len(args) == 1:
        for i in range(args[0]):
            ret.append(Signal(n = n+prefix+'['+str(i)+']'))
    else:
        last = args[0]
        for i in range(last):
            ret.append(alloc(*args[1:], n=n, prefix=prefix+'['+str(i)+']'))
    return ret
    

In [49]:
def OnCurve(point, n=None):
    assert len(point) == 2
    assert type(point[0]) == Signal
    assert type(point[1]) == Signal
    XSq = alloc(n=n+".XSq")
    XCb = alloc(n=n+".XCb")
    YSq = alloc(n=n+".YSq")
    
    XSq <= point[0]*point[0]
    XCb <= point[0]*XSq
    YSq <= point[1]*point[1]
    
    YSq - XCb - 3 == 0

    
def AddIncomplete(a, b, n=None):
    assert len(a) == 2
    assert len(b) == 2
    assert type(a[0]) == Signal
    assert type(a[1]) == Signal
    assert type(b[0]) == Signal
    assert type(b[1]) == Signal
    
    slope = alloc(n=n+".slope")
    slope << (~b[1] - ~a[1]) / (~b[0] - ~a[0])
    slope*(b[0] - a[0]) == b[1] - a[1] # this breaks if b = a, leading to arbitrary result
    
    c = alloc(2, n=n+".output")
    
    c[0] <= slope*slope - a[0] - b[0]
    c[1] <= slope*(a[0]-c[0]) - a[1]
    
    return c

def Double(a, n=None):
    assert len(a) == 2
    assert type(a[0]) == Signal
    assert type(a[1]) == Signal

    slope = alloc(n=n+".slope")
    slope << 3*(~a[0])^2 / (2*(~a[1]))
    
    XSq = alloc(n=n+".XSq")
    
    XSq <= a[0]*a[0]
    
    2*slope*a[1] == 3*XSq
    
    c = alloc(2, n=n+".output")
    c[0] <= slope*slope - 2*a[0]
    c[1] <= slope*(a[0]-c[0])-a[1]
    
    return c

def MinusDouble(a, n=None): # 
    assert len(a) == 2
    assert type(a[0]) == Signal
    assert type(a[1]) == Signal

    slope = alloc(n=n+".slope")
    slope << 3*(~a[0])^2 / (2*(~a[1]))
    
    XSq = alloc(n=n+".XSq")
    
    XSq <= a[0]*a[0]
    
    2*slope*a[1] == 3*XSq
    
    c = alloc(2, n=n+".output")
    c[0] <= slope*slope - 2*a[0]
    c[1] <= slope*(c[0]-a[0])+a[1]
    
    return c, slope
    
def TritsDecompose(Scalar, n=None): # for convenience, it will also return squares of the trits
    assert type(Scalar) == Signal
    
    Trits = alloc(TRITSIZE, n=n+".output")
    TritsSq = alloc(TRITSIZE, n=n+".squares")
    
    trunc = Integer(~Scalar)
    if trunc > p//2:
        trunc -= p
        
    acc=0
    for i in range(TRITSIZE):
        trit = trunc % 3
        if trit>1:
            trit -= 3
            
        Trits[i] << trit
        TritsSq[i] <= Trits[i]*Trits[i]
        Trits[i]*TritsSq[i] == Trits[i]
        
        trunc = -(trunc-trit)//3
        
        acc += Trits[i]*(-3)^i
        
    acc == Scalar
    
    assert(Scalar.value() == acc.compute())
    
    return (Trits, TritsSq)

In [77]:
# ---- RANDOM SETUP ----

SIZE = 100
BITSIZE = 128
TRITSIZE = 1 #ceil(128*log(2, 3))


scalars = []
points = []
for _ in range(SIZE-1):
    scalars.append(F(randrange(3))-1)
    #scalars.append(randrange(2**128)-2**127)
acc=0
for i in range(SIZE-1):
    points.append(gen_random_point(E))
    acc+=scalars[i]*points[i]
points.append(acc)
scalars.append(F(-1))
    

# ---- THE REST IS IN CIRCUIT, AND WE PRETEND THAT CHALLENGES ARE GIVEN BY THE VERIFIER ----
# convention for signals - I will use SnakeCase for them
# in the current version it is very easy to make a mistake; I really need a separation of
# runtime / construction variables, but this is a separate project, for now this should be fine
# with enough testing

# also, I do not have any particular arithmetization here - I could in theory allow only specific equations
# but I am mainly aiming at generic arithmetization for this prototype
# I will write this code using "R1CS", which has some consequences, it might be that PlonKish is a bit larger
# I am not sure. Though realistically we will be most likely using it in PlonK (unless anyone is willing
# to use my Groth16 + lookups lol).


initialize()

# Trit decomposition of the scalars

Scalars = alloc(SIZE, n="Scalars")

for i in range(SIZE):
    Scalars[i] << scalars[i]


ScTrits = []
ScTritsAbs = []

for i in range(SIZE):
    trits, trits_abs = TritsDecompose(Scalars[i], n="ScTrits[{i}]".format(i=i))
    ScTrits.append(trits)
    ScTritsAbs.append(trits_abs)
    
Scalars2 = alloc(SIZE, n="Scalars2") # this will contain the value obtained from scalar by doing
                                     #abs(trits)
for i in range(SIZE):
    acc = 0
    for j in range(TRITSIZE):
        acc += (-3)^j * ScTritsAbs[i][j]
    Scalars2[i] <= acc

ScPlus = []
ScMinus = []
    
for i in range(SIZE):
    ScPlus.append((Scalars2[i] + Scalars[i])*(F(1)/F(2)))
    ScMinus.append((Scalars2[i] - Scalars[i])*(F(1)/F(2)))
    
    assert ScPlus[-1].compute() - ScMinus[-1].compute() == Scalars[i].value()
    
# this will filter trits with +1 and -1 respectively
    
Points = alloc(SIZE, 2, n="Points")

for i in range(SIZE):
    px, py = points[i].xy()
    Points[i][0] << px
    Points[i][1] << py

for i in range(SIZE):
    OnCurve(Points[i], n="OnCurve[{i}]".format(i=i))

# Constrain that points are on curve

DComm0 = alloc(TRITSIZE, (SIZE+2)//2 + 1, n="Dcomm0")    # SIZE+2 is the amount of points in each row 
DComm1 = alloc(TRITSIZE, ((SIZE+2)-1)//2 + 1, n="Dcomm1") 

# ---- prover commits to D: ----

# this should be moved to a separate helper function
# notice that these points_table are not signals, we
# will not be using them explicitly in the verifier check

points_table =  [[None for j in range(SIZE+2)] for i in range(TRITSIZE)]

acc = E(0)
for i in range(TRITSIZE):
    acc = -(3*acc)
    points_table[-1-i][SIZE] = acc    # we go from the top bits
    for j in range(SIZE):
        trit = Integer(~ScTrits[j][-1-i])
        if trit > 1:
            trit -= F.order()
#        assert trit in [-1, 0, 1]
        pt = trit*points[j]
        assert pt == points[j] or pt == E(0) or pt == -points[j]
        points_table[-1-i][j] = pt    # btw none of this would have happened if we would write bits
        acc += pt
    points_table[-1-i][SIZE+1] = -acc


# beware the wraparound; sage converts field coefficients from a field to the integer scalar automatically
# how nice of it...
# sadly it breaks everything because it treats negative values as positive :<


D = []
for i in range(TRITSIZE):
    l = []
    for pt in points_table[i]:
        if pt != E(0):
            l.append(pt)
    tmp = cartier(l)
    D.append(tmp)  
    
# I will use ugly trick again to extract the coefficients
# They are also padded with 0s 

coeffs_D0 = []
coeffs_D1 = []
for i in range(TRITSIZE):
    tmp_poly = y*D[i]
    coeffs_D0.append(tmp_poly.list()[1].list())
#    assert len(coeffs_D0[-1]) <= len(DComm0[0]), "something is wrong with the padding"
    while len(coeffs_D0[-1]) < len(DComm0[0]):
        coeffs_D0[-1].append(0)
    if len(tmp_poly.list())>2:
        coeffs_D1.append(tmp_poly.list()[2].list())
    else:
        coeffs_D1.append([])
#    assert len(coeffs_D1[-1]) <= len(DComm1[0]), "something is wrong with the padding"    
    while len(coeffs_D1[-1]) < len(DComm1[0]):
        coeffs_D1[-1].append(0)
        
# ---- and now prover actually commits to D: ----

for i in range(TRITSIZE):
    for j in range(len(DComm0[0])):
        DComm0[i][j] << coeffs_D0[i][j]
    for j in range(len(DComm1[0])):
        DComm1[i][j] << coeffs_D1[i][j]
    
# ---- now we need to constrain it. we pretend that verifier sends us challenge ----
# ---- this should be replaced by Fiat-Shamir, but we are only mock proving right now ----

a = alloc(2, n="CHALLENGE_POINT_a")

tmp_a = gen_random_point(E)

ax,ay = tmp_a.xy()

a[0] << ax
a[1] << ay

(b, slope) = MinusDouble(a, n="MinusDouble")

# TESTING
test_b = -2*tmp_a
assert test_b.xy() == (~b[0], ~b[1])

offset = alloc(n="offset")
offset <= (-1)*(a[1] - slope*a[0])

# TESTING
line = reduce(linefunc(tmp_a,tmp_a))
assert line == (y - ~slope*x + ~offset)

# z = y + slope*x - offset
# offset is found from z(a)=0

# we need to compute LHS and RHS
# compute RHS first (it is simpler)
# the same formula as before, but rescaled by scalars

rhs = 0
for i, Point in enumerate(Points):
    zPointP = alloc(n="zPointP[{i}]".format(i=i))
    zPointP <= (-1)*(offset + Point[1] - slope*Point[0])
    
    toAddP = alloc(n="toAddP[{i}]".format(i=i))
    toAddP << (~a[0] - ~Point[0])/(~zPointP)
    toAddP*zPointP == (~a[0] - ~Point[0])               # note that this division is unchecked; however, the numerator
                                                      # is (ax - ptx), which is nonzero as a is chosen randomly
    toAddRescaledP = alloc(n="toAddRescaled[{i}]".format(i=i)) # so there is no possibility of having 0/0 here
    toAddRescaledP <= toAddP*ScPlus[i]            # scalar might be 0, but it does not participate in this division
    
    # TESTING
    assert ((~a[0]-~Point[0])/(-line(x=~Point[0], y=~Point[1]))) == (~toAddP)

    zPointM = alloc(n="zPointM[{i}]".format(i=i))
    zPointM <= (-1)*(offset - Point[1] - slope*Point[0])   # negation of a point
    
    toAddM = alloc(n="toAddM[{i}]".format(i=i))
    toAddM << (~a[0] - ~Point[0])/(~zPointM)
    toAddM*zPointM == (~a[0] - ~Point[0])             # note that this division is unchecked; however, the numerator
                                                      # is (ax - ptx), which is nonzero as a is chosen randomly
    toAddRescaledM = alloc(n="toAddRescaled[{i}]".format(i=i)) # so there is no possibility of having 0/0 here
    toAddRescaledM <= toAddM*ScMinus[i]         # scalar might be 0, but it does not participate in this division
    
    print ((~a[0] - ~Point[0])/(-line(x=~Point[0],y=-~Point[1])))
    print (~toAddM)
    
    rhs += (toAddRescaledP + toAddRescaledM) # Scalars[i] * (ax - ptx)/(-z(x=ptx, y=pty))

    
# to compute lhs, we need to compute logderivative of every DComm

# this boils down to rewriting this piece of code:

#dtdz = -1/ax 
#dydx = (3*x^2) / (2*y)
#dydx_b = dydx(x=bx, y=by)
#dbdx = (1 + bx * dtdz)/(dydx_b - t)   # (1/ax) (ax - bx)(2*by)/(3*bx^2 - 2*by*t) - on parity with the paper
#dDdx = Dx + dydx * Dy

#dDdx_a = dDdx(x=ax, y=ay)
#dDdx_b = dDdx(x=bx, y=by)

#c2 = 2*by*(ax-bx)/(3*bx^2 - 2*t*by)

#lhs = c2 * dDdx_b / D(x=bx, y=by) - (c2 + 2*t) * dDdx_a / D(x=ax, y=ay)

logderiv = []

dtdz = alloc(n="dtdz")
dtdz << -1/(~a[0])
dtdz * a[0] == -1

dydx_a = alloc(n="dydx_a")
dydx_a << 3*(~a[0])^2 / (2*~a[1])
dydx_a * 2 * a[1] == 3 * a[0] * a[0]

dydx_b = alloc(n="dydx_b")
dydx_b << 3*(~b[0])^2 / (2*~b[1])
dydx_b * 2 * b[1] == 3 * b[0] * b[0]

dbdx = alloc(n="dbdx")
dbdx << (1 + ~b[0]*~dtdz)/(~dydx_b - ~slope)
dbdx*(dydx_b - slope) == (1 + b[0]*dtdz)

c2 = alloc(n="c2")
c2 << 2*(~b[1])*(~a[0]-~b[0])/(3*(~b[0])^2 - 2*(~slope)*~b[1])
c2*(3*b[0]*b[0] - 2*slope*b[1]) == 2*b[1]*(a[0]-b[0])

axk = alloc(len(DComm0[0]), n="axk")       # powers a_x^k
bxk = alloc(len(DComm0[0]), n="bxk")

axk[0] <= Equation("const", 1)
bxk[0] <= Equation("const", 1)

for i in range(len(DComm0[0])-1):
    axk[i+1] <= axk[i]*a[0]
    bxk[i+1] <= bxk[i]*b[0]

for i in range(TRITSIZE):
        
        D_a = alloc(n="D_a")    
        D_a_1 = alloc(len(DComm1[i]), n="D_a_1")         
        D_a_2 = alloc(len(DComm0[i]), n="D_a_2")         
        acc = 0
        for j in range(len(DComm1[i])):
            D_a_1[j] <= DComm0[i][j]+DComm1[i][j]*a[1]    # d0_k + y d1_k
        
        for j in range(len(DComm0[i])):
            D_a_2[j] <= (D_a_1[j] * axk[j] if j<len(DComm1[i]) else DComm0[i][j]*axk[j])
            acc += D_a_2[j]
        D_a <= acc                              # this computes DComm(a)

        # TESTING
        assert D[i](x=~a[0], y=~a[1]) == ~D_a
        
        D_b = alloc(n="D_b")    
        D_b_1 = alloc(len(DComm1[i]), n="D_b_1")         
        D_b_2 = alloc(len(DComm0[i]), n="D_b_2")         
        acc = 0
        for j in range(len(DComm1[i])):
            D_b_1[j] <= DComm0[i][j]+DComm1[i][j]*b[1]    # d0_k + y d1_k
        
        for j in range(len(DComm0[i])):
            D_b_2[j] <= (D_b_1[j] * bxk[j] if j<len(DComm1[i]) else DComm0[i][j]*bxk[j])
            acc += D_b_2[j]
        D_b <= acc                           

        # TESTING
        assert D[i](x=~b[0], y=~b[1]) == ~D_b
        
        Dx_a = alloc(n="Dx_a")
        acc = 0
        for j in range(len(DComm0[i])-1):
            acc += (j+1)*D_a_2[j+1]        
        Dx_a * a[0] == acc
        Dx_a << acc.compute()/(~a[0])

        # TESTING
        assert (diff(D[i], x))(x=~a[0], y=~a[1]) == ~Dx_a
        
        Dx_b = alloc(n="Dx_b")
        acc = 0
        for j in range(len(DComm0[i])-1):
            acc += (j+1)*D_b_2[j+1]        
        Dx_b * b[0] == acc
        Dx_b << acc.compute()/(~b[0])
                 
        # TESTING
        assert (diff(D[i], x))(x=~b[0], y=~b[1]) == ~Dx_b    
        
        Dy_a = alloc(n="Dy_a")    
        Dy_a_1 = alloc(len(DComm1[i]), n="Dy_a_1")          
        acc = 0
        for j in range(len(DComm1[i])):
            Dy_a_1[j] <= DComm1[i][j] * axk[j]
            acc += Dy_a_1[j]
        Dy_a <= acc
                 
        # TESTING
        assert (diff(D[i], y))(x=~a[0], y=~a[1]) == ~Dy_a
            
        Dy_b = alloc(n="Dy_b")    
        Dy_b_1 = alloc(len(DComm1[i]), n="Dy_b_1")          
        acc = 0
        for j in range(len(DComm1[i])):
            Dy_b_1[j] <= DComm1[i][j] * bxk[j]
            acc += Dy_b_1[j]
        Dy_b <= acc
        
        # TESTING
        assert (diff(D[i], y))(x=~b[0], y=~b[1]) == ~Dy_b
        
        D_a_inv = alloc(n="D_a_inv[{i}]".format(i=i))
        D_a_inv << 1/(~D_a)
        D_a_inv * D_a == 1
        
        D_b_inv = alloc(n="D_b_inv[{i}]".format(i=i))
        D_b_inv << 1/(~D_b)                          # this actually could be skipped, because if D does not vanish
        D_b_inv * D_b == 1                          # in one random point, it is not vanishing in other
                                                   # but the gains are miniscule and im too lazy at this point
        dDdx_a = alloc(n="dDdx_a[{i}]".format(i=i))
        dDdx_a <= Dx_a + dydx_a * Dy_a
        dDdx_b = alloc(n="dDdx_b[{i}]".format(i=i))
        dDdx_b <= Dx_b + dydx_b * Dy_b

        s1 = alloc(n="s1[{i}]".format(i=i))
        s2 = alloc(n="s2[{i}]".format(i=i))
        s3 = alloc(n="s3[{i}]".format(i=i))
        s4 = alloc(n="s4[{i}]".format(i=i))
        
        s1 <= dDdx_b * D_b_inv
        s2 <= s1 * c2
        
        s3 <= dDdx_a * D_a_inv
        s4 <= (c2 + 2*slope)*s3
        
        logderiv.append(s2 - s4)        #this will push c2 * dDdx_b / D_b - (c2 + 2t) * dDdx_a / D_a

lhs = 0

for i in range(TRITSIZE):
    lhs += (-3)^i * logderiv[i]

    
lhs == rhs

print(EQUATIONS[-1].compute())

print(len(CONSTRAINTS))
print(len(SIGNALS))
print(len(EQUATIONS))

760063645055206506756556187314859609375646350568906621228753998293721024299
760063645055206506756556187314859609375646350568906621228753998293721024299
14361767126239636420196462233755932726290628579331319783274535335213750855128
14361767126239636420196462233755932726290628579331319783274535335213750855128
13972425220397600858818978998658521118662470861307597298065172016635129811005
13972425220397600858818978998658521118662470861307597298065172016635129811005
19140133972322983693952354941826209136489019798609603248681493537179206388944
19140133972322983693952354941826209136489019798609603248681493537179206388944
11839597085364108686342781107647275062146888381511093468134406301646929044978
11839597085364108686342781107647275062146888381511093468134406301646929044978
14818417686618866426883645091096922570973636033163690103143530423766969157559
14818417686618866426883645091096922570973636033163690103143530423766969157559
67983456999731138235330541050971075707111622639058561546224223192289

902385215843167942139655471513841393259988719077957204018585229436875419095
902385215843167942139655471513841393259988719077957204018585229436875419095
9324153366767344687285213095923865910756254321303985099178491009571750925136
9324153366767344687285213095923865910756254321303985099178491009571750925136
3818426451795747982974197820815073593658176215268967249575385696386691341711
3818426451795747982974197820815073593658176215268967249575385696386691341711
10092827343961733613356865894239574938010832254355667501228077685028893179300
10092827343961733613356865894239574938010832254355667501228077685028893179300
16183326431330938746927211272653746146527272979706700058595202951034356826179
16183326431330938746927211272653746146527272979706700058595202951034356826179
4795574770474244102892647027656095753747753657126977179144754128912212831027
4795574770474244102892647027656095753747753657126977179144754128912212831027
87828665402193970783138542927203739422625328093408314334677666698237545884

In [51]:
poly = x^3 + x*y
(poly*y).list()

[0, x^3, x]