### Perturbation transfer learning for Duffing equation
$$
\begin{align*}
\frac{dx}{dt} - y &= 0,\quad \frac{dy}{dt} + \alpha y + \delta x + \beta x^3 &= \gamma cos(\omega t)
\end{align*}
$$
with initial conditions $x_0=1$ and $y_0=0.5$, $\delta=0.1$, $\beta=0.1$, $\gamma=1$, $\omega=1$, $\alpha>1$ is the stiffness parameter.

Then by $\beta$-perturbation expansion $\tilde{x}=\sum_{i=0}^{\infty}\beta^iX_i$ and $\tilde{y}=\sum_{i=0}^{\infty}\beta^iY_i$ with have these systems for the various power of $\beta^i$ with $i=0, 1, 2, ...$
$$
\begin{align*}
\frac{dX_i}{dt} - Y_i &= 0, \quad \frac{dY_i}{dt} + \delta X_i + \alpha Y_i  &= f_i \\
\end{align*} \\
\begin{align*}
f_i = \begin{cases}
\gamma cos(\omega t)\ \ \text{for} \ i=0 \\
-\sum_{a+b+c=i-1}^{0<a,b,c<p}\phi(a,b,c)X_aX_bX_c \ \ \text{with} \ \phi(a,b,c)=\begin{cases}
6 \ \text{if} \ a\ne b\ne c \\
1 \ \text{if} \ a=b=c \\
3 \ \text{otherwise}
\end{cases}
\end{cases}
\end{align*}
$$

So these are linear system with A = $
\begin{bmatrix}
  0 & -1 \\
  \delta & \alpha 
\end{bmatrix}
$ and f = $\begin{bmatrix}
  0 \\
  f_i 
\end{bmatrix}$


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# imports
import torch
import matplotlib.pyplot as plt
import numpy as np
import sys
import json
import os
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm
from collections import defaultdict
import time

# Add parent directory to sys.path
from pathlib import Path
current_path = Path.cwd()
parent_dir = current_path.parent.parent
sys.path.append(str(parent_dir))

# Import necessary modules
from src.utils_plot import plot_loss_and_all_solution
from src.load_save import load_run_history
from src.transfer_learning import compute_H_and_dH_dt, analytically_compute_weights
from src.nonlinear_transfer_learning import solve_perturbation_TL

from src.utils_plot import plot_transfer_learned_and_analytical

torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

In [None]:
def check_versions_and_device():
  # set the device to the GPU if it is available, otherwise use the CPU
  current_dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  !nvidia-smi
  return current_dev

# set a global device variable to use in code
dev = check_versions_and_device()
print(dev)

### Load the pretrain model and history

In [None]:
file_name = "base_model_1311309"
equation_name = "Duffing"

trained_model, \
x_range, \
iterations, \
hid_lay, \
num_equations, \
num_heads, \
loss_hist, \
alpha_list, \
A_list, \
IC_list, \
force_list = load_run_history(equation_name, file_name, dev, prev=False)

# create the A matrix associated with the ODE
def get_A(alpha):
    return torch.tensor([[0., -1.], [0.1, alpha]], device=dev).double() 

force_list = [
    lambda t: torch.cat([torch.zeros(len(t), device=dev).unsqueeze(1), torch.cos(t).unsqueeze(1)], dim=1).double()
    if not isinstance(t, (float, int))
    else np.array([0, np.cos(t)]).T
    for _ in range(num_heads)
]

# function to numerically compute the solution to any set of two coupled, linear first-order ODES
def double_coupled_equation(t, y, A, force):
    return np.array([force(t)[0] - A[0][1] * y[1] - A[0][0] * y[0],
                     force(t)[1] - A[1][0] * y[0] - A[1][1] * y[1]])

numerical_sol_fct = lambda x, v, A, force: (solve_ivp(double_coupled_equation, [x_range[0], x_range[1]],
                                                    v.squeeze(), args=(A, force), t_eval=x.squeeze(), method="Radau").y)

