In [3]:
from scipy.integrate import fixed_quad
from scipy.interpolate import interp1d
from concurrent.futures import ThreadPoolExecutor
from numba import njit
from typing import Callable
import numpy as np

In [25]:
class NonlocalSolverMomentum2ndOrder:
    def __init__(self, f: Callable, dL: Callable, t_span: list, y0: np.array, dy0: np.array, betas: list,
                 alpha: float, omega:float = 0, lambda_:float = 0, verbose: bool = True):
        
        np.random.seed(33)

        self.f = f
        self.t_span = t_span
        self.y0 = y0
        self.dy0 = dy0  
        self.alpha = alpha
        self.omega = omega
        self.t = np.arange(t_span[0], t_span[1], alpha)
        self.betas = betas
        self.lambda_ = lambda_
        self.dL = dL
        self.alpha_t = lambda t: 2 * np.sqrt((1-betas[1]**(t/alpha)) / alpha) / ( alpha * (1 - betas[0]**(t/alpha)))
        
        self.epsilon_t = lambda t: np.sqrt((1-betas[1]**(t/alpha)) / alpha) * 1e-8

        self.smoothing_factor = 0.5
        self.smoothing_factor_max = 0.9999
        self.increments = np.linspace(self.smoothing_factor, self.smoothing_factor_max, num=int(1e2))
        self.max_value_index = False

        self.max_iteration = int(6e2)
        self.global_error_tolerance = 1e-4
        self.verbose = verbose


    def __k__(self, i: int, s: float ) -> float:

        exp_term = np.exp(- s / self.alpha)
        beta_value = self.betas[i - 1]

        if np.all(exp_term != 0.0):
            if beta_value == 0.5:
                return (s / self.alpha) * exp_term
            elif beta_value > 0.5:
                sqrt_term = np.sqrt(2 * beta_value - 1)
                sinh_value = np.sinh(sqrt_term / self.alpha * s)
                if np.any(np.isinf(sinh_value)):
                    log_exp_term = np.log(exp_term)
                    sinh_arg = sqrt_term / self.alpha * s
                    log_sinh_term = sinh_arg - np.log(2)
                    log_values = log_exp_term + log_sinh_term
                    return (2 * (1-beta_value)) / sqrt_term * np.exp(log_values)
                else:
                    return (2 * (1-beta_value)) / sqrt_term * exp_term * np.sinh(sqrt_term / self.alpha * s)
            else:
                sqrt_term = np.sqrt(1 - 2 * beta_value)
                return (2 * (1-beta_value)) / sqrt_term * exp_term * np.sin(sqrt_term / self.alpha * s)
        else:
            return 0.0

    def __initial_solution__(self) -> np.array:
        return self.__solve_ode__(self.__rhs_system__)
    
    def __solve_ode__(self, rhs_ode: Callable) -> np.array:
        t_values = self.t
        y_values = np.zeros((len(t_values), 2)) 
        y_values[0, 0] = self.y0.item() if isinstance(self.y0, np.ndarray) else self.y0
        y_values[0, 1] = self.dy0.item() if isinstance(self.dy0, np.ndarray) else self.dy0   
        for i in range(1, len(t_values)):
            dy = rhs_ode(t_values[i - 1], y_values[i - 1])
            y_values[i, 0] = y_values[i - 1, 0] + self.alpha * dy[0]  
            y_values[i, 1] = y_values[i - 1, 1] + self.alpha * dy[1]
        return y_values
    
    def __rhs_system__(self, t, y_vec):
        y1, y2 = y_vec
        dy1 = y2
        dy2 = - (2 / self.alpha) * (y2 + (self.omega / self.alpha) * y1)
        return np.array([dy1, dy2])
    
    def __rhs_with_integral_part__(self, y: np.array) -> np.array:
        y_interpolated = interp1d(self.t, y[:, 0], kind='cubic', fill_value="extrapolate", assume_sorted=True)

        dy_dt = np.gradient(y_interpolated(self.t), self.t)  
        dy_interpolated = interp1d(self.t, dy_dt, kind='cubic', fill_value="extrapolate", assume_sorted=True)

        self.m = []
        self.v = []

        def integral(t):
            def integrand(i, tp):
                k_value = self.__k__(i, t - tp)
                df_value = self.dL(y_interpolated(tp))
                common_term = df_value + 0.5 * self.lambda_ * y_interpolated(tp)
                return k_value * common_term if i == 1 else k_value * (common_term ** 2)
            
            numerador_func = lambda tp: integrand(1, tp)
            denominador_func = lambda tp: integrand(2, tp)

            with ThreadPoolExecutor() as executor:
                future_numerator = executor.submit(fixed_quad, numerador_func, 0, t, n=int(1e3))
                future_denominator = executor.submit(fixed_quad, denominador_func, 0, t, n=int(1e3))

                value_numerator, _ = future_numerator.result()
                value_denominator, _ = future_denominator.result()

            v_value = value_denominator
            v_sqrt_value = np.sqrt(v_value) 
            m_value = value_numerator

            self.m.append((t, m_value))
            self.v.append((t, v_value))   

            return m_value / (v_sqrt_value + self.epsilon_t(t))

        def rhs(t, y):
            y_vec = np.array([y_interpolated(t), dy_interpolated(t)])
            return self.__rhs_system__(t, y_vec) - self.alpha_t(t) * integral(t)
                
        return self.__solve_ode__(rhs)
    
    @staticmethod
    @njit(parallel=True)
    def __global_error__(y_new: np.array, y_guess: np.array) -> float:
        diff = y_new - y_guess
        return np.sqrt(np.sum(diff ** 2))
    
    @staticmethod
    @njit(parallel=True)
    def __next_y__(smoothing_factor: float, y_current: np.array, y_guess: np.array) -> np.array:
        return (smoothing_factor * y_current) + ((1.0 - smoothing_factor) * y_guess)            
        
    def solve(self):
        self.iteration = 0

        y_current = self.__initial_solution__()
        y_guess = self.__rhs_with_integral_part__(y_current)
        current_error = self.__global_error__(y_current, y_guess)

        if self.verbose:
            print(f"Iteration {self.iteration} advanced. Current error: {current_error}.")

        last_error = current_error
        while current_error > self.global_error_tolerance:
            
            y_new = self.__next_y__(self.smoothing_factor, y_current, y_guess)
            y_guess = self.__rhs_with_integral_part__(y_new)
            current_error = self.__global_error__(y_new, y_guess)

            y_current = y_new
            self.iteration += 1

            if current_error > last_error:
                    if self.max_value_index:
                        print(f'Maximum value of the smoothing factor reached. The algorithm will stop without reaching the desired tolerance. The error is {current_error}.')
                        break

                    try:
                        next_factor = self.increments[np.searchsorted(self.increments, self.smoothing_factor, side='right')]
                    except IndexError:
                        next_factor = self.smoothing_factor_max
                        print(f'Smoothing factor is at maximum value.')
                        self.max_value_index = True

                    self.smoothing_factor = min(self.smoothing_factor_max, next_factor)
            last_error = current_error

            if self.verbose and self.iteration % 1 == 0:
                print(f"Iteration {self.iteration} advanced. Current error: {current_error}.")

            if self.iteration >= self.max_iteration:
                print(f"Maximum number of iterations reached. Current error: {current_error}.") 
                break
            
        print(f'Last iteration: {self.iteration}. Final error: {current_error}')

        self.y = y_guess[:, 0]  
        self.dy = y_guess[:, 1]  
        self.global_error = current_error

        return self.t, self.y, self.dy

