## 1. Import Library and Generate train-set

In [1]:
from torch_relukan import ReLUKANLayer, ReLUKAN
from torch_hrkan import HRKANLayer, HRKAN
import torch
import matplotlib.pyplot as plt
from torch import autograd
import torch.nn as nn
import numpy as np
from scipy.integrate._ivp.radau import P

from matplotlib import cm
from matplotlib.ticker import LinearLocator

from kan import KAN, LBFGS
from tqdm import tqdm
import time

from fft_burgers import fft_burgers

nx = 199
nt = 199
nu=0.001
xmin=0
xmax=2.5
tmin=0
tmax=2.5

hx = (xmax - xmin) / nx
ht = (tmax - tmin) / nt
noise_xstd = hx / 4.0
noise_tstd = ht / 4.0

alpha = 0.05
dim = 2
np_i = 100 
np_b = 100 

def batch_jacobian(func, x, create_graph=False):
    def _func_sum(x):
        return func(x).sum(dim=0)
    return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

# interior
sampling_mode = 'random'

x_mesh = torch.linspace(xmin, xmax, np_i)
y_mesh = torch.linspace(tmin, tmax, np_i)
X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij")
if sampling_mode == 'mesh':
    #mesh
    x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
else:
    #random
    x_i = torch.hstack([torch.rand((np_i**2,1))*xmax, torch.rand((np_i**2,1))*tmax])

# boundary, 4 sides
helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
xb1 = helper(X[0], Y[0])
xb2 = helper(X[-1], Y[0])
xb3 = helper(X[:,0], Y[:,0])
xb4 = helper(X[:,0], Y[:,-1])
x_bx = torch.cat([xb1, xb2], dim=0)
x_bt = xb3

## 2. Generate test-set

In [2]:
sampling_mode_test = 'mesh'

x_test_mesh = torch.linspace(xmin, xmax, (nx*8)+1)[::8]
y_test_mesh = torch.linspace(tmin, tmax, nt+1)
X_test, Y_test = torch.meshgrid(x_test_mesh, y_test_mesh, indexing="ij")
if sampling_mode_test == 'mesh':
    #mesh
    x_test_i = torch.stack([X_test.reshape(-1,), Y_test.reshape(-1,)]).permute(1,0)
else:
    #random
    x_test_i = torch.hstack([torch.rand((nx**2,1))*xmax, torch.rand((nt**2,1))*tmax])
    
xb1_test = helper(X_test[0], Y_test[0])
xb2_test = helper(X_test[-1], Y_test[0])
xb3_test = helper(X_test[:,0], Y_test[:,0])
xb4_test = helper(X_test[:,0], Y_test[:,-1])
x_test_b = torch.cat([xb1_test, xb2_test, xb3_test, xb4_test], dim=0)
# x_test_bt = xb3_test
x_test_bx = torch.cat([xb1_test, xb2_test], dim=0)
x_test_bt = xb3_test

X_test_np = X_test.clone().detach().numpy()
Y_test_np = Y_test.clone().detach().numpy()
x_test = torch.stack([X_test.reshape(-1,), Y_test.reshape(-1,)]).permute(1,0)

def get_sol(x, t):
    """ use FFT method """
    x = np.linspace(xmin, xmax, (nx*8)+1)
    t = np.linspace(tmin, tmax, nt+1)
    sol = fft_burgers(x, t, nu)
    sol = sol[:,::8]
    sol = sol.T
    return sol

sol = get_sol(X_test, Y_test)

## 3. Check if have GPU and move data there

In [3]:
if torch.cuda.is_available():
    x_i = x_i.cuda()
    x_bx = x_bx.cuda()
    x_bt = x_bt.cuda()
    x_test = x_test.cuda()
    x_test_i = x_test_i.cuda()
    x_test_bx = x_test_bx.cuda()
    x_test_bt = x_test_bt.cuda()

## 4. Define the function to train ReLUKAN and HRKAN