plot_loss_and_all_solution(x_range=x_range, true_functs=numerical_sol_fct,
                           trained_model=trained_model, IC_list=IC_list, A_list=A_list,
                           force=force_list, train_losses=loss_hist, device=dev)

## Perturbation Transfer Learning

### Extract H 

In [None]:
# forward pass to extract H, H0 and dH/dt
size = 512
H, H_0, dH_dt_new, t_eval = compute_H_and_dH_dt(x_range[0], x_range[1], trained_model, num_equations, hid_lay, size, dev)
H = H.double()
H_0 = H_0.double()
dH_dt_new = dH_dt_new.double()
t_eval = t_eval.double()

### Try to transfer learning on the linear form

In [None]:
# Transfer on the linear system Choose alpha, IC and force to transfer on
alpha_transfer = 20
A_transfer = get_A(alpha=alpha_transfer)
IC_transfer = IC_list[0]
force_transfer = force_list[0]

M_inv, W_out, force_terms, total_time = analytically_compute_weights(dH_dt_new, H, H_0, t_eval,
                                                                     IC_transfer, A_transfer,
                                                                     force_transfer)

# function to numerically compute the solution to any set of two coupled, linear first-order ODES
def double_coupled_equation(t, y, A, force):
    return np.array([force(t)[0] - A[0][1] * y[1] - A[0][0] * y[0],
                     force(t)[1] - A[1][0] * y[0] - A[1][1] * y[1]])

numerical_sol_fct = lambda x, v, A, force: (solve_ivp(double_coupled_equation, [x_range[0], x_range[1]],
                                                    v.squeeze(), args=(A, force), t_eval=x.squeeze(), method="Radau").y)
plot_transfer_learned_and_analytical(H, W_out, t_eval, IC_transfer, A_transfer,
                                     force_transfer, num_equations, numerical_sol_fct)

# Perturbation transfer learning

### Choose transfer equaiton parameters

-   stiffness parameter $\alpha$
-   nonlinear parameter $\beta$
-   iniital conditions $IC$
-   force function $f$

In [None]:
alpha_transfer = 40;
A_transfer = get_A(alpha=alpha_transfer)
beta = 0.5
IC_transfer = torch.tensor([[1.], [0.5]], device=dev).double()
force_transfer = force_list[0]
force_transfer = lambda t: torch.cat([torch.zeros(len(t), device=dev).unsqueeze(1), torch.cos(t).unsqueeze(1)], dim=1).double() if not isinstance(t, (float, int)) else np.array([0, np.cos(t)]).T

### Compute numerical solution

In [None]:
domain = (x_range[0], x_range[1]);

def numerical_non_linear_solution(alpha, beta, u0, domain, t_eval, method="Radau"):
  def F(t, y):
    return [y[1],
            -alpha*y[1] - 0.1*y[0] - beta*y[0]**3 + np.cos(t)]
  solution = solve_ivp(F, domain, u0, t_eval=t_eval, method=method)
  return solution

non_linear_num_sol = numerical_non_linear_solution(alpha_transfer, beta, IC_transfer.detach().cpu().squeeze(), domain, t_eval.detach().cpu().numpy().squeeze())

### Component for perturbation transfer learning

-   $\phi$ function
-   numerical solution of each system

In [None]:
# functions to calculate the force function of each system p
def force_func_index(n):
    solution_index = [] # ind1, ind2, ind3, coeff
    for a in range(n+1):
        for b in range(a+1):
              for c in range(b+1):
                if ((a+b+c)==n):
                    if ((a==b) & (b==c)):
                        solution_index.append([a, b, c, 1])
                    elif ((a!=b) & (b!=c)):
                        solution_index.append([a, b, c, 6])
                    else:
                        solution_index.append([a, b, c, 3])
    return solution_index

