In [10]:
from ptlpinns.odes import equations, numerical
from ptlpinns.models import model, transfer
import numpy as np
import time
import torch

Note: computational time scales with N for PINNs but not for the numerical solvers

In [11]:
N_ITER = 800

N = 512
t_span = (0, 10)
t_eval = np.linspace(t_span[0], t_span[1], N)

RK45_time, Radau_time, PTL_PINN_inverting, PTL_PINN_not_inverting = [], [], [], []

In [12]:
zeta_list = [0, 0.1, 0.5, 10, 30, 60]

w_list_transfer = [1, 1, 1, 1, 1, 1]

forcing_names = ['']

def forcing(numpy=False):
    if not numpy:
        def force(t):
            return torch.stack((torch.zeros_like(t), torch.zeros_like(t)), dim=1)
    else:
        def force(t):
            return np.stack((np.zeros_like(t), np.zeros_like(t)), axis=1)
    return force

forcing_list = [forcing(True), forcing(True), forcing(True), forcing(True), forcing(True), forcing(True)]

def zeroes_1D(t):
    return np.zeros_like(t)

forcing_1D = [zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D]

ic_list = [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]
p_list = [6, 6, 6, 6, 6, 6]
epsilon_list = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]

### RK45 and Radau times

In [13]:
total_RK45, total_Radau = [], []

for i, zeta in enumerate(zeta_list):

    ode = equations.ode_oscillator_1D(w_0=w_list_transfer[i], zeta=zeta, forcing_1D=forcing_1D[i], q=3, epsilon=0.5)

    print("solving for zeta =", zeta)
    RK45_time_list, Radau_time_list = [], []

    for j in range(N_ITER):

        start_RK45 = time.perf_counter()
        numerical.solve_ode_equation(ode, t_span, t_eval, ic_list[i], method="RK45", rtol=1e-5, atol=1e-5)[0]
        end_RK45 = time.perf_counter()

        start_Radau = time.perf_counter()
        numerical.solve_ode_equation(ode, t_span, t_eval, ic_list[i], method="Radau", rtol=1e-5, atol=1e-5)[0]
        end_Radau = time.perf_counter()

        RK45_time_list.append(end_RK45 - start_RK45)
        Radau_time_list.append(end_Radau - start_Radau)

    total_RK45.append(RK45_time_list)
    total_Radau.append(Radau_time_list)

    Radau_time.append(np.mean(RK45_time_list[i]))
    RK45_time.append(np.mean(Radau_time_list[i]))

solving for zeta = 0
solving for zeta = 0.1
solving for zeta = 0.5
solving for zeta = 10
solving for zeta = 30
solving for zeta = 60


In [14]:
print("RK45 time:", RK45_time)
print("Radau time:", Radau_time)

RK45 time: [np.float64(0.2333092199987732), np.float64(0.015433378001034725), np.float64(0.012522905999503564), np.float64(0.006585603001440177), np.float64(0.0038152439992700238), np.float64(0.0037876049973419867)]
Radau time: [np.float64(0.15488537100100075), np.float64(0.0037842919991817325), np.float64(0.002050667997536948), np.float64(0.012430689999746392), np.float64(0.022984351999184582), np.float64(0.05686163700011093)]


### PTL-PINNs

In [15]:
undamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/undamped_k12"
undamped_name = "model_undamped_k12.pth"
undamped_model, _ = model.load_model(undamped_path, undamped_name)

underdamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/underdamped_k12"
underdamped_name = "model_underdamped_k12.pth"
underdamped_model, _ = model.load_model(underdamped_path, underdamped_name)

overdamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/overdamped_k12"
overdamped_name = "model_overdamped_k12.pth"
overdamped_model, _ = model.load_model(underdamped_path, underdamped_name)

12 True True True 1.0 16 [256, 256, 512]
12 True True True 1.0 16 [128, 128, 256]
12 True True True 1.0 16 [128, 128, 256]


In [16]:
# Compute latent representation: H(t) and derivatives
H_dict_undamped = transfer.compute_H_dict(undamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))
H_dict_underdamped = transfer.compute_H_dict(underdamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))
H_dict_overdamped = transfer.compute_H_dict(overdamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))

training_log = {'w_ode': 1.5, 'w_ic': 1}

In [17]:
total_inverting, total_not_inverting = [], []

for i in range(len(zeta_list)):

    print("solving for zeta =", zeta_list[i])

    inverting, not_inverting = [], []

    for j in range(N_ITER):

        if zeta_list[i] == 0:
            solver = "LPM"
            H_dict = H_dict_undamped
        elif 0 < zeta_list[i] < 1:
            solver = "standard"
            H_dict = H_dict_underdamped
        else:
            solver = "standard"
            H_dict = H_dict_overdamped 

        # invert = True
        _, _, TL_time_inverting = transfer.compute_perturbation_solution([w_list_transfer[i]], [zeta_list[i]], [epsilon_list[i]], [p_list[i]],
                                                                [ic_list[i]], [forcing_list[i]], H_dict,
                                                                t_eval, training_log, all_p=True, comp_time=True,
                                                                solver=solver, w_sol = [], invert=True)
        
        # invert = False
        _, _, TL_time_not_inverting = transfer.compute_perturbation_solution([w_list_transfer[i]], [zeta_list[i]], [epsilon_list[i]], [p_list[i]],
                                                                [ic_list[i]], [forcing_list[i]], H_dict,
                                                                t_eval, training_log, all_p=True, comp_time=True,
                                                                solver=solver, w_sol = [], invert=False)
        
        inverting.append(TL_time_inverting[0])
        not_inverting.append(TL_time_not_inverting[0])

    total_inverting.append(inverting)
    total_not_inverting.append(not_inverting)

    PTL_PINN_inverting.append(np.mean(total_inverting[i]))
    PTL_PINN_not_inverting.append(np.mean(total_not_inverting[i]))

solving for zeta = 0


solving for zeta = 0.1
solving for zeta = 0.5
solving for zeta = 10
solving for zeta = 30
solving for zeta = 60


In [18]:
for i in range(len(zeta_list)):
    print(f"zeta: {zeta_list[i]} | RK45: {RK45_time[i]} | Radau: {Radau_time[i]} | PTL-PINN: {PTL_PINN_inverting[i]} | PTL-PINN no invert: {PTL_PINN_not_inverting[i]}")

zeta: 0 | RK45: 0.2333092199987732 | Radau: 0.15488537100100075 | PTL-PINN: 0.03836736403021405 | PTL-PINN no invert: 0.0050096903914936776
zeta: 0.1 | RK45: 0.015433378001034725 | Radau: 0.0037842919991817325 | PTL-PINN: 0.0173258756307996 | PTL-PINN no invert: 0.001589093969851092
zeta: 0.5 | RK45: 0.012522905999503564 | Radau: 0.002050667997536948 | PTL-PINN: 0.0187820573575209 | PTL-PINN no invert: 0.0015131778085969926
zeta: 10 | RK45: 0.006585603001440177 | Radau: 0.012430689999746392 | PTL-PINN: 0.011046555688612897 | PTL-PINN no invert: 0.0014413058691934565
zeta: 30 | RK45: 0.0038152439992700238 | Radau: 0.022984351999184582 | PTL-PINN: 0.011444397787358866 | PTL-PINN no invert: 0.0014432244375075242
zeta: 60 | RK45: 0.0037876049973419867 | Radau: 0.05686163700011093 | PTL-PINN: 0.018960542153763527 | PTL-PINN no invert: 0.0014522519152251334
