# Noisy gradient descent experiments
This notebook includes experiments for noisy gradient descent with partial and full covariance.

In [4]:
import jax.numpy as jnp
from jax import random
from noisy_gradient_descent import *

This test case approximates $y = x$ for $x \in [0,1]$

In [8]:
STEP_SIZE = 0.001
FINAL_TIME = 5.5
LEARNING_RATE = 0.1
SDE_SOLVER_ITERATIONS = 10


x = jnp.linspace(0.0, 2.0, 100).reshape((100, 1))
y = x

key = random.PRNGKey(2)
sizes = (1, 1)
parital_covariance_parameters = initialize_network_parameters(sizes, key)
full_covariance_parameters = initialize_network_parameters(sizes, key)
time = jnp.arange(0.0, FINAL_TIME, STEP_SIZE)
weight_full_covariance = []
weight_parital_covariance = []
# weight_gradient_descent = []
current_time = 0.0

while current_time < FINAL_TIME - 0.0001:
    key, subkey = random.split(key)
    for _ in range(SDE_SOLVER_ITERATIONS):
        key, subkey = random.split(key)
        parital_covariance_parameters = full_covariance_update(
            subkey, parital_covariance_parameters, sizes, x, y, STEP_SIZE, LEARNING_RATE)
        weight_full_covariance.append(parital_covariance_parameters[0][0][0])

        key, subkey = random.split(key)
        full_covariance_parameters = partial_covariance_update(full_covariance_parameters, x, y,
                                        STEP_SIZE, LEARNING_RATE, key)
        weight_parital_covariance.append(full_covariance_parameters[0][0][0])

    current_time += SDE_SOLVER_ITERATIONS * STEP_SIZE
    if current_time >= FINAL_TIME - 0.0001:
        break
    print(current_time)

plt.figure()
plt.plot(time, jnp.array(weight_full_covariance).flatten(), "r-.", label=f"Full covariance weight")
plt.plot(time, jnp.array(weight_parital_covariance).flatten(), label=f"Partial covariance weight")
# plt.plot(t, jnp.array(b_sde).flatten(), label=f"sde b")
plt.legend()
plt.show()

0.01
0.02
0.03
0.04
0.05
0.060000000000000005
0.07
0.08
0.09
0.09999999999999999
0.10999999999999999
0.11999999999999998
0.12999999999999998
0.13999999999999999
0.15
0.16
0.17
0.18000000000000002
0.19000000000000003
0.20000000000000004
0.21000000000000005
0.22000000000000006
0.23000000000000007
0.24000000000000007
0.25000000000000006
0.26000000000000006
0.2700000000000001
0.2800000000000001
0.2900000000000001
0.3000000000000001
0.3100000000000001
0.3200000000000001
0.3300000000000001
0.34000000000000014
0.35000000000000014
0.36000000000000015
0.37000000000000016
0.38000000000000017
0.3900000000000002
0.4000000000000002
0.4100000000000002
0.4200000000000002
0.4300000000000002
0.4400000000000002
0.45000000000000023
0.46000000000000024
0.47000000000000025
0.48000000000000026
0.49000000000000027
0.5000000000000002
0.5100000000000002
0.5200000000000002
0.5300000000000002
0.5400000000000003
0.5500000000000003
0.5600000000000003
0.5700000000000003
0.5800000000000003
0.5900000000000003
0.60000