In [1]:
from ptlpinns.odes import equations, numerical
from ptlpinns.models import model, transfer
import numpy as np
import time
import torch

Note: computational time scales with N for PINNs but not for the numerical solvers

In [2]:
N_ITER = 600

N = 512
t_span = (0, 15)
t_eval = np.linspace(t_span[0], t_span[1], N)

RK45_time, Radau_time, PTL_PINN_inverting, PTL_PINN_not_inverting = [], [], [], []

In [3]:
zeta_list = [0, 0.4, 0.8, 10, 30, 60]

w_list_transfer = [1, 1, 1, 1, 1, 1]

forcing_names = ['']

def forcing(numpy=False):
    if not numpy:
        def force(t):
            return torch.stack((torch.zeros_like(t), torch.zeros_like(t)), dim=1)
    else:
        def force(t):
            return np.stack((np.zeros_like(t), np.zeros_like(t)), axis=1)
    return force

forcing_list = [forcing(True), forcing(True), forcing(True), forcing(True), forcing(True), forcing(True)]

def zeroes_1D(t):
    return np.zeros_like(t)

forcing_1D = [zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D, zeroes_1D]

ic_list = [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]
p_list = [6, 6, 6, 6, 6, 6]
epsilon_list = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]

### RK45 and Radau times

In [None]:
total_RK45, total_Radau = [], []

for i, zeta in enumerate(zeta_list):

    ode = equations.ode_oscillator_1D(w_0=w_list_transfer[i], zeta=zeta, forcing_1D=forcing_1D[i], q=3, epsilon=0.5)

    print("solving for zeta =", zeta)
    RK45_time_list, Radau_time_list = [], []

    for j in range(N_ITER):

        start_RK45 = time.perf_counter()
        numerical.solve_ode_equation(ode, t_span, t_eval, ic_list[i], method="RK45", rtol=1e-5, atol=1e-5)[0]
        end_RK45 = time.perf_counter()

        start_Radau = time.perf_counter()
        numerical.solve_ode_equation(ode, t_span, t_eval, ic_list[i], method="Radau", rtol=1e-5, atol=1e-5)[0]
        end_Radau = time.perf_counter()

        RK45_time_list.append(end_RK45 - start_RK45)
        Radau_time_list.append(end_Radau - start_Radau)

    total_RK45.append(RK45_time_list)
    total_Radau.append(Radau_time_list)

    Radau_time.append(np.mean(RK45_time_list[i]))
    RK45_time.append(np.mean(Radau_time_list[i]))

solving for zeta = 0


KeyboardInterrupt: 

In [None]:
print("RK45 time:", RK45_time)
print("Radau time:", Radau_time)

RK45 time: [np.float64(0.049482706001072074), np.float64(0.007889464999607299), np.float64(0.005692542999895522), np.float64(0.0037235119998513255), np.float64(0.0027954460001637926), np.float64(0.0028651949996856274)]
Radau time: [np.float64(0.023131681000450044), np.float64(0.001959080000233371), np.float64(0.0016324119987984886), np.float64(0.007867510001233313), np.float64(0.022304087000520667), np.float64(0.04708660600044823)]


### PTL-PINNs

In [None]:
undamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/undamped_k12"
undamped_name = "model_undamped_k12.pth"
undamped_model, _ = model.load_model(undamped_path, undamped_name)

underdamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/underdamped_k12"
underdamped_name = "model_underdamped_k12.pth"
underdamped_model, _ = model.load_model(underdamped_path, underdamped_name)

overdamped_path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/overdamped_k12"
overdamped_name = "model_overdamped_k12.pth"
overdamped_model, _ = model.load_model(underdamped_path, underdamped_name)

12 True True True 1.0 16 [256, 256, 512]
12 True True True 1.0 16 [128, 128, 256]
12 True True True 1.0 16 [128, 128, 256]


In [None]:
# Compute latent representation: H(t) and derivatives
H_dict_undamped = transfer.compute_H_dict(undamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))
H_dict_underdamped = transfer.compute_H_dict(underdamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))
H_dict_overdamped = transfer.compute_H_dict(overdamped_model, N=N, bias=True, t_span=(t_span[0], t_span[1]))

training_log = {'w_ode': 1.5, 'w_ic': 1}

In [None]:
total_inverting, total_not_inverting = [], []

for i in range(len(zeta_list)):

    print("solving for zeta =", zeta_list[i])

    inverting, not_inverting = [], []

    for j in range(N_ITER):

        if zeta_list[i] == 0:
            solver = "LPM"
            H_dict = H_dict_undamped
        elif 0 < zeta_list[i] < 1:
            solver = "standard"
            H_dict = H_dict_underdamped
        else:
            solver = "standard"
            H_dict = H_dict_overdamped 

        # invert = True
        _, _, TL_time_inverting = transfer.compute_perturbation_solution([w_list_transfer[i]], [zeta_list[i]], [epsilon_list[i]], [p_list[i]],
                                                                [ic_list[i]], [forcing_list[i]], H_dict,
                                                                t_eval, training_log, all_p=True, comp_time=True,
                                                                solver=solver, w_sol = [], invert=True)
        
        # invert = False
        _, _, TL_time_not_inverting = transfer.compute_perturbation_solution([w_list_transfer[i]], [zeta_list[i]], [epsilon_list[i]], [p_list[i]],
                                                                [ic_list[i]], [forcing_list[i]], H_dict,
                                                                t_eval, training_log, all_p=True, comp_time=True,
                                                                solver=solver, w_sol = [], invert=False)
        
        inverting.append(TL_time_inverting[0])
        not_inverting.append(TL_time_not_inverting[0])

    total_inverting.append(inverting)
    total_not_inverting.append(not_inverting)

    PTL_PINN_inverting.append(np.mean(total_inverting[i]))
    PTL_PINN_not_inverting.append(np.mean(total_not_inverting[i]))

solving for zeta = 0


solving for zeta = 0.4
solving for zeta = 0.8
solving for zeta = 10
solving for zeta = 30
solving for zeta = 60


In [None]:
for i in range(len(zeta_list)):
    print(f"zeta: {zeta_list[i]} | RK45: {RK45_time[i]} | Radau: {Radau_time[i]} | PTL-PINN: {PTL_PINN_inverting[i]} | PTL-PINN no invert: {PTL_PINN_not_inverting[i]}")

zeta: 0 | RK45: 0.049482706001072074 | Radau: 0.023131681000450044 | PTL-PINN: 0.048844318390247284 | PTL-PINN no invert: 0.007226598568352832
zeta: 0.4 | RK45: 0.007889464999607299 | Radau: 0.001959080000233371 | PTL-PINN: 0.025139206404901415 | PTL-PINN no invert: 0.0021651787301652805
zeta: 0.8 | RK45: 0.005692542999895522 | Radau: 0.0016324119987984886 | PTL-PINN: 0.02670808205971298 | PTL-PINN no invert: 0.0024642330199306645
zeta: 10 | RK45: 0.0037235119998513255 | Radau: 0.007867510001233313 | PTL-PINN: 0.022381669343412795 | PTL-PINN no invert: 0.0020303827534705002
zeta: 30 | RK45: 0.0027954460001637926 | Radau: 0.022304087000520667 | PTL-PINN: 0.018605449413198586 | PTL-PINN no invert: 0.0022971531566205764
zeta: 60 | RK45: 0.0028651949996856274 | Radau: 0.04708660600044823 | PTL-PINN: 0.027975613555033002 | PTL-PINN no invert: 0.0025293407734958842
