In [4]:
from Solver.NonlocalSolver import NonlocalSolverMomentum
from sklearn.model_selection import ParameterGrid
import numpy as np
import matplotlib.pyplot as plt

In [5]:
dL = lambda y: 2 * (y - 4)
f = lambda x, y: 0.0
t = [1e-12, 10]

In [6]:
param_grid = {'lr': [0.1, 0.01], 'beta1': [0.9, 0.0],'beta2': [0.99, 0.999]}
n_learning_rates = len(param_grid['lr'])
param_list = list(ParameterGrid(param_grid))

In [7]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

fig_theta = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_m = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_v = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])

fig_theta.update_layout(title_text='Theta values convergence trajectories for the first-order nonlocal continuous Adam')
fig_m.update_layout(title_text='Numerators (m) over Time for the first-order nonlocal continuous Adam')
fig_v.update_layout(title_text='Denominators (v) over Time for the first-order nonlocal continuous Adam')

for i, lr in enumerate(param_grid['lr']):

    filtered_params = [p for p in param_list if p['lr'] == lr]

    for params in filtered_params:
        print(f'\nNonlocal Continuous Adam Configuration: {params}')

        solver = NonlocalSolverMomentum(f=f, dL=dL, t_span=t, y0=np.array([1.0]), alpha=params['lr'],
                                        betas=[params['beta1'], params['beta2']])
        t_values, y_values = solver.solve()
        
        label = f"beta1={params['beta1']}, beta2={params['beta2']}"

        # Agregar datos a la figura de theta
        fig_theta.add_trace(go.Scatter(
            x=t_values/params['lr'],
            y=y_values,
            mode='lines',
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Obtener los valores de numerators y denominators
        numerators = solver.m
        denominators = solver.v

        # Agregar datos a la figura de numerators
        fig_m.add_trace(go.Scatter(
            x=[item[0] for item in numerators],
            y=[item[1]/params['lr'] for item in numerators],
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Agregar datos a la figura de denominators
        fig_v.add_trace(go.Scatter(
            x=[item[0] for item in denominators],
            y=[item[1]/params['lr'] for item in denominators],
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)

# Actualización de los ejes
fig_theta.update_xaxes(title_text="Time/lr")
fig_theta.update_yaxes(title_text="Theta value")

fig_m.update_xaxes(title_text="Time/lr")
fig_m.update_yaxes(title_text="Numerator value (m)")

fig_v.update_xaxes(title_text="Time/lr")
fig_v.update_yaxes(title_text="Denominator value (v)")

# Mostrar las figuras
fig_theta.show()
fig_m.show()
fig_v.show()


Nonlocal Continuous Adam Configuration: {'beta1': 0.9, 'beta2': 0.99, 'lr': 0.1}
Iteration 0 advanced. Current error: 56.709267707766905.
Iteration 1 advanced. Current error: 12.433545737230553.
Iteration 2 advanced. Current error: 13.721282935189565.
Iteration 3 advanced. Current error: 6.92683854691461.
Iteration 4 advanced. Current error: 2.8806798548280748.
Iteration 5 advanced. Current error: 2.757854362604732.
Iteration 6 advanced. Current error: 1.6699844646602793.
Iteration 7 advanced. Current error: 0.7093479020625347.
Iteration 8 advanced. Current error: 0.35304424437066706.
Iteration 9 advanced. Current error: 0.27734840429728086.
Iteration 10 advanced. Current error: 0.19552991544021417.
Iteration 11 advanced. Current error: 0.11450180111199434.
Iteration 12 advanced. Current error: 0.05779594186344797.
Iteration 13 advanced. Current error: 0.026783216926622652.
Iteration 14 advanced. Current error: 0.013566291590842895.
Iteration 15 advanced. Current error: 0.008704794130