# Notebook to Change IC or force function inside the stiff regime

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# imports
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
import sys
from scipy.integrate import solve_ivp
from collections import defaultdict
from tqdm import tqdm

# Add parent directory to sys.path
from pathlib import Path
current_path = Path.cwd()
parent_dir = current_path.parent.parent
sys.path.append(str(parent_dir))

from src.transfer_learning import compute_H_and_dH_dt, compute_M_inv, compute_force_term, compute_W_with_IC
from src.utils_plot import plot_loss_and_all_solution
from src.load_save import load_run_history

torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

In [None]:
def check_versions_and_device():
  # set the device to the GPU if it is available, otherwise use the CPU
  current_dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  torch.cuda.empty_cache()
  !nvidia-smi
  return current_dev

# set a global device variable to use in code
dev = check_versions_and_device()
print(dev)

### Load the pretrain model and history

In [None]:
file_name = "final_inference_2081553"
equation_name = "DHO"

trained_model, \
x_range, \
iterations, \
hid_lay, \
num_equations, \
num_heads, \
loss_hist, \
alpha_list, \
A_list, \
IC_list, \
force_list = load_run_history(equation_name, file_name, dev, prev=False)

# function to get A for alpha value and DHO equation
def get_A(alpha):
    return torch.tensor([[0., -1.], [1., 2*alpha]], device=dev).double()


### Define numerical solution

In [None]:
# Numerical solution
def double_coupled_equation(t, y, A, force):
    return np.array([force[0].detach().item() - A[0][1] * y[1] - A[0][0] * y[0],
                     force[1].detach().item() - A[1][0] * y[0] - A[1][1] * y[1]])

r_tol = 1e-4
numerical_sol_fct_radau = lambda x, v, A, force: (solve_ivp(double_coupled_equation, [x_range[0], x_range[1]],
                                                  v.squeeze(), args=(A, force), t_eval=x.squeeze(), method="Radau", rtol=r_tol).y)

numerical_sol_fct_rk45 = lambda x, v, A, force: (solve_ivp(double_coupled_equation, [x_range[0], x_range[1]],
                                                    v.squeeze(), args=(A, force), t_eval=x.squeeze(), method="RK45", rtol=r_tol).y)

numerical_methods = {"RK45": numerical_sol_fct_rk45, "Radau": numerical_sol_fct_radau}

### Plot training result

In [None]:
plot_loss_and_all_solution(x_range=x_range, true_functs=numerical_sol_fct_radau,
                           trained_model=trained_model, IC_list=IC_list, A_list=A_list,
                           force=force_list, train_losses=loss_hist, device=dev)

## Transfer Learning inside the sitff domain
### Extract H 

In [None]:
# forward pass to extract H
size = 512
H, H_0, dH_dt_new, t_eval = compute_H_and_dH_dt(x_range[0], x_range[1], trained_model, num_equations, hid_lay, size, dev)

### 1. Change IC in a stiff regime
### 2. Change force in a stiff regime
### 3. Change IC and force in a stiff regime

In [None]:
def random_IC(x_bound=[0, 5], y_bound=[0, 5]):
    ICx = np.random.uniform(x_bound[0], x_bound[1], 1)
    ICy = np.random.uniform(y_bound[0], y_bound[1], 1)
    return torch.tensor([ICx, ICy], device=dev)

def random_force(force1_bound=[0, 2], force2_bound=[0, 2]):
    force1 = np.random.uniform(force1_bound[0], force1_bound[1], 1)
    force2 = np.random.uniform(force2_bound[0], force2_bound[1], 1)
    return torch.tensor([force1, force2], device=dev)

### Choose what to change (Initials Condition and Force function)

In [None]:
change_IC = True
change_force = False

### Precompute the M matrix in the stiff regime

In [None]:
alpha_transfer = 50
A_transfer = get_A(alpha=alpha_transfer)

force_transfer = force_list[0]
IC_transfer = IC_list[0]