In [5]:
def train_model(model):

    opt = torch.optim.Adam(model.parameters())

    plt.ion()
    losses = []
    pde_losses = []
    bc_losses = []
    pde_losses_test = []
    bc_losses_test = []
    l2_losses_test = []
    l2_losses_std_test= []

    start = time.time()

    for e in range(3000):
        opt.zero_grad()
        
        sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
        sol_D1_x = sol_D1_fun(x_i)[:,[0]]
        sol_D1_t = sol_D1_fun(x_i)[:,[1]]
        
        sol_D1_fun_x = lambda x: sol_D1_fun(x)[:,[0]]
        sol_D2 = batch_jacobian(sol_D1_fun_x, x_i, create_graph=True)[:,:,0]
        
        pde_loss = torch.mean((sol_D1_t + model(x_i) * sol_D1_x/4 - nu * sol_D2/16)**2)
        bc_loss = torch.mean((model(x_bx))**2) + torch.mean((model(x_bt)-1/torch.cosh(4*x_bt[:,[0]]-5))**2)
        loss = alpha * pde_loss + bc_loss
        loss.backward()
        opt.step()

        with torch.no_grad():
        
            sol_D1_x_test = sol_D1_fun(x_test_i)[:,[0]]
            sol_D1_t_test = sol_D1_fun(x_test_i)[:,[1]]
            
            sol_D2_test = batch_jacobian(sol_D1_fun_x, x_test_i, create_graph=False)[:,:,0]
            
            pde_loss_test = torch.mean((sol_D1_t_test + model(x_test_i) * sol_D1_x_test/4 - nu * sol_D2_test/16)**2)
            bc_loss_test = torch.mean((model(x_test_bx))**2) + torch.mean((model(x_test_bt)-1/torch.cosh(4*x_test_bt[:,[0]]-5))**2)

            l2_test = torch.mean((model(x_test).cpu().clone().detach() - sol.reshape(-1,1))**2)
            l2_test_std = torch.std((model(x_test).cpu().clone().detach() - sol.reshape(-1,1))**2)
    
            pde_losses.append(pde_loss.cpu().detach().numpy())
            bc_losses.append(bc_loss.cpu().detach().numpy())
            pde_losses_test.append(pde_loss_test.cpu().detach().numpy())
            bc_losses_test.append(bc_loss_test.cpu().detach().numpy())
            l2_losses_test.append(l2_test.cpu().detach().numpy())
            l2_losses_std_test.append(l2_test_std.cpu().detach().numpy())

    elapsed = (time.time() - start)
        
    with torch.no_grad():
        output = model(x_test).cpu().clone().detach().numpy().reshape(X_test_np.shape)
    
    return output, losses, pde_losses, bc_losses, pde_losses_test, bc_losses_test, l2_losses_test, l2_losses_std_test, elapsed

## 5. Define the function to train KAN

In [6]:
def train_kan(model):

    opt = torch.optim.Adam(model.parameters())

    plt.ion()
    losses = []
    pde_losses = []
    bc_losses = []
    pde_losses_test = []
    bc_losses_test = []
    l2_losses_test = []
    l2_losses_std_test= []

    start = time.time()

    for e in range(3000):
        
        if e % 5 == 0 and e < 50:
            model.update_grid_from_samples(x_i)
            
        opt.zero_grad()
        
        sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
        sol_D1_x = sol_D1_fun(x_i)[:,[0]]
        sol_D1_t = sol_D1_fun(x_i)[:,[1]]
        
        sol_D1_fun_x = lambda x: sol_D1_fun(x)[:,[0]]
        sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,0]
        
        pde_loss = torch.mean((sol_D1_t + model(x_i) * sol_D1_x/4 - nu * sol_D2/16)**2)
        bc_loss = torch.mean((model(x_bx))**2) + torch.mean((model(x_bt)-1/torch.cosh(4*x_bt[:,[0]]-5))**2)
        loss = alpha * pde_loss + bc_loss
        loss.backward()
        opt.step()

        with torch.no_grad():
        
            sol_D1_x_test = sol_D1_fun(x_test_i)[:,[0]]
            sol_D1_t_test = sol_D1_fun(x_test_i)[:,[1]]
            
            sol_D2_test = batch_jacobian(sol_D1_fun, x_test_i, create_graph=False)[:,:,0]
            
            pde_loss_test = torch.mean((sol_D1_t_test + model(x_test_i) * sol_D1_x_test/4 - nu * sol_D2_test/16)**2)
            bc_loss_test = torch.mean((model(x_test_bx))**2) + torch.mean((model(x_test_bt)-1/torch.cosh(4*x_test_bt[:,[0]]-5))**2)

            l2_test = torch.mean((model(x_test).cpu().clone().detach() - sol.reshape(-1,1))**2)
            l2_test_std = torch.std((model(x_test).cpu().clone().detach() - sol.reshape(-1,1))**2)
    
            pde_losses.append(pde_loss.cpu().detach().numpy())
            bc_losses.append(bc_loss.cpu().detach().numpy())
            pde_losses_test.append(pde_loss_test.cpu().detach().numpy())
            bc_losses_test.append(bc_loss_test.cpu().detach().numpy())
            l2_losses_test.append(l2_test.cpu().detach().numpy())
            l2_losses_std_test.append(l2_test_std.cpu().detach().numpy())

    elapsed = (time.time() - start)

    with torch.no_grad():
        output = model(x_test).cpu().clone().detach().numpy().reshape(X_test_np.shape)
    
    return output, losses, pde_losses, bc_losses, pde_losses_test, bc_losses_test, l2_losses_test, l2_losses_std_test, elapsed

