In [172]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [173]:
%cd /content/drive/MyDrive/Github/

/content/drive/MyDrive/Github


In [174]:
!git clone https://github.com/jaysulk/PINO_Applications

fatal: destination path 'PINO_Applications' already exists and is not an empty directory.


In [175]:
%cd /content/drive/MyDrive/Github/PINO_Applications

/content/drive/MyDrive/Github/PINO_Applications


In [176]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [177]:
!pip install functorch



In [178]:
!pip install Mat73



In [179]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%matplotlib notebook
from argparse import ArgumentParser
import yaml
import os
import torch
# from torch import vmap
from functorch import vmap, grad

from models import FNN2d, FNN2d_AD
from train_utils import Adam
# from train_utils.datasets import BurgersLoader'
# from train_utils.train_2d import train_2d_burger
# from train_utils.eval_2d import eval_burgers

import traceback

import scipy.io
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

import imageio

import torch.nn.functional as F

from tqdm import tqdm
from train_utils.utils import get_grid, save_checkpoint, torch2dgrid, load_checkpoint, load_config, update_config
from train_utils.losses import LpLoss
from train_utils.datasets import DataLoader1D
# from utils import torch2dgrid
from solver.my_random_fields import GRF_Mattern

from importlib import reload

try:
    import wandb
except ImportError:
    wandb = None
import pickle

In [180]:
# Dummy functions for Q and w
def dummy_Q(a, b, T):
    return a + b * T
    #return torch.zeros_like(T)

def dummy_w(c, d, z):
    return c + d * z**2

def hydrostatic_pressure(rho, g, z, p0):
    return p0 - rho * g * z

In [181]:
def central_difference(data, axis, dz):
    data_m1 = torch.roll(data, shifts=1, dims=axis)
    data_p1 = torch.roll(data, shifts=-1, dims=axis)
    data_diff = (data_p1 - data_m1) / (2.0 * dz)
    return data_diff

In [182]:
def FDM_ThermodynamicEnergy(T, D=1, dt=0.1, dz=0.1, a=1, b=0, c=1, d=0, rho=1.0, g=9.81, cp=1005):
    batchsize = T.size(0)
    nt = T.size(1)
    nz = T.size(2)

    # FFT of T to get Fourier representation
    #T_h = torch.fft.fft(T, dim=2)

    # Create wave numbers for FFT
    #k_max = nz // 2
    #k_z = torch.cat((torch.arange(start=0, end=k_max, step=1, device=T.device),
    #                 torch.arange(start=-k_max, end=0, step=1, device=T.device)), 0).reshape(1, 1, nz)

    # Calculate Fourier representation of spatial derivative dT/dz
    #Tz_h = 2j * np.pi * k_z * T_h
    #Tz = torch.fft.irfft(Tz_h[:, :, :k_max+1], dim=2, n=nz)
    Tz = central_difference(T, axis=2, dz=dz)
    #Tz = (T[:, 2:, :] - T[:, :-2, :]) / (2 * dz)

    # Use centered difference to calculate temporal derivative dT/dt
    Tt = central_difference(T, axis=2, dz=dt)
    #Tt = (T[:, 2:, :] - T[:, :-2, :]) / (2 * dt)

    # Calculate Q using dummy_Q
    Q = dummy_Q(a, b, T)

    # Calculate w_values (use a better function in your code)
    w_values = dummy_w(c, d, torch.arange(nz, device=T.device) * dz).unsqueeze(0).unsqueeze(1)
    w_values = w_values.repeat(batchsize, nt, 1)

    # Calculate the Fourier representation of the spatial derivative dw/dz
    #w_h = torch.fft.fft(w_values, dim=2)
    #wz_h = 2j * np.pi * k_z * w_h
    #wz = torch.fft.irfft(wz_h[:, :, :k_max+1], dim=2, n=nz)
    wz = central_difference(w_values, axis=2, dz=dz)
    #wz = (w_values[:, 2:, :] - w_values[:, :-2, :]) / (2 * dz)

    # Equation using hydrostatic balance to determine change in T with respect to t
    DT = Tt + (-w_values * Tz + (g * torch.arange(nz, device=T.device) * dz / cp)[None, None, :] * wz + Q)

    return DT