def force_function_PINNS(i, alpha, list_force_index, PINNS_list):
      result = 0
      for force_ind in list_force_index[i-1]:
        result += force_ind[-1]*(PINNS_list[force_ind[0]][:, 0, :]*PINNS_list[force_ind[1]][:, 0, :]*PINNS_list[force_ind[2]][:, 0, :])
      return torch.hstack((torch.zeros_like(result), -result))

def force_function_numerical(i, alpha, list_force_index, numerical_pert_list):
      result = 0
      for force_ind in list_force_index[i-1]:
        result += force_ind[-1]*(numerical_pert_list[force_ind[0]][0, :]*numerical_pert_list[force_ind[1]][0, :]*numerical_pert_list[force_ind[2]][0, :])
      return np.vstack((np.zeros_like(result), -result)).T


def solve_numericaly_perturbation(t, y, A, force, t_eval):
    index = np.argmin(np.abs(t_eval- t))
    return np.array([force[index, 0] - A[0][1] * y[1] - A[0][0] * y[0],
                     force[index, 1] - A[1][0] * y[0] - A[1][1] * y[1]])

numerical_perturbation_fct = lambda x, v, A, force: (solve_ivp(solve_numericaly_perturbation, [x_range[0], x_range[1]],
                                                    v.squeeze(), args=(A, force, x), t_eval=x.squeeze(), method="Radau"))


### Solve the p systems for perturbation transfer learning

In [None]:
p=10
compute_numerical_pert = False

solution_PINNS, \
solution_numerical, \
PINNS_list, \
numerical_pert_list, \
total_time = solve_perturbation_TL(beta=beta, p=p, t_eval=t_eval,
                                   alpha=alpha_transfer, A=A_transfer,
                                   force=force_transfer, IC=IC_transfer,
                                   H=H, H_0=H_0, dH_dt=dH_dt_new, dev=dev,
                                   force_func_index=force_func_index,
                                   numerical_sol_fct=numerical_sol_fct,
                                   force_function_PINNS=force_function_PINNS,
                                   force_function_numerical=force_function_numerical,
                                   compute_numerical_pert=compute_numerical_pert,
                                   numerical_perturbation_fct=numerical_perturbation_fct,
                                   verbose=True)

### Plot results

In [None]:
t_numpy = t_eval.detach().cpu().numpy()

fig, ax = plt.subplots(1, 2, figsize=(18, 5), tight_layout=False)
# # plot PINNS solution
ax[0].plot(t_numpy, solution_PINNS[:, 0], 'x', markersize=8, label=f'PINNS $y_1$',
                  linewidth=3.5)
ax[0].plot(t_numpy, solution_PINNS[:, 1], 'x', markersize=8, label=f'PINNS $y_2$',
                  linewidth=3.5)
# plot Numerical solution
ax[0].plot(t_numpy, non_linear_num_sol.y[0], label=f'Numerical $y_{1}$', linewidth=2.5)
ax[0].plot(t_numpy, non_linear_num_sol.y[1], label=f'Numerical $y_{2}$', linewidth=2.5)

# plot numerical solution by perturbation
if compute_numerical_pert:
    ax[0].plot(t_numpy[::10], solution_numerical[::10, 0], '*', c='blue', label='Numerical perturbation x', markersize=5)
    ax[0].plot(t_numpy[::10], solution_numerical[::10, 1], '*', c='orange', label='Numerical perturbation y')

ax[0].set_title("$y(t)$ for PINNs Transfer and Numerical Solutions", fontsize=20)
ax[0].set_xlabel("t", fontsize=16)
ax[0].set_ylabel("$y(t)$", fontsize=16)
ax[0].tick_params(axis='x', labelsize=16)
ax[0].tick_params(axis='y', labelsize=16)
ax[0].grid()
ax[0].legend()