## 6. Define the function to plot ground-truth solution, learnt solutions and their residual difference

In [9]:
def plot_fig(relu_kan_, hrkan_, kan_, i):
    fig, axs = plt.subplots(2, 4, figsize=(25,8), subplot_kw={"projection": "3d"})
    fig.suptitle('Solutions and their residual difference')
    
    #true solution
    surf1 = axs[0,0].plot_surface(X_test_np, Y_test_np, sol.reshape(X_test_np.shape), cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[0,0].set_zlim(-1.01, 1.01)
    axs[0,0].zaxis.set_major_locator(LinearLocator(10))
    axs[0,0].zaxis.set_major_formatter('{x:.02f}')
    axs[0,0].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="Ground-truth")
    
    # RELU_KAN
    surf2 = axs[0,1].plot_surface(X_test_np, Y_test_np, relu_kan_[0], cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[0,1].set_zlim(-1.01, 1.01)
    axs[0,1].zaxis.set_major_locator(LinearLocator(10))
    axs[0,1].zaxis.set_major_formatter('{x:.02f}')
    axs[0,1].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="ReLU-Kan solution")
    
    surf3 = axs[1,1].plot_surface(X_test_np, Y_test_np, (sol - relu_kan_[0]), cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[1,1].set_zlim(-1.01, 1.01)
    axs[1,1].zaxis.set_major_locator(LinearLocator(10))
    axs[1,1].zaxis.set_major_formatter('{x:.02f}')
    axs[1,1].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="ReLU-Kan residual")
    
    # HRKAN
    surf4 = axs[0,2].plot_surface(X_test_np, Y_test_np, hrkan_[0], cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[0,2].set_zlim(-1.01, 1.01)
    axs[0,2].zaxis.set_major_locator(LinearLocator(10))
    axs[0,2].zaxis.set_major_formatter('{x:.02f}')
    axs[0,2].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="HRKan solution")
    
    surf5 = axs[1,2].plot_surface(X_test_np, Y_test_np, (sol - hrkan_[0]), cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[1,2].set_zlim(-1.01, 1.01)
    axs[1,2].zaxis.set_major_locator(LinearLocator(10))
    axs[1,2].zaxis.set_major_formatter('{x:.02f}')
    axs[1,2].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="HRKan residual")
    
    # KAN
    surf6 = axs[0,3].plot_surface(X_test_np, Y_test_np, kan_[0], cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[0,3].set_zlim(-1.01, 1.01)
    axs[0,3].zaxis.set_major_locator(LinearLocator(10))
    axs[0,3].zaxis.set_major_formatter('{x:.02f}')
    axs[0,3].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="Kan solution")
    
    surf7 = axs[1,3].plot_surface(X_test_np, Y_test_np, (sol - kan_[0]), cmap=cm.coolwarm,linewidth=0, antialiased=False)
    axs[1,3].set_zlim(-1.01, 1.01)
    axs[1,3].zaxis.set_major_locator(LinearLocator(10))
    axs[1,3].zaxis.set_major_formatter('{x:.02f}')
    axs[1,3].set(xlabel='x', ylabel='t', zlabel='u(x,t)', title="Kan residual")
    
    cb1 = fig.colorbar(surf1, ax=axs, orientation='vertical')

    plt.savefig(f'Burgersv_fig_{i}.png')
    plt.close()



## 7. Define the function to calculate MSE, MSE std. and training time

In [10]:
relu_kan_loss, relu_kan_loss_test, relu_kan_L2s, relu_kan_L2s_std, relu_kan_time = [], [], [], [], []
hrkan_loss, hrkan_loss_test, hrkan_L2s, hrkan_L2s_std, hrkan_time = [], [], [], [], []
kan_loss, kan_loss_test, kan_L2s, kan_L2s_std, kan_time = [], [], [], [], []
def cal_error(relu_kan_, hrkan_, kan_, i):
    relu_kan_loss.append([alpha * x + y for x, y in zip(relu_kan_[2], relu_kan_[3])])
    relu_kan_loss_test.append([alpha * x + y for x, y in zip(relu_kan_[4], relu_kan_[5])])
    relu_kan_L2s.append(relu_kan_[-3])
    relu_kan_L2s_std.append(relu_kan_[-2])
    relu_kan_time.append(relu_kan_[-1])
    hrkan_loss.append([alpha * x + y for x, y in zip(hrkan_[2], hrkan_[3])])
    hrkan_loss_test.append([alpha * x + y for x, y in zip(hrkan_[4], hrkan_[5])])
    hrkan_L2s.append(hrkan_[-3])
    hrkan_L2s_std.append(hrkan_[-2])
    hrkan_time.append(hrkan_[-1])
    kan_loss.append([alpha * x + y for x, y in zip(kan_[2], kan_[3])])
    kan_loss_test.append([alpha * x + y for x, y in zip(kan_[4], kan_[5])])
    kan_L2s.append(kan_[-3])
    kan_L2s_std.append(kan_[-2])
    kan_time.append(kan_[-1])

## 8. Train the models

In [1]:
relu_kans = []
hrkans = []
kans = []
for i in range(10):
    print(i)
    relu_kan = ReLUKAN([2,3,3,3,1], 7, 3, 0, 2.5)
    relu_kan = relu_kan.cuda()
    relu_kan_results = train_model(relu_kan)
    del relu_kan
    hrkan = HRKAN([2,3,3,3,1], 7, 3, 0, 2.5, 4)
    hrkan = hrkan.cuda()
    hrkan_results = train_model(hrkan)
    del hrkan
    kan = KAN(width=[2,3,3,3,1], grid=7, k=3, grid_eps=1.0, device='cuda')
    kan_results = train_kan(kan)
    del hrkan
    plot_fig(relu_kan_results, hrkan_results, kan_results, i)
    cal_error(relu_kan_results, hrkan_results, kan_results, i)

## 8. Plot the median losses and MSE

In [14]:
k_median=5
q1_quantile = 0
q2_quantile = -1
fig, axs = plt.subplots(3, 3, figsize=(25,18))
fig.suptitle('Loss and accuracy (median and max-min-band of 10 runs)')

axs[0,0].plot(np.arange(3000),np.sort(np.array(hrkan_loss), axis=0)[k_median,:])
axs[0,0].plot(np.arange(3000),np.sort(np.array(relu_kan_loss), axis=0)[k_median,:])
axs[0,0].plot(np.arange(3000),np.sort(np.array(kan_loss), axis=0)[k_median,:])
axs[0,0].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[0,0].fill_between(np.arange(3000),np.sort(np.array(hrkan_loss), axis=0)[q1_quantile,:], np.sort(np.array(hrkan_loss), axis=0)[q2_quantile,:], alpha=0.3)
axs[0,0].fill_between(np.arange(3000),np.sort(np.array(relu_kan_loss), axis=0)[q1_quantile,:], np.sort(np.array(relu_kan_loss), axis=0)[q2_quantile,:], alpha=0.3)
axs[0,0].fill_between(np.arange(3000),np.sort(np.array(kan_loss), axis=0)[q1_quantile,:], np.sort(np.array(kan_loss), axis=0)[q2_quantile,:], alpha=0.3)
axs[0,0].set(xlabel='epoch', ylabel='loss', title="Training loss")

axs[0,1].plot(np.arange(1000,3000),np.sort(np.array(hrkan_loss), axis=0)[k_median,1000:])
axs[0,1].plot(np.arange(1000,3000),np.sort(np.array(relu_kan_loss), axis=0)[k_median,1000:])
axs[0,1].plot(np.arange(1000,3000),np.sort(np.array(kan_loss), axis=0)[k_median,1000:])
axs[0,1].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[0,1].fill_between(np.arange(1000,3000),np.sort(np.array(hrkan_loss), axis=0)[q1_quantile,1000:], np.sort(np.array(hrkan_loss), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[0,1].fill_between(np.arange(1000,3000),np.sort(np.array(relu_kan_loss), axis=0)[q1_quantile,1000:], np.sort(np.array(relu_kan_loss), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[0,1].fill_between(np.arange(1000,3000),np.sort(np.array(kan_loss), axis=0)[q1_quantile,1000:], np.sort(np.array(kan_loss), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[0,1].set(xlabel='epoch', ylabel='loss', title="Training loss")

axs[0,2].plot(np.arange(2000,3000),np.sort(np.array(hrkan_loss), axis=0)[k_median,2000:])
axs[0,2].plot(np.arange(2000,3000),np.sort(np.array(relu_kan_loss), axis=0)[k_median,2000:])
axs[0,2].plot(np.arange(2000,3000),np.sort(np.array(kan_loss), axis=0)[k_median,2000:])
axs[0,2].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[0,2].fill_between(np.arange(2000,3000),np.sort(np.array(hrkan_loss), axis=0)[q1_quantile,2000:], np.sort(np.array(hrkan_loss), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[0,2].fill_between(np.arange(2000,3000),np.sort(np.array(relu_kan_loss), axis=0)[q1_quantile,2000:], np.sort(np.array(relu_kan_loss), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[0,2].fill_between(np.arange(2000,3000),np.sort(np.array(kan_loss), axis=0)[q1_quantile,2000:], np.sort(np.array(kan_loss), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[0,2].set(xlabel='epoch', ylabel='loss', title="Training loss")

axs[1,0].plot(np.arange(3000),np.sort(np.array(hrkan_loss_test), axis=0)[k_median,:])
axs[1,0].plot(np.arange(3000),np.sort(np.array(relu_kan_loss_test), axis=0)[k_median,:])
axs[1,0].plot(np.arange(3000),np.sort(np.array(kan_loss_test), axis=0)[k_median,:])
axs[1,0].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[1,0].fill_between(np.arange(3000),np.sort(np.array(hrkan_loss_test), axis=0)[q1_quantile,:], np.sort(np.array(hrkan_loss_test), axis=0)[q2_quantile,:], alpha=0.3)
axs[1,0].fill_between(np.arange(3000),np.sort(np.array(relu_kan_loss_test), axis=0)[q1_quantile,:], np.sort(np.array(relu_kan_loss_test), axis=0)[q2_quantile,:], alpha=0.3)
axs[1,0].fill_between(np.arange(3000),np.sort(np.array(kan_loss_test), axis=0)[q1_quantile,:], np.sort(np.array(kan_loss_test), axis=0)[q2_quantile,:], alpha=0.3)
axs[1,0].set(xlabel='epoch', ylabel='loss', title="Test loss")

axs[1,1].plot(np.arange(1000,3000),np.sort(np.array(hrkan_loss_test), axis=0)[k_median,1000:])
axs[1,1].plot(np.arange(1000,3000),np.sort(np.array(relu_kan_loss_test), axis=0)[k_median,1000:])
axs[1,1].plot(np.arange(1000,3000),np.sort(np.array(kan_loss_test), axis=0)[k_median,1000:])
axs[1,1].fill_between(np.arange(1000,3000),np.sort(np.array(hrkan_loss_test), axis=0)[q1_quantile,1000:], np.sort(np.array(hrkan_loss_test), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[1,1].fill_between(np.arange(1000,3000),np.sort(np.array(relu_kan_loss_test), axis=0)[q1_quantile,1000:], np.sort(np.array(relu_kan_loss_test), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[1,1].fill_between(np.arange(1000,3000),np.sort(np.array(kan_loss_test), axis=0)[q1_quantile,1000:], np.sort(np.array(kan_loss_test), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[1,1].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[1,1].set(xlabel='epoch', ylabel='loss', title="Test loss")

axs[1,2].plot(np.arange(2000,3000),np.sort(np.array(hrkan_loss_test), axis=0)[k_median,2000:])
axs[1,2].plot(np.arange(2000,3000),np.sort(np.array(relu_kan_loss_test), axis=0)[k_median,2000:])
axs[1,2].plot(np.arange(2000,3000),np.sort(np.array(kan_loss_test), axis=0)[k_median,2000:])
axs[1,2].fill_between(np.arange(2000,3000),np.sort(np.array(hrkan_loss_test), axis=0)[q1_quantile,2000:], np.sort(np.array(hrkan_loss_test), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[1,2].fill_between(np.arange(2000,3000),np.sort(np.array(relu_kan_loss_test), axis=0)[q1_quantile,2000:], np.sort(np.array(relu_kan_loss_test), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[1,2].fill_between(np.arange(2000,3000),np.sort(np.array(kan_loss_test), axis=0)[q1_quantile,2000:], np.sort(np.array(kan_loss_test), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[1,2].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[1,2].set(xlabel='epoch', ylabel='loss', title="Test loss")

axs[2,0].plot(np.arange(3000),np.sort(np.array(hrkan_L2s), axis=0)[k_median,:])
axs[2,0].plot(np.arange(3000),np.sort(np.array(relu_kan_L2s), axis=0)[k_median,:])
axs[2,0].plot(np.arange(3000),np.sort(np.array(kan_L2s), axis=0)[k_median,:])
axs[2,0].fill_between(np.arange(3000),np.sort(np.array(hrkan_L2s), axis=0)[q1_quantile,:], np.sort(np.array(hrkan_L2s), axis=0)[q2_quantile,:], alpha=0.3)
axs[2,0].fill_between(np.arange(3000),np.sort(np.array(relu_kan_L2s), axis=0)[q1_quantile,:], np.sort(np.array(relu_kan_L2s), axis=0)[q2_quantile,:], alpha=0.3)
axs[2,0].fill_between(np.arange(3000),np.sort(np.array(kan_L2s), axis=0)[q1_quantile,:], np.sort(np.array(kan_L2s), axis=0)[q2_quantile,:], alpha=0.3)
axs[2,0].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[2,0].set(xlabel='epoch', ylabel='MSE', title="Test MSE")

axs[2,1].plot(np.arange(1000,3000),np.sort(np.array(hrkan_L2s), axis=0)[k_median,1000:])
axs[2,1].plot(np.arange(1000,3000),np.sort(np.array(relu_kan_L2s), axis=0)[k_median,1000:])
axs[2,1].plot(np.arange(1000,3000),np.sort(np.array(kan_L2s), axis=0)[k_median,1000:])
axs[2,1].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[2,1].fill_between(np.arange(1000,3000),np.sort(np.array(hrkan_L2s), axis=0)[q1_quantile,1000:], np.sort(np.array(hrkan_L2s), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[2,1].fill_between(np.arange(1000,3000),np.sort(np.array(relu_kan_L2s), axis=0)[q1_quantile,1000:], np.sort(np.array(relu_kan_L2s), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[2,1].fill_between(np.arange(1000,3000),np.sort(np.array(kan_L2s), axis=0)[q1_quantile,1000:], np.sort(np.array(kan_L2s), axis=0)[q2_quantile,1000:], alpha=0.3)
axs[2,1].set(xlabel='epoch', ylabel='MSE', title="Test MSE")

axs[2,2].plot(np.arange(2000,3000),np.sort(np.array(hrkan_L2s), axis=0)[k_median,2000:])
axs[2,2].plot(np.arange(2000,3000),np.sort(np.array(relu_kan_L2s), axis=0)[k_median,2000:])
axs[2,2].plot(np.arange(2000,3000),np.sort(np.array(kan_L2s), axis=0)[k_median,2000:])
axs[2,2].legend(['HRKan', 'ReLU-Kan', 'Kan'])
axs[2,2].fill_between(np.arange(2000,3000),np.sort(np.array(hrkan_L2s), axis=0)[q1_quantile,2000:], np.sort(np.array(hrkan_L2s), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[2,2].fill_between(np.arange(2000,3000),np.sort(np.array(relu_kan_L2s), axis=0)[q1_quantile,2000:], np.sort(np.array(relu_kan_L2s), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[2,2].fill_between(np.arange(2000,3000),np.sort(np.array(kan_L2s), axis=0)[q1_quantile,2000:], np.sort(np.array(kan_L2s), axis=0)[q2_quantile,2000:], alpha=0.3)
axs[2,2].set(xlabel='epoch', ylabel='MSE', title="Test MSE")

plt.savefig(f'Burgersv_loss_band.png', dpi=400)
plt.close()

## 9. Calculate mean MSE, mean MSE std. and mean training time

In [2]:
print( "{0:.5g}".format(np.array(relu_kan_L2s)[:,-1].mean()), "{0:.5g}".format(np.array(relu_kan_L2s_std)[:,-1].mean()), "{0:.5g}".format(np.array(relu_kan_time)[-1].mean()))
print( "{0:.5g}".format(np.array(hrkan_L2s)[:,-1].mean()), "{0:.5g}".format(np.array(hrkan_L2s_std)[:,-1].mean()), "{0:.5g}".format(np.array(hrkan_time)[-1].mean()))
print( "{0:.5g}".format(np.array(kan_L2s)[:,-1].mean()), "{0:.5g}".format(np.array(kan_L2s_std)[:,-1].mean()), "{0:.5g}".format(np.array(kan_time)[-1].mean()))