def PINO_loss_ThermodynamicEnergy(T, T0, dz=0.1, a=1, b=0, c=1, d=0, rho=1.0, g=9.81, cp=1005):
    # Get the size of the input tensor T
    batchsize, nt, nz = T.shape

    # Extract boundary temperatures based on the initial time conditions
    boundary_T = T[:, 0, :]

    # Calculate the MSE loss between the boundary temperature and the initial temperature
    loss_T = F.mse_loss(boundary_T, T0)

    # Calculate the change in thermodynamic energy using the FDM_ThermodynamicEnergy function
    # Note: Make sure to pass in the additional parameters for pressure calculation based on hydrostatic equation
    DT = FDM_ThermodynamicEnergy(T, D=1, dt=0.1, dz=dz, a=a, b=b, c=c, d=d, rho=rho, g=g, cp=cp)

    # Calculate the MSE loss for the residuals
    f = torch.zeros_like(DT)
    loss_f = F.mse_loss(DT, f)

    return loss_T, loss_f


In [183]:
import torch
import matplotlib.pyplot as plt

class ThermodynamicsEq1D():
    def __init__(self, w=0.01, a=1.0, b=0.0, c=1.0, d=0.0, g=9.8, cp=1005, zmin=0, zmax=1, Nz=100, dt=1e-3, tend=1.0, device=None, dtype=torch.float64):
        self.zmin = zmin
        self.zmax = zmax
        self.Nz = Nz
        z = torch.linspace(zmin, zmax, Nz + 1, device=device, dtype=dtype)[:-1]
        self.z = z
        self.dz = z[1] - z[0]
        self.u = torch.zeros_like(z, device=device)
        self.u0 = torch.zeros_like(self.u, device=device)
        self.dt = dt
        self.tend = tend
        self.t = 0
        self.it = 0
        self.U = []
        self.T = []
        self.device = device
        self.w = w
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        self.g = g
        self.cp = cp

    def CD_i(self, data):
        data_m1 = torch.roll(data, shifts=1, dims=0)
        data_p1 = torch.roll(data, shifts=-1, dims=0)
        data_diff_i = (data_p1 - data_m1) / (2.0 * self.dz)
        return data_diff_i

    def Dz(self, data):
        data_dz = self.CD_i(data=data, axis=0, dz=self.dz)
        return data_dz

    def w_func(self, z):
        return dummy_w(self.c, self.d, z)

    def Q_func(self, T):
        return dummy_Q(self.a, self.b, T)

    def thermodynamics_calc_RHS(self, T):
        T_z = self.CD_i(T)
        w_z = self.CD_i(torch.full_like(T, self.w))
        T_RHS = -self.w * T_z + self.g * self.z / self.cp * w_z + self.Q_func(T)
        return T_RHS

    def update_field(self, field, RHS, step_frac):
        field_new = field + self.dt * step_frac * RHS
        return field_new

    def rk4_merge_RHS(self, field, RHS1, RHS2, RHS3, RHS4):
        field_new = field + self.dt / 6.0 * (RHS1 + 2 * RHS2 + 2.0 * RHS3 + RHS4)
        return field_new

    def thermodynamics_rk4(self, T, t=0):
        T_RHS1 = self.thermodynamics_calc_RHS(T)
        t1 = t + 0.5 * self.dt
        T1 = self.update_field(T, T_RHS1, step_frac=0.5)

        T_RHS2 = self.thermodynamics_calc_RHS(T1)
        t2 = t + 0.5 * self.dt
        T2 = self.update_field(T, T_RHS2, step_frac=0.5)

        T_RHS3 = self.thermodynamics_calc_RHS(T2)
        t3 = t + self.dt
        T3 = self.update_field(T, T_RHS3, step_frac=1.0)

        T_RHS4 = self.thermodynamics_calc_RHS(T3)

        t_new = t + self.dt
        T_new = self.rk4_merge_RHS(T, T_RHS1, T_RHS2, T_RHS3, T_RHS4)

        return T_new, t_new

    def plot_data(self, cmap='jet', vmin=None, vmax=None, fig_num=0, title='', xlabel='', ylabel=''):
        plt.ion()
        fig = plt.figure(fig_num)
        plt.cla()
        plt.clf()
        plt.plot(self.z, self.u)
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.draw()
        plt.pause(1e-17)
        plt.show()

    def thermodynamics_driver(self, T0, save_interval=10, plot_interval=0):
        self.u0 = T0[:self.Nz]
        self.u = self.u0
        self.t = 0
        self.it = 0
        self.T = []
        self.U = []

        if plot_interval != 0 and self.it % plot_interval == 0:
            self.plot_data(vmin=-1, vmax=1, title=r'\{T}')
        if save_interval != 0 and self.it % save_interval == 0:
            self.U.append(self.u)
            self.T.append(self.t)

        while self.t < self.tend:
            self.u, self.t = self.thermodynamics_rk4(self.u, self.t)
            self.it += 1
            if plot_interval != 0 and self.it % plot_interval == 0:
                self.plot_data(vmin=-1, vmax=1, title=r'\{T}')
            if save_interval != 0 and self.it % save_interval == 0:
                self.U.append(self.u)
                self.T.append(self.t)

        return torch.stack(self.U)