# plot errors
ax[1].plot(t_numpy, np.abs(solution_PINNS[:, 0]-non_linear_num_sol.y[0]), label='Error $y_1$');
ax[1].plot(t_numpy, np.abs(solution_PINNS[:, 1]-non_linear_num_sol.y[1]), label='Error $y_2$');
ax[1].set_title("Absolute Error", fontsize=20)
ax[1].set_xlabel("$t$", fontsize=16)
ax[1].set_yscale('log')
ax[1].set_ylabel('Error Value', fontsize=16)
ax[1].tick_params(axis='x', labelsize=16)
ax[1].tick_params(axis='y', labelsize=16)
ax[1].grid()
ax[1].legend()
#fig.suptitle(fr"Solving non linear stiff ODE with $\alpha={alpha_transfer}$, $\beta={beta:.2f}$, $p={p}$")


### Plot the PINNs and numerical solution of each systems

In [None]:
# plot Xi
nb_row_plot = (p+1)//4+1 if (p+1)%4!=0 else (p+1)//4
fig, ax = plt.subplots(nb_row_plot, 4, figsize=(15, nb_row_plot*2))
for i in range(p+1):
  j = i//4; k = i%4
  pert_coeff = (beta)**i
  ax[j][k].plot(t_numpy, PINNS_list[i][:, 0]*pert_coeff, label="PINNS")
  if compute_numerical_pert:
    ax[j][k].plot(t_numpy, numerical_pert_list[i][:, 0]*pert_coeff, '--', label="Numerical")
  ax[j][k].set_title(rf"$\beta^{i} x_{i}$")
  ax[j][k].legend(loc="best")
fig.suptitle(fr"Linear systems X solution with $\alpha={alpha_transfer}$, $\beta={beta:.2f}$, $p={p}$")
fig.tight_layout()

# plot Yi
fig, ax = plt.subplots(nb_row_plot, 4, figsize=(15, nb_row_plot*2))
for i in range(p+1):
  j = i//4; k = i%4
  pert_coeff = (beta)**i
  ax[j][k].plot(t_numpy, PINNS_list[i][:, 1]*pert_coeff, label="PINNS")
  if compute_numerical_pert:
    ax[j][k].plot(t_numpy, numerical_pert_list[i][:, 1]*pert_coeff, '--', label="Numerical")
  ax[j][k].set_title(rf"$\beta^{i} y_{i}$")
  ax[j][k].legend(loc="best")
fig.suptitle(fr"Linear systems Y solution with $\alpha={alpha_transfer}$, $\beta={beta:.2f}$, $p={p}$")
fig.tight_layout()

# Hyperparameter optimization

## First experiment
-   Fix $\alpha$
-   Changing $p$ and $\beta$

In [None]:
alpha_transfer = 100
A_transfer = get_A(alpha=alpha_transfer)

p_list = [i for i in range(1, 30, 1)]
beta_list = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

compute_numerical_pert = False
MSE_list = []
MaxAE_list = []
for beta in tqdm(beta_list):
    MSE_list.append([])
    MaxAE_list.append([])
    non_linear_num_sol_beta = numerical_non_linear_solution(alpha_transfer, beta, IC_transfer.detach().cpu().squeeze(), domain, t_eval.detach().cpu().numpy().squeeze())
    for p in tqdm(p_list):
        sp, sn, _, _, tt = solve_perturbation_TL(beta=beta, p=p, t_eval=t_eval,
                                                 alpha=alpha_transfer, A=A_transfer,
                                                 force=force_transfer, IC=IC_transfer,
                                                 H=H, H_0=H_0, dH_dt=dH_dt_new, dev=dev,
                                                 force_func_index=force_func_index,
                                                 numerical_sol_fct=numerical_sol_fct,
                                                 force_function_PINNS=force_function_PINNS,
                                                 force_function_numerical=force_function_numerical,
                                                 compute_numerical_pert=compute_numerical_pert,
                                                 numerical_perturbation_fct=numerical_perturbation_fct,
                                                 verbose=False)
        MSE_list[-1].append([np.mean(np.abs(non_linear_num_sol_beta.y[0]- sp[:, 0])),
                        np.mean(np.abs(non_linear_num_sol_beta.y[1]- sp[:, 1]))])
        MaxAE_list[-1].append([np.max(np.abs(non_linear_num_sol_beta.y[0]- sp[:, 0])),
                        np.max(np.abs(non_linear_num_sol_beta.y[1]- sp[:, 1]))])

### Set up color for the plot

In [None]:
def generate_blue_gradient(n):
    # Create a linear gradient from light to dark blue
    colors = plt.cm.Blues(np.linspace(0, 1, n))
    # Convert RGBA values to RGB
    rgb_colors = [(r, g, b) for r, g, b, _ in colors]
    
    return rgb_colors
def generate_red_gradient(n):
    # Create a linear gradient from light to dark blue
    colors = plt.cm.Reds(np.linspace(0, 1, n))
    
    # Convert RGBA values to RGB
    rgb_colors = [(r, g, b) for r, g, b, _ in colors]
    
    return rgb_colors

n_colors = 20
blue_gradient = generate_blue_gradient(n_colors)
red_gradient = generate_red_gradient(n_colors)

### Plot MEA

In [None]:
# Create subplots
fig, ax= plt.subplots(1, 2, figsize=(15, 4))

# Plot for y1
for i, beta in enumerate(beta_list):
    ax[0].plot(p_list, np.array(MSE_list[i])[:, 0], label=rf"$\beta={beta}$", color=blue_gradient[i+4], linewidth=2)
ax[0].set_title(r"MAE of $y_1$ vs $p$ for $\beta \in [0, 1]$", fontsize=20)
ax[0].set_xlabel("Number of system $p$", fontsize=19)
ax[0].set_yscale('log')
ax[0].tick_params(axis='x', labelsize=16)
ax[0].tick_params(axis='y', labelsize=16)
ax[0].grid()
ax[0].legend()

# Plot for y2
for i, beta in enumerate(beta_list):
    ax[1].plot(p_list, np.array(MSE_list[i])[:, 1], label=rf"$\beta={beta}$", color=red_gradient[i+4], linewidth=2)
ax[1].set_title(r"MAE of $y_2$ vs $p$ for $\beta \in [0, 1]$", fontsize=19)
ax[1].set_xlabel("Number of system $p$", fontsize=16)
ax[1].set_yscale('log')
ax[1].tick_params(axis='x', labelsize=16)
ax[1].tick_params(axis='y', labelsize=16)
ax[1].grid()
ax[1].legend()

## Second experiment
-   Fix $\beta$
-   Changing $p$ and $\alpha$

In [None]:
beta = 0.5
p_list = [i for i in range(1, 5, 1)]
alpha_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150]
compute_numerical_pert = False
MSE_list = []
MaxAE_list = []
for alpha_transfer in tqdm(alpha_list):
    MSE_list.append([])
    MaxAE_list.append([])
    A_transfer = get_A(alpha=alpha_transfer)
    non_linear_num_sol_beta = numerical_non_linear_solution(alpha_transfer, beta, IC_transfer.detach().cpu().squeeze(), domain, t_eval.detach().cpu().numpy().squeeze())
    for p in tqdm(p_list):
        sp, sn, _, _, tt = solve_perturbation_TL(beta=beta, p=p, t_eval=t_eval,
                                                 alpha=alpha_transfer, A=A_transfer,
                                                 force=force_transfer, IC=IC_transfer,
                                                 H=H, H_0=H_0, dH_dt=dH_dt_new, dev=dev,
                                                 force_func_index=force_func_index,
                                                 numerical_sol_fct=numerical_sol_fct,
                                                 force_function_PINNS=force_function_PINNS,
                                                 force_function_numerical=force_function_numerical,
                                                 compute_numerical_pert=compute_numerical_pert,
                                                 numerical_perturbation_fct=numerical_perturbation_fct,
                                                 verbose=False)
        MSE_list[-1].append([np.mean(np.abs(non_linear_num_sol_beta.y[0]- sp[:, 0])),
                        np.mean(np.abs(non_linear_num_sol_beta.y[1]- sp[:, 1]))])
        MaxAE_list[-1].append([np.max(np.abs(non_linear_num_sol_beta.y[0]- sp[:, 0])),
                        np.max(np.abs(non_linear_num_sol_beta.y[1]- sp[:, 1]))])

### Plot MEA

In [None]:
# Create subplots
fig, ax = plt.subplots(1, 2, figsize=(15, 5.5))

# Plot for y1
for i, alpha in enumerate(alpha_list):
    ax[0].plot(p_list, np.array(MSE_list[i])[:, 0], label=rf"$\alpha={alpha}$", color=blue_gradient[i+4], linewidth=2)
ax[0].set_title(r"MAE of $y_1$ vs $p$ for $\alpha \in [10, 100]$", fontsize=20)
ax[0].set_xlabel("Number of system $p$", fontsize=19)
ax[0].set_yscale('log')
ax[0].tick_params(axis='x', labelsize=16)
ax[0].tick_params(axis='y', labelsize=16)
ax[0].grid()

# Plot for y2
for i, alpha in enumerate(alpha_list):
    ax[1].plot(p_list, np.array(MSE_list[i])[:, 1], label=rf"$\alpha={alpha}$", color=red_gradient[i+4], linewidth=2)

ax[1].set_title(r"MAE of $y_2$ vs $p$ for $\alpha \in [10, 100]$", fontsize=19)
ax[1].set_xlabel("Number of system $p$", fontsize=16)
ax[1].set_yscale('log')
ax[1].tick_params(axis='x', labelsize=16)
ax[1].tick_params(axis='y', labelsize=16)
ax[1].grid()

# Legend placement in a box under the plot
ax[0].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(alpha_list)//3, fontsize=12)
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(alpha_list)//3, fontsize=12)

plt.tight_layout()
plt.show()

## Third experiment
-   Fix $p$ and $\beta$
-   Changing $\alpha$

In [None]:
MSE_list = np.array(MSE_list)
MaxAE_list = np.array(MaxAE_list)

MAE_min = MSE_list.min(1)
MaxAE_min = MaxAE_list.min(1)

fig, ax = plt.subplots(1, figsize=(15, 4))

ax.plot(alpha_list, MAE_min[:, 0], "-o", label="$MAE$ ${y_1}$", linewidth=2, markersize=6)
ax.plot(alpha_list, MAE_min[:, 1],"-o", label="$MAE$ ${y_2}$", linewidth=2, markersize=6)
ax.plot(alpha_list, MaxAE_min[:, 0], "-x", color="#1f77b4", label="$MaxAE$ ${y_1}$", linewidth=2, markersize=8)
ax.plot(alpha_list, MaxAE_min[:, 1], "-x", color="#ff7f0e", label="$MaxAE$ ${y_2}$", linewidth=2, markersize=8)


ax.set_yscale("log")
ax.set_title(r"Mean and Max Absolute Error with increasing Stiffness", fontsize=20)
ax.set_xlabel(r'Stiffness parameter $\alpha$ and ratio $SR$', fontsize=16)
ax.set_ylabel('Absolute Error', fontsize=16)
ax.set_xticks(alpha_list, [r"$\alpha$=" + str(i) + "\n" +rf"$SR$=" + f"{i**2}" for i in alpha_list])
ax.set_yticks([0.1, 0.01, 0.001, 0.0001],
              [r"$10^{-1}$", r"$10^{-2}$", r"$10^{-3}$", r"$10^{-4}$"])
ax.grid()
ax.tick_params(axis='x', labelsize=9.5)
ax.tick_params(axis='y', labelsize=16)
ax.legend(loc='best', fontsize=14)

