# ADAM Simulations

In [1]:
from Solver.Solver import AdamMomentum
from sklearn.model_selection import ParameterGrid
import numpy as np
import matplotlib.pyplot as plt

In [3]:
param_grid = {'lr': [0.1, 0.01], 'beta1': [0.9, 0.0],'beta2': [0.99, 0.999]}
n_learning_rates = len(param_grid['lr'])
param_list = list(ParameterGrid(param_grid))

dL = lambda y: 2 * (y - 4)

In [6]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Creación de subplots
fig_theta = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_m = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_v = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])

for i, lr in enumerate(param_grid['lr']):

    filtered_params = [p for p in param_list if p['lr'] == lr]

    if lr == 0.1:
        epochs = 100
    elif lr == 0.01:
        epochs = 1000

    for params in filtered_params:
        theta_initial = 1.0
        
        print(f'\nAdam Configuration: {params}')
        solver = AdamMomentum(dL=dL, lr=lr, beta1=params['beta1'], beta2=params['beta2'], epochs=epochs)
        solver.solve(theta_initial=theta_initial)

        label = f"beta1={params['beta1']}, beta2={params['beta2']}"
        
        # Agregar datos a la figura de theta
        fig_theta.add_trace(go.Scatter(
            x=list(range(epochs)),
            y=solver.theta_result,
            mode='lines',
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Agregar datos a la figura de m_result
        fig_m.add_trace(go.Scatter(
            x=list(range(epochs)),
            y=solver.m_result,
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Agregar datos a la figura de v_result
        fig_v.add_trace(go.Scatter(
            x=list(range(epochs)),
            y=solver.v_result,
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)

# Actualización de los títulos y etiquetas
fig_theta.update_layout(title_text='Theta values convergence trajectories for the Adam Optimizer', showlegend=True)
fig_m.update_layout(title_text='First moment (m) convergence trajectories for the Adam Optimizer', showlegend=True)
fig_v.update_layout(title_text='Second moment (v) convergence trajectories for the Adam Optimizer', showlegend=True)

fig_theta.update_xaxes(title_text="k")
fig_theta.update_yaxes(tickformat=".1f", title_text="Theta_k")
fig_theta.update_layout(
    width=1500,  # Ancho en píxeles
    height=600   # Alto en píxeles
)

fig_m.update_xaxes(title_text="k")
fig_m.update_yaxes(title_text="m_k")
fig_m.update_layout(
    width=1500,  # Ancho en píxeles
    height=600   # Alto en píxeles
)

fig_v.update_xaxes(title_text="k")
fig_v.update_yaxes(tickformat=".1f", title_text="v_k")
fig_v.update_layout(
    width=1500,  # Ancho en píxeles
    height=600   # Alto en píxeles
)

# Mostrar las figuras
fig_theta.show()
fig_m.show()
fig_v.show()


Adam Configuration: {'beta1': 0.9, 'beta2': 0.99, 'lr': 0.1}
Epoch: 50, Error: 0.004799330898878296.
Epoch: 100, Error: 0.0021703457886625976.
Last epoch: 101, Error: 0.0021703457886625976.

Adam Configuration: {'beta1': 0.9, 'beta2': 0.999, 'lr': 0.1}
Epoch: 50, Error: 0.006239302049674045.
Epoch: 100, Error: 0.0011988458789149448.
Last epoch: 101, Error: 0.0011988458789149448.

Adam Configuration: {'beta1': 0.0, 'beta2': 0.99, 'lr': 0.1}
Epoch: 50, Error: 0.016781864495888144.
Epoch: 100, Error: 0.00020620802432569363.
Last epoch: 101, Error: 0.00020620802432569363.

Adam Configuration: {'beta1': 0.0, 'beta2': 0.999, 'lr': 0.1}
Epoch: 50, Error: 0.01684481567661056.
Epoch: 100, Error: 0.0003506600043072794.
Last epoch: 101, Error: 0.0003506600043072794.

Adam Configuration: {'beta1': 0.9, 'beta2': 0.99, 'lr': 0.01}
Epoch: 50, Error: 0.009498909793875132.
Epoch: 100, Error: 0.008722426780408243.
Epoch: 150, Error: 0.007945940231400606.
Epoch: 200, Error: 0.0071417128212911685.
Epoch: