
## Imports

In [None]:
%reload_ext autoreload
%autoreload 2
from merkle_tree import MerkleTree
from channel import Channel
from pprint import pprint

In [None]:
# Test

from pprint import pprint

m = MerkleTree([1, 2, 3, 4, 5, 6, 7, 8])
pprint([[b[:4] for b in a] for a in m.tree])
print(m.root)
pprint([a[:4] for a in m.get_path(0)])

### Fibonacci function

In [None]:
def fibonacci(a=1, size=1001) -> list[int]:
    fib_list = [1, a]
    for _ in range(2, size):
        fib_list.append(fib_list[-1] + fib_list[-2])
    return fib_list

In [None]:
print(fibonacci())

### Polynomial f(x)

In [None]:
P: int =3221225473
F = GF(P)
N = 8192

g = F(1734477367)
gamma = g ** 8
gamma_group = [gamma ** i for i in range(1001)]

fibonacci_seq = fibonacci()

R.<x> = PolynomialRing(F)

f = R.lagrange_polynomial(zip(gamma_group, fibonacci_seq))

### Bit Reverse Order

In [None]:
def bit_reverse(n, width):
    rev = 0  # This will store the reversed bits
    for i in range(width):
        rev = (rev << 1) | (n & 1)  # Shift left and take LSB of n
        n >>= 1  # Shift n to the right
    return rev


def bit_reverse_permutation(N):
    """Generate bit-reversed order for N elements."""
    width = int(math.log(N, 2))
    return [bit_reverse(i, width) for i in range(N)]


bit_reverse_order = bit_reverse_permutation(N)
bit_reverse_order

### w < g >

In [None]:
w = F.multiplicative_generator()

g_group = [g ** i for i in range(N)]

wg = [w * gi for gi in g_group]
wg_merkle_order = [wg[i] for i in bit_reverse_order]

### Commit f(x) on LDE

In [None]:
channel = Channel(F)

f_wg = [f(x) for x in wg_merkle_order]

merkle_f_wg = MerkleTree(f_wg)

channel.send({'title': 'f(x) on LDE', 'data': merkle_f_wg.root})

### Polynomial h(x)

In [None]:
d_x = 1
for gamma_i in gamma_group[:-2]:
    d_x = d_x * (x - gamma_i)

F_x = f(x) + f(gamma * x) - f((gamma**2) * x)

h_x = F_x // d_x

In [None]:
h_wg = [F_x(x) / d_x(x) for x in wg_merkle_order]

# Sanity check
h_wg_by_poly = [h_x(x) for x in wg_merkle_order]
print(h_wg_by_poly[:10])
print(h_wg[:10])
assert(h_wg == h_wg_by_poly)

### Boundary Constraints

In [None]:
Y = fibonacci_seq[-1]

t0_x = (f(x) - 1) / (x - gamma ** 0)
t1_x = (f(x) - Y) / (x - gamma ** 1000)

In [None]:
t0_wg = [(f(i) - 1) / (i - gamma ** 0) for i in wg_merkle_order]
t1_wg = [(f(i) - Y) / (i - gamma ** 1000) for i in wg_merkle_order]

# Sanity check
t0_x_by_poly = [t0_x(x) for x in wg_merkle_order]
t1_x_by_poly = [t1_x(x) for x in wg_merkle_order]
print(t0_x_by_poly[:10])
print(t0_wg[:10])
print("")
print(t1_x_by_poly[:10])
print(t1_wg[:10])
assert(t0_wg == t0_x_by_poly)
assert(t1_wg == t1_x_by_poly)

### Composition Polynomial

In [None]:
beta_0 = channel.receive_random_field_element()
beta_1 = channel.receive_random_field_element()
beta_2 = channel.receive_random_field_element()

cp0_x = beta_0 * h_x + beta_1 * t0_x + beta_2 * t1_x

In [None]:
cp0_wg = [beta_0 * h_wg[i] + beta_1 * t0_wg[i] + beta_2 * t1_wg[i]
          for i in range(len(wg))]