### Save MAE and MaxAE results over several alpha value

In [None]:
history = {}
history["alpha_list"] = alpha_list
history["mae_y1"] = MAE_min[:, 0].tolist()
history["mae_y2"] = MAE_min[:, 1].tolist()
history["maxae_y1"] = MaxAE_min[:, 0].tolist()
history["maxae_y2"] = MaxAE_min[:, 1].tolist()


current_path = Path.cwd().parent.parent
path = os.path.join(current_path, "result_history")
with open(os.path.join(path, "Duffing_Error_Transfer.json"),  "w") as fp:
    json.dump(history, fp)

# Comparative analysis with numerical methods on several alpha
- Solve iteratively for several alpha value
- Solve with:
    - PINNS transfer
    - RK45
    - Radeau

In [None]:
numerical_sol_fct_radau = lambda x, IC, alpha, beta=beta, domain=domain: numerical_non_linear_solution(alpha, beta, IC, domain, x,
                                                                                                                    method="Radau").y

numerical_sol_fct_rk45 = lambda x, IC, alpha, beta=beta, domain=domain: numerical_non_linear_solution(alpha, beta, IC, domain, x,
                                                                                                                   method="RK45").y


numerical_methods = {"RK45": numerical_sol_fct_rk45, "Radau": numerical_sol_fct_radau}

In [None]:
alpha_list_transfer = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150]
IC_transfer = IC_transfer
force_transfer = force_list[0]
log_scale = False
p = 10

computational_time = defaultdict(list)
max_error = defaultdict(list)
mean_error = defaultdict(list)

solution = defaultdict(list)

length = t_eval.shape[0]

for alpha in tqdm(alpha_list_transfer):
    alpha = alpha
    A_transfer = get_A(alpha)

    pinns_sol, _, _, _, total_time = solve_perturbation_TL(beta=beta, p=p, t_eval=t_eval,
                                                           alpha=alpha, A=A_transfer,
                                                           force=force_transfer, IC=IC_transfer,
                                                           H=H, H_0=H_0, dH_dt=dH_dt_new, dev=dev,
                                                           force_func_index=force_func_index,
                                                           numerical_sol_fct=numerical_sol_fct,
                                                           force_function_PINNS=force_function_PINNS,
                                                           force_function_numerical=force_function_numerical,
                                                           compute_numerical_pert=False,
                                                           numerical_perturbation_fct=numerical_perturbation_fct,
                                                           verbose=False)
    solution["PINNS"].append(pinns_sol.T)
    computational_time["PINNS"].append(total_time)

    # solve with numerical methods
    for method, fct in numerical_methods.items():
        start = time.time()
        numerical_sol = fct(t_eval.detach().cpu().numpy().squeeze(),
                            IC_transfer.detach().cpu().numpy().squeeze(),
                            alpha)
        solution[method].append(numerical_sol)
        end = time.time()
        computational_time[method].append(end-start)

### Plot the computational time

In [None]:
color = {"PINNS": 'orange', "RK45": 'b', "Radau": 'g', 'LSODA': 'm', "True": (1, 0, 0, 0.5)}
fig, ax = plt.subplots(1, tight_layout=True, figsize=(15, 4))

for method, compt_time in computational_time.items():
    ax.plot(alpha_list_transfer, compt_time, "-o", color=color[method], label=f"{method}")

ax.set_title("Computational time solving stiff equation", fontsize=20)
ax.set_xlabel(r'Stiffness parameter $\alpha$ and ratio $SR$', fontsize=16)
ax.set_ylabel('Time', fontsize=16)
ax.set_xticks(alpha_list_transfer, [r"$\alpha$=" + str(i) + "\n" +rf"$SR$={i**2}" for i in alpha_list_transfer])
ax.tick_params(axis='x', labelsize=11.5)
ax.tick_params(axis='y', labelsize=16)
ax.legend(loc='best', fontsize=16)
ax.grid()