In [184]:
from tqdm import tqdm
# Make sure to import wandb, PINO_loss_ThermodynamicEnergy, LpLoss and other dependencies

def train_ThermodynamicEnergy(model,
                              train_loader,
                              optimizer,
                              scheduler,
                              config,
                              rank=0,
                              log=False,
                              project='PINO-2d-default',
                              group='default',
                              tags=['default'],
                              use_tqdm=True):
    if rank == 0 and log:
        import wandb  # Make sure wandb is installed
        run = wandb.init(project=project,
                         entity='shawngr2',
                         group=group,
                         config=config,
                         tags=tags, reinit=True,
                         settings=wandb.Settings(start_method="fork"))

    # Extract weights and other parameters from the config dictionary
    data_weight = config.get('train', {}).get('xy_loss', 1.0)
    f_weight = config.get('train', {}).get('f_loss', 1.0)
    ic_weight = config.get('train', {}).get('ic_loss', 1.0)
    ckpt_freq = config.get('train', {}).get('ckpt_freq', 10)

    model.train()
    myloss = LpLoss(size_average=True)  # Assuming LpLoss is defined

    pbar = range(config['train']['epochs'])
    if use_tqdm:
        pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.1)

    for e in pbar:
        model.train()
        train_pino = 0.0
        data_l2 = 0.0
        train_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(rank), y.to(rank)
            out = model(x).reshape(y.shape)
            data_loss = myloss(out, y)

            loss_u, loss_f = PINO_loss_ThermodynamicEnergy(out, x[:, 0, :, 0])  # Ensure w(z) is implemented in the loss function
            total_loss = loss_u * ic_weight + loss_f * f_weight + data_loss * data_weight

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            data_l2 += data_loss.item()
            train_pino += loss_f.item()
            train_loss += total_loss.item()

        scheduler.step()
        data_l2 /= len(train_loader)
        train_pino /= len(train_loader)
        train_loss /= len(train_loader)

        if use_tqdm:
            pbar.set_description(
                (
                    f'Epoch {e}, train loss: {train_loss:.5f} '
                    f'train f error: {train_pino:.5f}; '
                    f'data l2 error: {data_l2:.5f}'
                )
            )

        if log:
            wandb.log(
                {
                    'Train f error': train_pino,
                    'Train L2 error': data_l2,
                    'Train loss': train_loss,
                }
            )

        if e % ckpt_freq == 0:
            save_checkpoint(config['train']['save_dir'],
                            config['train']['save_name'].replace('.pt', f'_{e}.pt'),
                            model, optimizer)

    save_checkpoint(config['train']['save_dir'],
                    config['train']['save_name'],
                    model, optimizer)
    print('Done!')

