In [1]:
from ptlpinns.odes import numerical
import numpy as np
import matplotlib.pyplot as plt
import time
from ptlpinns.models import train_PDE, model, transfer_kpp
import torch

In [2]:
N_ITER = 800

### Define time domain

In [None]:
Nx = 50
x_span = (0, 2) 
Nt = 50
t_span = (0, 5)

x = np.linspace(x_span[0], x_span[1], Nx)
t = np.linspace(t_span[0], t_span[1], Nt)
mesh_x, mesh_t = np.meshgrid(x, t)

### Equation parameters

In [4]:
# KPP Fisher
epsilon = 0.5
epsilons = [epsilon]
D = 0.1

def u_0_function(x):
    return np.sin(np.pi * x / 2)

u_0 = [u_0_function for _ in range(len(epsilons))]  
forcing = [lambda x, t: 0 for _ in range(len(epsilons))] 
bcs = [[lambda t: 0, lambda t: 0] for _ in range(len(epsilons))]  
polynomial = [lambda u: -u + u**2 for _ in range(len(epsilons))]

### PTL-PINN model

In [5]:
name = "model_KPP_Fisher_linear.pth"
path = "/home/dda24/PTL-PINNs/ptlpinns/models/train/KPP_Fisher_linear"
pinn = model.Multihead_model_PDE(k=9, bias=True)
pinn.load_state_dict(torch.load(f'{path}/{name}'))

bias = True
w_pde, w_bc, w_ic = 1, 10, 10
L, T = 2, 5

def ic_sin(input):
    x = input[:, 0].unsqueeze(1)
    return (torch.sin(3 * torch.pi * x / L)) ** 2

def constant_function(constant):
    '''return a function input  ((x,t)in torch) -> constant, can be given as function.

    for example constant_function(0) will return a function that returns 0 for any input.
    '''
    def forcing_function_constant(input):
        return constant*torch.ones_like(input[:, 0].unsqueeze(1))
    return forcing_function_constant

ic = ic_sin
bcs_ptl = [constant_function(0), constant_function(0)] 

training_log = {
    'name': name,
    'bias': bias,
    'k' : 9,
    'domain_info':{
        'L': L,
        'T': T,},
    'Nx' : Nx,
    'Nt' : Nt,
    'w_pde': w_pde,
    'w_bc': w_bc,
    'w_ic': w_ic,}

x, t, grid = train_PDE.generate_interior_tensor(IG=(Nx, Nt), x_span=(0, L), t_span=(0, T), require_grad=False)
input = torch.cat((x.unsqueeze(1), t.unsqueeze(1)), dim=1) 
p = 5
polynomial_ptl = [[-1,1],[1,2]]

forcing_ptl =constant_function(0)

In [6]:
H_dict = transfer_kpp.compute_H_dict(model = pinn, IG = (Nx,Nt), Nic = Nx,Nbc = Nt, bias = bias, x_span= x_span,t_span= t_span, D=D, log=training_log)
_, H_dict = transfer_kpp.compute_R_ic(H_dict,ic_function=ic, w_ic=w_ic,log=training_log)
_,_,_,H_dict = transfer_kpp.compute_R_bcs(H_dict,boundary_functions=bcs_ptl ,w_bc=w_bc, log=training_log)
_, _, _ = transfer_kpp.compute_M(H_dict= H_dict,w_pde=w_pde,w_ic=w_ic,w_bc=w_bc)

Differentiating H w.r.t. x now...
Finished computing H2x.
Differentiating H w.r.t. t now...
Finished computing Ht


### Measure time

In [7]:
time_ptl_invert, time_ptl_no_invert = [], []

for j in range(N_ITER):

    if (j % 200) == 0: 
        print(f"{j}/{N_ITER}")

    _, time_no_invert = transfer_kpp.compute_perturbation_solution_polynomial_complete(p, epsilon,H_dict= H_dict, input=input,training_log= training_log, forcing=forcing_ptl,Polynomial=polynomial_ptl,boundary_functions=bcs_ptl,ic_function=ic)

    start_inv = time.perf_counter()
    _, _, _ = transfer_kpp.compute_M(H_dict= H_dict,w_pde=w_pde,w_ic=w_ic,w_bc=w_bc)
    end_inv = time.perf_counter()
    time_invert = time_no_invert + (end_inv - start_inv)

    time_ptl_invert.append(time_invert)
    time_ptl_no_invert.append(time_no_invert)

0/800
200/800
400/800
600/800


In [8]:
time_RK45_KPP, time_Radau_KPP = [], []

for j in range(N_ITER):

    if (j % 200) == 0: 
        print(f"{j}/{N_ITER}")

    start_RK45 = time.perf_counter()
    numerical.solution_KPP(epsilons, D, polynomial, x_span, t_span, Nx, Nt, u_0, forcing, bcs, method="RK45", atol= 1e-3, rtol=1e-3).squeeze()
    end_RK45 = time.perf_counter()

    start_Radau = time.perf_counter()
    numerical.solution_KPP(epsilons, D, polynomial, x_span, t_span, Nx, Nt, u_0, forcing, bcs, method="Radau", atol= 1e-3, rtol=1e-3).squeeze()
    end_Radau = time.perf_counter()

    time_RK45_KPP.append(end_RK45 - start_RK45)
    time_Radau_KPP.append(end_Radau - start_Radau)

0/800
200/800
400/800
600/800


In [9]:
print(f"PTL-PINN-INVERT: {np.mean(time_ptl_invert)} | PTL-PINN-NO-INVERT: {np.mean(time_ptl_no_invert)} | RK45: {np.mean(time_RK45_KPP)} | Radau: {np.mean(time_Radau_KPP)}")

PTL-PINN-INVERT: 0.3111405244289153 | PTL-PINN-NO-INVERT: 0.015012532303908302 | RK45: 0.04776761804623675 | Radau: 0.006203890212473198
