In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax

from templates import multiplication
from templates import addition
from templates import permutation
from templates import utils

jax.config.update("jax_enable_x64", True)


def get_ntk_predictor(X_train, Y_train):

    n0 = X_train.shape[0]

    kernel_fn = stax.serial(
        stax.Dense(1024),   # First dense layer with 128 units
        stax.Relu(),       # ReLU activation
        stax.Dense(n0)      # Output dense layer with 1 unit
    )[-1]
    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train)

    return lambda x: predict_fn(x_test=x, get='ntk', compute_cov=True).mean[0]

eps = 1e-10 # small tolerance to avoid numerical issues

### Permutation

In [None]:
m = 10 # number of bits

# sample permutation at random
P = np.eye(m)[np.random.permutation(m)]

# define dataset
X, Y_train = permutation.get_dataset(m)
X_train = jnp.array(P.T@np.eye(len(X)), dtype=jnp.float64)
Y_train = jnp.array(Y_train, dtype=jnp.float64)
predict_fn = get_ntk_predictor(X_train, Y_train)


number = np.random.randint(0, 2**m)
number_bin = np.array([int(x) for x in np.binary_repr(number, width=m)])
out_bin = (P.T@number_bin).tolist()
out = int(''.join([str(int(x)) for x in out_bin]), 2)
print(f"Input: {number}")      

X_test = permutation.get_sample(m)[1]
X_test['p'] = number_bin.tolist()

print("Iteration 0:")
print(X_test)
print()

X_test = utils.encode_data(X_test, X)
X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)

y_pred = predict_fn(X_test)
y_pred_round = np.where(y_pred > eps, 1, 0).tolist()
X_test = permutation.unflatten_sample(y_pred_round, m)

print("Iteration 1:")
print(X_test)
print()


assert out_bin == X_test['p'] 

# decode output from binary to int
prediction = int(''.join([str(int(x)) for x in X_test['p']]), 2)
print(f"Expected Output: {out}")
print(f"Predicted Output: {prediction}")

### Addition

In [None]:
m = 4

X, Y_train = addition.get_dataset(m)
X_train = jnp.eye(len(X))
Y_train = jnp.array(Y_train, dtype=jnp.float64)
predict_fn = get_ntk_predictor(X_train, Y_train)

p = np.random.randint(0, 2**m)
q = np.random.randint(0, 2**m)
out = p + q

X_test = addition.get_sample(m)[1]
X_test['sum_p'] = np.array([int(x) for x in np.binary_repr(p, width=m)])[::-1].tolist()
X_test['sum_q'] = np.array([int(x) for x in np.binary_repr(q, width=m)])[::-1].tolist()

print(f"Input: {p} + {q}")
print("Iteration 0:")
print(X_test)
print()

for i in range(2*m):
    X_test_old = X_test.copy()
    X_test = utils.encode_data(X_test, X)

    X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)
    y_pred = predict_fn(X_test)
    y_pred_round = np.where(y_pred > eps, 1, 0).tolist()
    X_test = addition.unflatten_sample(y_pred_round, m)

    print("Iteration:", i+1)
    print("Updated variables:")
    for key in X_test.keys():
        if X_test[key] != X_test_old[key]:
            print(f"{key}: {X_test[key]}")
    print()

    # add breaking condition to avoid unnecessary printing
    # you can remove this condition if you want to see all iterations
    if X_test['sum_c'][:-1] == [0]*(m-1) and X_test['sum_q'] == [0]*m:
        break

out_bin = np.array([int(x) for x in np.binary_repr(out, width=m+1)])[::-1].tolist()
# check if the output is correct
assert out_bin[-1] == X_test['sum_c'][-1]
assert out_bin[:-1] == X_test['sum_p']

# decode output from binary to int
prediction = int(''.join([str(int(x)) for x in X_test['sum_p'] + X_test['sum_c'][-1:]])[::-1], 2)
print(f"Expected Output: {out}")
print(f"Predicted Output: {prediction}")

### Multiplication

In [None]:
m = 3 # number of bits

# define dataset
X, Y_train = multiplication.get_dataset(m)
pad = m
X_train = jnp.eye(len(X) + pad)
Y_pad = np.zeros((pad, Y_train.shape[1]), dtype=np.float64)
Y_train = jnp.array(np.vstack((Y_train, Y_pad)), dtype=jnp.float64)
predict_fn = get_ntk_predictor(X_train, Y_train)

# define test sample
# randomly sample from 2**m range
multiplier = np.random.randint(0, 2**m)
multiplicand = np.random.randint(0, 2**m)
out = multiplier*multiplicand
print(f"Multiplier: {multiplier}, Multiplicand: {multiplicand}")

# initialize test sample
X_test = multiplication.get_sample(m)[1]
X_test['multiplier'] = np.array([int(x) for x in np.binary_repr(multiplier, width=m)])[::-1].tolist()
X_test['multiplicand'] = np.array([int(x) for x in np.binary_repr(multiplicand, width=2*m)])[::-1].tolist()
X_test['to_check_lsb'][0] = 1

print("Iteration: 0")
print(X_test)
print()

for i in range(4*(m**2)+ 3*m):
    X_test_old = X_test.copy()
    X_test = utils.encode_data(X_test, X)
    X_test = np.concatenate([X_test, np.zeros(pad)])
    X_test = jnp.array(X_test, dtype=jnp.float64).reshape(1, -1)
    y_pred = predict_fn(X_test)

    y_pred = np.where(y_pred > eps, 1, 0).tolist()
    X_test = multiplication.unflatten_sample(y_pred, m)
    print("Iteration:", i+1)
    print("Updated variables:")
    for key in X_test.keys():
        if X_test[key] != X_test_old[key]:
            print(f"{key}: {X_test[key]}")
    print()

    # add breaking condition to avoid unnecessary printing
    # you can remove this condition if you want to see all iterations
    if X_test['multiplier'] == [0]*m:
        break

out_bin = np.array([int(x) for x in np.binary_repr(out, width=2*m)])[::-1].tolist()
assert out_bin == X_test['sum_p']

# decode output from binary to int
prediction = int(''.join([str(int(x)) for x in X_test['sum_p']])[::-1], 2)
print(f"Expected Output: {out}")
print(f"Predicted Output: {prediction}")