In [185]:
import torch
import numpy as np
from tqdm import tqdm
# Make sure to import LpLoss, PINO_loss_ThermodynamicEnergy, etc.

def eval_ThermodynamicEnergy(model,
                             dataloader,
                             config,
                             device,
                             use_tqdm=True):
    model.eval()
    model.to(device)
    myloss = LpLoss(size_average=True)  # Assuming LpLoss is defined
    nu = config.get('data', {}).get('nu', 0.01)

    if use_tqdm:
        pbar = tqdm(dataloader, dynamic_ncols=True, smoothing=0.05)
    else:
        pbar = dataloader

    test_err = []
    f_err = []

    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            out = model(x).reshape(y.shape)
            data_loss = myloss(out, y)

            loss_u, f_loss = PINO_loss_ThermodynamicEnergy(out, x[:, 0, :, 0])
            test_err.append(data_loss.item())
            f_err.append(f_loss.item())

    mean_f_err = np.mean(f_err)
    std_f_err = np.std(f_err, ddof=1) / np.sqrt(len(f_err))

    mean_err = np.mean(test_err)
    std_err = np.std(test_err, ddof=1) / np.sqrt(len(test_err))

    print(f'==Averaged relative L2 error mean: {mean_err}, std error: {std_err}==\n'
          f'==Averaged equation error mean: {mean_f_err}, std error: {std_f_err}==')


In [186]:
config_file = 'configs/custom/TDE-0000.yaml'
config = load_config(config_file)
display(config)

{'data': {'name': 'TDE-0000',
  'total_num': 100,
  'n_train': 90,
  'n_test': 10,
  'nx': 128,
  'nt': 100,
  'sub': 1,
  'sub_t': 1,
  'nu': 0.01},
 'model': {'layers': [16, 24, 24, 32, 32],
  'modes1': [15, 12, 9, 9],
  'modes2': [15, 12, 9, 9],
  'fc_dim': 128,
  'activation': 'gelu'},
 'train': {'batchsize': 20,
  'epochs': 500,
  'milestones': [100, 200, 300, 400, 500],
  'base_lr': 0.001,
  'scheduler_gamma': 0.5,
  'ic_loss': 1.0,
  'f_loss': 1.0,
  'xy_loss': 1.0,
  'save_dir': 'TDE',
  'save_name': 'TDE-0000.pt',
  'ckpt': 'checkpoints/TDE/TDE-0000.pt',
  'ckpt_freq': 100},
 'log': {'project': 'PINO-TDE', 'group': 'TDE-0000'},
 'test': {'batchsize': 1, 'ckpt': 'checkpoints/TDE/TDE-0000.pt'}}

In [187]:
Nsamples = config['data']['total_num']
N = config['data']['nx']
Nt0 = config['data']['nt']
nu = config['data']['nu']
sub_x = config['data']['sub']
sub_t = config['data']['sub_t']
Nx = N // sub_x
Nt = Nt0 // sub_t + 1
dim = 1
l = 0.1
L = 1.0
sigma = 1 #2.0
Nu = None # 2.0
dt = 1.0e-4
tend = 1.0
save_int = int(tend/dt/Nt)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [188]:
Nu

In [189]:
grf = GRF_Mattern(dim, N, length=L, nu=Nu, l=l, sigma=sigma, boundary="periodic", device=device)
#U0 = grf.sample(Nsamples)
#with open('U0.pkl', 'wb') as f:
#  pickle.dump(U0, f)