A_transfer = A_transfer.double()
dH_dt_new = dH_dt_new.double()
H = H.double()
H_0 = H_0.double()

M_inv = compute_M_inv(dH_dt_new, H, H_0, t_eval, A_transfer)

if not change_force:
    force_terms = compute_force_term(t_eval, A_transfer, force_transfer, H, dH_dt_new)

if not change_IC:
    IC_term = torch.matmul(H_0.T, IC_transfer)

### Compute 1000 solution with changing IC and force

In [None]:
np.random.seed(42)
nb_transfer_equation = 1000

IC_transfer_list = [random_IC() for _ in range(nb_transfer_equation)] if change_IC else [IC_transfer for _ in range(nb_transfer_equation)]

force_transfer_list = [random_force() for _ in range(nb_transfer_equation)] if change_force else [force_transfer for _ in range(nb_transfer_equation)]

computational_time = defaultdict(list)
max_error = defaultdict(list)
mean_error = defaultdict(list)
solution = defaultdict(list)

for IC_transfer, force_transfer in tqdm(zip(IC_transfer_list, force_transfer_list), total=nb_transfer_equation):
    # PINNS (change only IC)
    if (change_IC) & (~change_force):
        start = time.time()
        W_out, _ = compute_W_with_IC(M_inv, force_terms, IC_transfer, H_0)
        pinns_sol = torch.matmul(H, W_out)
        end = time.time()
    
    # PINNS (change only force)
    elif (change_force) & (~change_IC):
        start = time.time()
        force_terms = compute_force_term(t_eval, A_transfer, force_transfer, H, dH_dt_new)
        rhs_terms = force_terms + IC_term
        W_out = torch.matmul(M_inv, rhs_terms)
        pinns_sol = torch.matmul(H, W_out)
        end = time.time()

    # PINNS (change IC and force)
    elif (change_force) & (change_IC):
        start = time.time()
        force_terms = compute_force_term(t_eval, A_transfer, force_transfer, H, dH_dt_new)
        W_out, _ = compute_W_with_IC(M_inv, force_terms, IC_transfer, H_0)
        pinns_sol = torch.matmul(H, W_out)
        end = time.time()

    solution["PINNS"].append(np.swapaxes(pinns_sol.detach().cpu().numpy().squeeze(), 0, 1))
    computational_time["PINNS"].append(end-start)

    # solve with numerical methods
    for method, fct in numerical_methods.items():
        start = time.time()
        numerical_sol = fct(t_eval.detach().cpu().numpy(), IC_transfer.detach().cpu(),
                            A_transfer.cpu(),
                            force_transfer.detach().cpu())
        solution[method].append(numerical_sol)
        end = time.time()
        computational_time[method].append(end-start)

### Plot average computational time

In [None]:
color = {"PINNS": 'orange', "RK45": 'b', "Radau": 'g', "True": (1, 0, 0, 0.5)}
fig, ax = plt.subplots(1, tight_layout=True, figsize=(8, 4))

height = 0.6  # Change width to height for horizontal bars

for i, (method, compt_time) in enumerate(computational_time.items()):
    ax.barh(i, sum(compt_time)/len(compt_time), height=height, color=color[method], label=f"{method}")
    ax.annotate(f'{sum(compt_time)/len(compt_time):1.2e}', (sum(compt_time)/len(compt_time), i),
                ha='left', va='center', fontsize=12)  # Adjusted annotation placement for horizontal bars

ax.set_xscale("log")  # Change yscale to xscale for horizontal bars
change_title = "IC and force" if (change_force and change_IC) else ("force" if change_force else "IC")
ax.set_xlabel('Time (s)', fontsize=16)  # Change ylabel to xlabel for horizontal bars
ax.set_yticks([i for i in range(len(computational_time))])
ax.set_yticklabels(computational_time.keys())  # Change xticks to yticks and set_xticks to set_yticks
ax.tick_params(axis='y', labelsize=14)
ax.tick_params(axis='x', labelsize=16)
fig.tight_layout()
ax.margins(x=0.15)
plt.show()