In [1]:
import numpy as np
import jax.numpy as jnp
import jax
from jax.experimental import sparse
from polynomial_generator import eval_polynomial, generate_random_polynomial, generate_training_set, eval_polynomial_vectorized

(Array([[-1.        , -1.        , -1.        ],
       [-0.97979796, -0.97979796, -0.97979796],
       [-0.9595959 , -0.9595959 , -0.9595959 ],
       [-0.939394  , -0.939394  , -0.939394  ],
       [-0.91919196, -0.91919196, -0.91919196],
       [-0.8989899 , -0.8989899 , -0.8989899 ],
       [-0.8787879 , -0.8787879 , -0.8787879 ],
       [-0.85858583, -0.85858583, -0.85858583],
       [-0.8383838 , -0.8383838 , -0.8383838 ],
       [-0.81818175, -0.81818175, -0.81818175],
       [-0.79797983, -0.79797983, -0.79797983],
       [-0.7777778 , -0.7777778 , -0.7777778 ],
       [-0.75757575, -0.75757575, -0.75757575],
       [-0.7373737 , -0.7373737 , -0.7373737 ],
       [-0.71717167, -0.71717167, -0.71717167],
       [-0.69696975, -0.69696975, -0.69696975],
       [-0.6767677 , -0.6767677 , -0.6767677 ],
       [-0.65656567, -0.65656567, -0.65656567],
       [-0.6363636 , -0.6363636 , -0.6363636 ],
       [-0.6161616 , -0.6161616 , -0.6161616 ],
       [-0.59595966, -0.59595966, -0.59

In [2]:
max_coeff = 3
polynomial = generate_random_polynomial(3, [2,3,4], max_coeff)
print(polynomial)
print(polynomial.data)
print(polynomial.indices)
data = jnp.array([1,2,3])
print(data)

BCOO(int32[2, 3, 4], nse=3)
[0 2 1]
[[0 3 1]
 [2 1 2]
 [0 2 4]]
[1 2 3]


In [3]:
multy_data = jnp.array([[1,2,3], [1,2,3]])
eval_multy_data = jax.vmap(lambda x: eval_polynomial(polynomial, x), in_axes=0)(multy_data)
print(eval_multy_data)
print(eval_polynomial(polynomial, data))

[360 360]
360


In [4]:
test = jnp.linspace(np.array([0,0,0]), np.array([1,1,1]), 10)
print(test)

[[0.         0.         0.        ]
 [0.11111111 0.11111111 0.11111111]
 [0.22222222 0.22222222 0.22222222]
 [0.33333334 0.33333334 0.33333334]
 [0.44444445 0.44444445 0.44444445]
 [0.5555556  0.5555556  0.5555556 ]
 [0.6666667  0.6666667  0.6666667 ]
 [0.7777778  0.7777778  0.7777778 ]
 [0.8888889  0.8888889  0.8888889 ]
 [1.         1.         1.        ]]


In [5]:
start_samples = np.array([-1, -1, -1])
end_samples = np.array([1, 1, 1])
polynomial = generate_random_polynomial(1, [2, 3, 3], 10)
x, y_pure, y_noisy = generate_training_set(polynomial, 10, start_samples, end_samples, 0.25, 42)
print(x)
print(y_pure)
print(y_noisy)

[[-1.         -1.         -1.        ]
 [-0.7777778  -0.7777778  -0.7777778 ]
 [-0.5555556  -0.5555556  -0.5555556 ]
 [-0.33333328 -0.33333328 -0.33333328]
 [-0.11111113 -0.11111113 -0.11111113]
 [ 0.11111116  0.11111116  0.11111116]
 [ 0.33333337  0.33333337  0.33333337]
 [ 0.5555556   0.5555556   0.5555556 ]
 [ 0.7777778   0.7777778   0.7777778 ]
 [ 1.          1.          1.        ]]
[8.0000000e+00 1.0713571e+00 7.2595574e-02 1.2193248e-03 1.8584487e-07
 1.8584531e-07 1.2193273e-03 7.2595574e-02 1.0713571e+00 8.0000000e+00]
[8.7380085e+00 9.4797015e-01 5.6895085e-02 1.5875878e-03 2.3244849e-07
 1.4538651e-07 1.0978519e-03 6.0490094e-02 1.1615905e+00 8.7964497e+00]


In [19]:
def loss(polynomial_data, polynomial_indices, polynomial_shape, x, y):
    polynomial = sparse.BCOO((polynomial_data, polynomial_indices), shape= polynomial_shape)
    y_pred = eval_polynomial_vectorized(polynomial, x)
    loss = jnp.log(jnp.sum(jnp.square(y_pred - y)))
    return loss
grad_loss = jax.grad(loss, allow_int=True)

# def compute_loss_and_grad(param_w, data, start, stop, num_points=100):
#     param_w_values = jnp.linspace(start, stop, num_points)
#     loss_values = jnp.array([loss(w, data) for w in param_w_values])
#     grad_values = jnp.array([grad_loss(w, data) for w in param_w_values])
#     return param_w_values, loss_values, grad_values

print(polynomial.shape)
print(polynomial.data)
test = loss(polynomial.data, polynomial.indices, polynomial.shape, x, y_noisy)
print(test)
grad_test = grad_loss(polynomial.data, polynomial.indices, polynomial.shape, x, y_noisy)
print(grad_test)
print(grad_test[0])

(2, 3, 3)
[9 8 4]
1.62921
[(b'',) (b'',) (b'',)]
(b'',)


In [18]:
# %% Run stochastic gradient descent
num_epochs = 50
learning_rate = 0.01
param_w = jnp.array([1.0, 1.0, 1.0])
n_points = 10
polynomial = generate_random_polynomial(3, [2, 3, 3], 10)
x, y_pure, y_noisy = generate_training_set(polynomial, n_points, start_samples, end_samples, 0.25, 42)
num_points_per_batch = n_points // 5
print("\n===== Running Stochastic Gradient Descent =====")
for epoch in range(num_epochs):
    # Get points for the current batch
    for i in range(0, n_points, num_points_per_batch):
        batch_x = x[i:i + num_points_per_batch]
        batch_y = y_noisy[i:i + num_points_per_batch]
        grad = grad_loss(polynomial.data, polynomial.indices, polynomial.shape, batch_x, batch_y)
        param_w = polynomial.data - learning_rate * grad

    print(f"Epoch {epoch}: param_w={param_w}, grad={grad}, loss={loss(param_w, train_ds)}")


===== Running Stochastic Gradient Descent =====


UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('float64'), dtype([('float0', 'V')])) -> None