In [190]:
# Save the object to a file
with open('../../T0.pkl', 'rb') as f:
  U0 = pickle.load(f)

In [191]:
U0.shape

torch.Size([100, 128])

In [192]:
nu

0.01

In [193]:
TD_eq = ThermodynamicsEq1D(Nz=Nx, dt=dt, device=device)
save_interval = int(1e-2/dt)
U = vmap(TD_eq.thermodynamics_driver, in_dims=(0, None))(U0, save_interval)

In [194]:
a = U0.cpu().float()
u = U.cpu().float()
display(u.shape,a.shape)

torch.Size([100, 101, 128])

torch.Size([100, 128])

In [195]:
dataset = DataLoader1D(a, u, config['data']['nx'], config['data']['nt'])
train_loader = dataset.make_loader(config['data']['n_train'], config['train']['batchsize'], start=0, train=True)
test_loader = dataset.make_loader(config['data']['n_test'], config['test']['batchsize'], start=config['data']['n_train'], train=False)

In [196]:
model = FNN2d(modes1=config['model']['modes1'],
              modes2=config['model']['modes2'],
              fc_dim=config['model']['fc_dim'],
              layers=config['model']['layers'],
              activation=config['model']['activation'],
             ).to(device)

In [197]:
log = False

#optimizer = Adam(model.parameters(), betas=(0.9, 0.999),lr=config['train']['base_lr'])
optimizer = Adam(model.parameters(), betas=(0.9, 0.999),lr=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=config['train']['milestones'],
                                                 gamma=config['train']['scheduler_gamma'])

In [198]:
#load_checkpoint(model, ckpt_path=config['train']['ckpt'], optimizer=None)

In [None]:
train_ThermodynamicEnergy(model,
              train_loader,
              optimizer,
              scheduler,
              config,
              rank=0,
              log=log,
              project=config['log']['project'],
              group=config['log']['group'])

Epoch 0, train loss: 410581.45630 train f error: 1.00000; data l2 error: 405.36213:   0%|          | 1/500 [00:00<03:55,  2.12it/s]

Checkpoint is saved at checkpoints/TDE/TDE-0000_0.pt


Epoch 100, train loss: 1.74529 train f error: 1.00000; data l2 error: 0.60170:  20%|██        | 101/500 [00:32<02:00,  3.32it/s]

Checkpoint is saved at checkpoints/TDE/TDE-0000_100.pt


Epoch 200, train loss: 1.74443 train f error: 1.00000; data l2 error: 0.60145:  40%|████      | 201/500 [01:01<01:30,  3.29it/s]

Checkpoint is saved at checkpoints/TDE/TDE-0000_200.pt


Epoch 300, train loss: 1.74151 train f error: 1.00000; data l2 error: 0.59538:  60%|██████    | 301/500 [01:31<00:59,  3.32it/s]

Checkpoint is saved at checkpoints/TDE/TDE-0000_300.pt


Epoch 400, train loss: 1.74593 train f error: 1.00000; data l2 error: 0.60033:  80%|████████  | 401/500 [02:01<00:30,  3.27it/s]

Checkpoint is saved at checkpoints/TDE/TDE-0000_400.pt


Epoch 428, train loss: 1.74468 train f error: 1.00000; data l2 error: 0.59884:  86%|████████▌ | 429/500 [02:09<00:21,  3.38it/s]

In [None]:
eval_ThermodynamicEnergy(model, test_loader, config, device)

In [None]:
Nx = config['data']['nx']
Nt = config['data']['nt'] + 1
N = config['data']['n_test']
model.eval()
test_x = np.zeros((N,Nt,Nx,3))
preds_y = np.zeros((N,Nt,Nx))
test_y = np.zeros((N,Nt,Nx))
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data_x, data_y = data
        data_x, data_y = data_x.to(device), data_y.to(device)
        pred_y = model(data_x).reshape(data_y.shape)
        test_x[i] = data_x.cpu().numpy()
        test_y[i] = data_y.cpu().numpy()
        preds_y[i] = pred_y.cpu().numpy()