# Sanity check
cp0_x_by_poly = [cp0_x(x) for x in wg_merkle_order]
print(cp0_x_by_poly[:10])
print(cp0_wg[:10])
assert(cp0_wg == cp0_x_by_poly)

### FRI

In [None]:
fri_layers = []
degree = R(cp0_x).degree()
# if degree % 2 != 0:
#     degree += 1

curr_cp = [cp0_x(x) for x in wg_merkle_order]
while degree > 0:
    # Commit curr layer
    merkle_curr_cp = MerkleTree(curr_cp)
    fri_layers.append((curr_cp, merkle_curr_cp))
    channel.send({
        "title": f"commit for layer {len(fri_layers) - 1} of FRI",
        "data": merkle_curr_cp.root,
    })
    random_x = channel.receive_random_field_element()
    curr_cp = [((curr_cp[i] + curr_cp[i + 1]) / 2) +
               random_x * (curr_cp[i] - curr_cp[i + 1]) / (2 * wg_merkle_order[i]) for i in range(len(curr_cp))[::2]]
    degree = degree // 2

constant = curr_cp[0]
channel.send({"title:": "constant of last FRI layer",
              "data": str(constant),
              })

assert(all([constant == x for x in curr_cp]))

In [None]:
print(fri_layers)
print("Merkle Root")
print(curr_cp[0])

In [None]:
# Sanity check
print(curr_cp)

In [None]:
pprint(channel.get_all_messages())

### Decommit Phase

# TIOTA


## Decommit Phase

In [None]:
def send_f_of_x_and_gammas(idx_of_x, idx_of_gamma_x, idx_of_gamma_2_x):
    channel.send({
        'title': 'f of x',
        'data': f_wg[idx_of_x],
    })
    channel.send({
        'title': 'path of f of x',
        'data': merkle_f_wg.get_path(idx_of_x),
    })
    channel.send({
        'title': 'f of gamma x',
        'data': f_wg[idx_of_gamma_x],
    })
    channel.send({
        'title': 'path of f of gamma x',
        'data': merkle_f_wg.get_path(idx_of_gamma_x),
    })
    channel.send({
        'title': 'f of gamma squared x',
        'data': f_wg[idx_of_gamma_2_x],
    })
    channel.send({
        'title': 'path of f of gamma squared x',
        'data': merkle_f_wg.get_path(idx_of_gamma_2_x),
    })

def query(idx):
    idx_of_x = bit_reverse_order.index(idx)
    idx_of_gamma_x = bit_reverse_order.index(idx + 8)
    idx_of_gamma_2_x = bit_reverse_order.index(idx + 16)
    send_f_of_x_and_gammas(idx_of_x, idx_of_gamma_x, idx_of_gamma_2_x)
    for cp, cp_merkle in fri_layers:
        idx_of_minus_x = idx_of_x ^^ 1
        channel.send({ # DEBUG
            'title': 'cp of x',
            'data': cp[idx_of_x],
        })
        assert(cp_merkle.get_path(idx_of_x)[2:] == cp_merkle.get_path(idx_of_minus_x)[2:])
        channel.send({
            'title': 'path of cp of x',
            'data': cp_merkle.get_path(idx_of_x),
        })
        channel.send({ # DEBUG
            'title': 'cp of -x',
            'data': cp[idx_of_minus_x],
        })
        channel.send({
            'title': 'path of cp of x',
            'data': cp_merkle.get_path(idx_of_minus_x),
        })
        
query(0)

In [None]:
from random import randint, seed
seed(int(42069))

rand_idx = randint(0, len(wg))
rand_x = wg[rand_idx]
f_x_at_idx = f(rand_x)
f_x_at_g_idx = f(gamma * rand_x)
f_x_at_g2_idx = f(gamma**2 * rand_x)

channel.send({'title': 'decommit phase 0 ', 'data': {
             'result': f_x_at_idx, 'path': merkle_f_wg.get_path(rand_idx)}})

for i in range(len(cps)):
    cp = cps[i]
    cp_at_idx = cp(rand_x)
    cp_at_neg_idx = cp(-rand_idx)