In [16]:
from sklearn.model_selection import ParameterGrid
import numpy as np
import matplotlib.pyplot as plt

In [17]:
dL = lambda y: 2 * (y - 4)
f = lambda x, y: 0.0
t = [1e-12, 5]

In [22]:
param_grid = {'lr': [0.1], 'beta1': [0.9],'beta2': [0.99]}
n_learning_rates = len(param_grid['lr'])
param_list = list(ParameterGrid(param_grid))

In [26]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

fig_theta = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_m = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])
fig_v = make_subplots(rows=1, cols=n_learning_rates, subplot_titles=[f'Learning Rate = {lr}' for lr in param_grid['lr']])

fig_theta.update_layout(title_text='Theta values convergence trajectories for the second-order nonlocal continuous Adam')
fig_m.update_layout(title_text='Numerators (m) over Time for the second-order nonlocal continuous Adam')
fig_v.update_layout(title_text='Denominators (v) over Time for the second-order nonlocal continuous Adam')

for i, lr in enumerate(param_grid['lr']):

    filtered_params = [p for p in param_list if p['lr'] == lr]

    for params in filtered_params:
        print(f'\nNonlocal Continuous Adam Configuration: {params}')

        solver = NonlocalSolverMomentum2ndOrder(f=f, dL=dL, t_span=t, y0=np.array([1.0]),dy0=np.array([0.0]), alpha=params['lr'],
                                        betas=[params['beta1'], params['beta2']])
        t_values, y_values, dy_values = solver.solve()
        
        label = f"beta1={params['beta1']}, beta2={params['beta2']}"

        # Agregar datos a la figura de theta
        fig_theta.add_trace(go.Scatter(
            x=t_values/params['lr'],
            y=y_values,
            mode='lines',
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Obtener los valores de numerators y denominators
        numerators = solver.m
        denominators = solver.v

        # Agregar datos a la figura de numerators
        fig_m.add_trace(go.Scatter(
            x=[item[0] for item in numerators],
            y=[item[1]/params['lr'] for item in numerators],
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)
        
        # Agregar datos a la figura de denominators
        fig_v.add_trace(go.Scatter(
            x=[item[0] for item in denominators],
            y=[item[1]/params['lr'] for item in denominators],
            mode='markers',
            marker=dict(size=3),
            name=label,
            legendgroup=f'LR={lr}',
        ), row=1, col=i+1)

# Actualización de los ejes
fig_theta.update_xaxes(title_text="Time/lr")
fig_theta.update_yaxes(title_text="Theta value")

fig_m.update_xaxes(title_text="Time/lr")
fig_m.update_yaxes(title_text="Numerator value (m)")

fig_v.update_xaxes(title_text="Time/lr")
fig_v.update_yaxes(title_text="Denominator value (v)")

# Mostrar las figuras
fig_theta.show()
fig_m.show()
fig_v.show()


Nonlocal Continuous Adam Configuration: {'beta1': 0.9, 'beta2': 0.99, 'lr': 0.01}
Iteration 0 advanced. Current error: 18160.82872453616.
Iteration 1 advanced. Current error: 1305002.4430603937.
Iteration 2 advanced. Current error: 976593.2373752543.
Iteration 3 advanced. Current error: 704491.4077236553.
Iteration 4 advanced. Current error: 388130.2466014872.
Iteration 5 advanced. Current error: 970427.8862883415.
Iteration 6 advanced. Current error: 573167.5244853057.
Iteration 7 advanced. Current error: 509143.3831725429.
Iteration 8 advanced. Current error: 719011.8313175759.
Iteration 9 advanced. Current error: 614894.9479529598.
Iteration 10 advanced. Current error: 466782.38105830515.
Iteration 11 advanced. Current error: 531858.7357102872.
Iteration 12 advanced. Current error: 649250.5172547246.
Iteration 13 advanced. Current error: 627647.3555535309.
Iteration 14 advanced. Current error: 514738.2976168963.
Iteration 15 advanced. Current error: 423681.95949173975.
Iteration 16

KeyboardInterrupt: 