#     data_loss = myloss(out, y)

In [None]:
Nx = config['data']['nx']
Nt = config['data']['nt'] + 1
N = config['data']['n_test']
model.eval()
test_x = np.zeros((N,Nt,Nx,3))
preds_y = np.zeros((N,Nt,Nx))
test_y = np.zeros((N,Nt,Nx))
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data_x, data_y = data
        data_x, data_y = data_x.to(device), data_y.to(device)
        pred_y = model(data_x).reshape(data_y.shape)
        test_x[i] = data_x.cpu().numpy()
        test_y[i] = data_y.cpu().numpy()
        preds_y[i] = pred_y.cpu().numpy()
#     data_loss = myloss(out, y)

In [None]:
data_x.shape

In [None]:
use_train_data = False
padding = 5
batch_size = config['test']['batchsize']
Nx = config['data']['nx']
# Ny = config['data']['nx']
Nt = config['data']['nt'] + 1
Ntest = config['data']['n_test']
Ntrain = config['data']['n_train']
loader = test_loader
if use_train_data:
    Ntest = Ntrain
    loader = train_loader
# in_dim = config['model']['in_dim']
# out_dim = config['model']['out_dim']

model.eval()
# model.to('cpu')
test_x = np.zeros((Ntest,Nt,Nx,3))
preds_y = np.zeros((Ntest,Nt,Nx))
test_y = np.zeros((Ntest,Nt,Nx))


with torch.no_grad():
    for i, data in enumerate(loader):
#     for i, data in enumerate(train_loader):
        data_x, data_y = data
        data_x, data_y = data_x.to(device), data_y.to(device)
#         data_x_pad = F.pad()
#         display(data_x.shape)
        data_x_pad = F.pad(data_x, (0, 0, 0, 0, 0, padding), "constant", 0)
        pred_y_pad = model(data_x_pad).reshape(batch_size, Nt + padding, Nx)
#         out = out[..., :-padding, :]
#         pred_y_pad = model(data_x_pad).reshape(batch_size, Nx, Ny, Nt + padding, out_dim)
        pred_y = pred_y_pad[..., :-padding, :].reshape(data_y.shape)
#         pred_y = model(data_x).reshape(data_y.shape)
        test_x[i] = data_x.cpu().numpy()
        test_y[i] = data_y.cpu().numpy()
#         test_y0[i] = data_x[..., 0, -out_dim:].cpu().numpy() # same way as in training code
        preds_y[i] = pred_y.cpu().numpy()
#     data_loss = myloss(out, y)

In [None]:
len(preds_y)

In [None]:
key = 0
pred = preds_y[key]
true = test_y[key]


a = test_x[key]
Nt, Nx, _ = a.shape
u0 = a[0,:,0]
T = a[:,:,2]
X = a[:,:,1]
x = X[0]

In [None]:
fig = plt.figure(figsize=(24,5))
plt.subplot(1,4,1)

plt.plot(x, u0)
plt.xlabel('$x$')
plt.ylabel('$u$')
plt.title('Intial Condition $u(x)$')
plt.xlim([0,1])
plt.tight_layout()

plt.subplot(1,4,2)
# plt.pcolor(XX,TT, S_test, cmap='jet')
plt.pcolormesh(X, T, true, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$t$')
plt.title(f'Exact $s(x,t)$')
plt.tight_layout()
plt.axis('square')

plt.subplot(1,4,3)
# plt.pcolor(XX,TT, S_pred, cmap='jet')
plt.pcolormesh(X, T, pred, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$t$')
plt.title(f'Predict $s(x,t)$')
plt.axis('square')

plt.tight_layout()

plt.subplot(1,4,4)
# plt.pcolor(XX,TT, S_pred - S_test, cmap='jet')
plt.pcolormesh(X, T, pred - true, cmap='jet', shading='gouraud')
plt.colorbar()
plt.xlabel('$x$')
plt.ylabel('$t$')
plt.title('Absolute error')
plt.tight_layout()
plt.axis('square')

# plt.show()

In [None]:
def save_data(data_path, test_x, test_y, preds_y):
    data_dir, data_filename = os.path.split(data_path)
    os.makedirs(data_dir, exist_ok=True)
    np.savez(data_path, test_x=test_x, test_y=test_y, preds_y=preds_y)

def load_data(data_path):
    data = np.load(data_path)
    test_x = data['test_x']
    test_y = data['test_y']
    preds_y = data['preds_y']
    return test_x, test_y, preds_y

In [None]:
data_dir = 'data/Burgers1D'
data_filename = 'data.npz'
data_path = os.path.join(data_dir, data_filename)
# os.makedirs(data_dir, exist_ok=True)


In [None]:
save_data(data_path, test_x, test_y, preds_y)

In [None]:
test_x, test_y, preds_y = load_data(data_path)

In [None]:
def plot_predictions(key, test_x, test_y, preds_y, print_index=False, save_path=None, font_size=None):
    if font_size is not None:
        plt.rcParams.update({'font.size': font_size})
    pred = preds_y[key]
    true = test_y[key]


    a = test_x[key]
    Nt, Nx, _ = a.shape
    u0 = a[0,:,0]
    T = a[:,:,2]
    X = a[:,:,1]
    x = X[0]

    # Plot
    fig = plt.figure(figsize=(23,5))
    plt.subplot(1,4,1)

    plt.plot(x, u0)
    plt.xlabel('$z$')
    plt.ylabel('$T$')
    plt.title('Intial Condition $u(x)$')
    plt.xlim([0,1])
    plt.ylim([-1,1])
    plt.tight_layout()

    plt.subplot(1,4,2)
    # plt.pcolor(XX,TT, S_test, cmap='jet')
    plt.pcolormesh(X, T, true, cmap='jet', shading='gouraud')
    cbar = plt.colorbar()
    #cbar.mappable.set_clim(-2, 2)
    plt.xlabel('$z$')
    plt.ylabel('$t$')
    plt.title(f'Exact $T(z,t)$')
    plt.tight_layout()
    plt.axis('square')

    plt.subplot(1,4,3)
    # plt.pcolor(XX,TT, S_pred, cmap='jet')
    plt.pcolormesh(X, T, pred, cmap='jet', shading='gouraud')
    cbar = plt.colorbar()
    #cbar.mappable.set_clim(-2, 2)
    plt.xlabel('$z$')
    plt.ylabel('$t$')
    plt.title(f'Predict $T(z,t)$')
    plt.axis('square')

    plt.tight_layout()

    plt.subplot(1,4,4)
    # plt.pcolor(XX,TT, S_pred - S_test, cmap='jet')
    plt.pcolormesh(X, T, pred - true, cmap='jet', shading='gouraud')
    cbar = plt.colorbar()
    #cbar.mappable.set_clim(-0.2, 0.2)
    plt.xlabel('$z$')
    plt.ylabel('$t$')
    plt.title('Absolute Error')
    plt.tight_layout()
    plt.axis('square')

    if save_path is not None:
        plt.savefig(f'{save_path}.png', bbox_inches='tight')
    plt.show()

In [None]:
%matplotlib inline
figures_dir = '../../TDE1D/FNO/figures/'
os.makedirs(figures_dir, exist_ok=True)
font_size = 12
for key in range(len(preds_y)):
    save_path = os.path.join(figures_dir, f'TDE1D_{key}')
    plot_predictions(key, test_x, test_y, preds_y, print_index=True, save_path=save_path, font_size=font_size)
