In [None]:
import sys
sys.path.append('./data')
sys.path.append('./utils')
sys.path.append('./losses')
sys.path.append('./models')
sys.path.append('./saved_models')

from data import *
from utils import *
from losses import *
from models import *
from training_scripts import *

from matplotlib.colors import LinearSegmentedColormap

# for interactive figure
# %matplotlib widget 
# %config InlineBackend.figure_format = 'svg'
# for normal figure
%matplotlib inline 

# Automatic module reload
%reload_ext autoreload
%autoreload 2

# Random Seeds
torch.set_default_dtype(torch.float32)
torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Loading Saved Models

In [None]:

# ============================================================================================
# Final Model Selection

casename     = 'data_20241125_alpha_reshuffle_3_normalised'
data_path    = './data/processed/%s/'%(casename)

label = '_final'
# label = '_PI_models'

# =======================================================================
if label == '_final':
    model_paths = [
                'saved_models/%s/conv_ae/model_final%s.pt'%(casename,'_with_L2_loss_with_conv'), # working ones
                'saved_models/%s/conv_ae/model_final%s.pt'%(casename,'_PILoss_scaled_optimised_with_conv_multi-latent-layer'),# working ones
                'saved_models/%s/unet/model_final%s.pt'%(casename,'_with_L2_loss'),
                'saved_models/%s/unet/model_final%s.pt'%(casename,'_PILoss_scaled_optimised'),
                'saved_models/%s/fno2d/model_final%s.pt'%(casename,'_with_L2_loss_30_mode'),
                'saved_models/%s/fno2d/model_final%s.pt'%(casename,'_PILoss_scaled_optimised_30_mode'),
                    ]
    model_name_list = [
                    r'$\mathcal{M}_{AE-data}$',
                    r'$\mathcal{M}_{AE-PI}$',
                    r'$\mathcal{M}_{UNET-data}$',
                    r'$\mathcal{M}_{UNET-PI}$',
                    r'$\mathcal{M}_{FNO-data}$',
                    r'$\mathcal{M}_{FNO-PI}$'
                    ]

# =======================================================================
elif label == '_PI_models':
    model_paths = [
                'saved_models/%s/conv_ae/model_final%s.pt'%(casename,'_PILoss_scaled_optimised_with_conv_multi-latent-layer'),
                'saved_models/%s/unet/model_final%s.pt'%(casename,'_PILoss_scaled_optimised'),
                'saved_models/%s/fno2d/model_final%s.pt'%(casename,'_PILoss_scaled_optimised_30_mode'),
                    ]
    model_name_list = [
                    r'$\mathcal{M}_{AE-PI}$',
                    r'$\mathcal{M}_{UNET-PI}$',
                    r'$\mathcal{M}_{FNO-PI}$'
                    ]

model_list   = model_loader(model_paths, device)

for i in range(len(model_list)):
    model_list[i].casename   = casename
    model_list[i].model_name = model_name_list[i]
    
dim          = model_list[0].dim
in_channels  = model_list[0].in_channels
out_channels = model_list[0].out_channels

# Data Visualisation

In [None]:
# data loading

dataset = '' # '', '_train', '_valid', '' '_res_256' 
label = label+dataset

if dataset in ['', '_train', '_valid']:
    
    # original dataset reading
    train_a, train_u, \
        valid_a, valid_u, \
            test_a, test_u = data_loader(processed_save_path = data_path, 
                                         in_channels = in_channels, out_channels = out_channels, 
                                         dim = dim)
    
    if dataset == '_train':
        test_a, test_u = train_a, train_u
    elif dataset == '_valid':
        test_a, test_u = valid_a, valid_u

elif dataset == '_res_256':
    
    # dataset with resolution of 256
    data_file = 'test_256_resized.mat'
    test_a, test_u = read_from_mat(data_path+data_file)
    
else:
    raise TypeError('Unrecognized dataset %s'%(dataset))

# prediction
model_out, error_out = model_predictions(model_list, test_a, test_u, device, 
                                         error_mode = 'relative')
print(model_out.shape)

# original data
df = pd.read_csv(data_path+'norm_info.csv')
norm_info = np.array(df)[:,2:].astype(float)
test_a_origin, test_u_origin = denormalisation(test_a, test_u, norm_info)

model_out_origin, _ = model_predictions(model_list, test_a, test_u, device, 
                                         error_mode = 'relative',
                                         origin_recover = True, norm_info = norm_info)
print(model_out.shape)

# Prediction error metrics:
print_error_statistics( casename, model_list, model_out, test_a, test_u, 
                        out_channel_names = ['P', 'U_x', 'U_y'],
                        eval_mode_list = [ 'R2', 'MSE', 'L2_rel'], PI_Loss=True,
                        if_print = True, if_save = 1, overwrite = True,
                        file_path = 'results/%s/'%(casename),
                        file_name = 'metric_summary'+label,
                        )


## Prediction and Error Visualisation

In [None]:
fig_save            = 1
fig_overwrite       = True
fig_file_path       = 'results/%s/figures/%s/'%(casename,label)
fig_format_list     = ['.png', '.eps']

ax_side_length      = 1.2
wspace              = 0.01
hspace              = 0.02
group_col_sep       = 0.1
lvl2_title_pad      = 0.4
lvl2_title_font     = 15
border_off          = True

in_title    = [r'$\alpha$'] # ['Gamma']
out_title   = [r'$P$', r'$U_x$', r'$U_y$']
same_in     = True
cmap_in     = ['gray_r']
vmin_in     = None
vmax_in     = None


# halving cmap if needed
pressure_cmap = 'coolwarm'
velocity_cmap = 'coolwarm'  # 'coolwarm' 'bwr' #'Spectral_r' get_parula_map()
cmap2 = velocity_cmap

rejection           = 3


plot_index          = range(0,300,30)

# for shuffle 0 data_20241125_alpha_reshuffle_normalised
plot_index = np.array([5,6,8,17,18,38,44,53,64,78,93])-1
plot_index = np.array([6, 53, 64, 5, 44, 18, 8, 17, 38, 93])-1

# for shuffle 3 data_20241125_alpha_reshuffle_normalised
# index_list = [np.array([25 , 112, 93 , 14 , 298])-1,
#               np.array([289, 186, 88 , 260, 175])-1,
#               np.array([53 , 49 , 1  , 281, 195])-1,
#               np.array([4  , 136, 159, 19 , 20 ])-1,
#               np.array([60 , 102, 7  , 215, 28 ])-1,
#               np.array([248, 114, 126, 294, 38 ])-1]
plot_index = np.array([25, 298, 136, 20 ,248, 126, 191, 294, 38, 195])-1


group_name = '_3_fig'
index_list = [np.array([25 , 93 , 298])-1,
              np.array([289, 260, 175])-1,
              np.array([53 , 281, 195])-1,
              np.array([4  , 136, 159])-1,
              np.array([60 , 102, 7  ])-1,
              np.array([248, 114, 126])-1]

In [None]:
fig, ax = prediction_visualisation(model_list, 
                                   model_out_origin[:,plot_index,:,:,:], 
                                   test_a[plot_index,:,:,:], 
                                   test_u_origin[plot_index,:,:,:], 
                                   ax_side_length = ax_side_length, wspace = wspace, hspace = hspace, 
                                   group_col_sep = group_col_sep,
                                   lvl2_title_pad = lvl2_title_pad, 
                                   lvl2_title_font = lvl2_title_font, 
                                   border_off = border_off,
                                   in_title = in_title, out_title = out_title,
                                   same_in = same_in, cmap_in = cmap_in, 
                                #    vmin_in = vmin_in, vmax_in = vmax_in,
                                   vmin_in = None, vmax_in = [1,0],
                                   same_out = False, cmap_out = [pressure_cmap, cmap2, velocity_cmap], 
                                   vmin_out = [0,-0.006, -0.0006], 
                                   vmax_out = [0.25, 0.006, 0.0006], 
                                   norm_per_img = [0, 0, 0, 0], 
                                   cmap_center  = [0, 0, 1, 1],
                                   rejection = 3,
                                   plot_cbar = True,
                                   cbar_left = 0.18,
                                   cbar_bottom = 0.07, #0.05
                                   cbar_width = 0.72,
                                   cbar_height = 0.015, #0.2
                                   cbar_label_x = 0.15,
                                   cbar_sep    = 3.0, # 2.5
                                   save = fig_save, fig_file_path=fig_file_path, fig_file_name='pred'+label+'_origin_with_cbar', 
                                   fig_format_list=fig_format_list, fig_overwrite=fig_overwrite)

In [None]:
# selected for variesties, shuffle data3, origin

P_max_list = [0.4, 0.04, 0.04, 0.02, 0.02, 0.02]
u_max_list = [0.01]*len(index_list)
v_max_list = [0.001, 0.001, 0.001, 0.001, 0.001, 0.001]

cmap_out = [pressure_cmap, cmap2, velocity_cmap]
cmap_out = ['coolwarm', 'coolwarm', 'coolwarm'] # OrRd PuBu RdYlBu_r Spectral_r
for i in range(len(index_list)):
# for i in range(1):
    plot_index_group = index_list[i]
    fig, ax = prediction_visualisation(model_list, 
                                       model_out_origin[:,plot_index_group,:,:,:], 
                                       test_a[plot_index_group,:,:,:], 
                                       test_u_origin[plot_index_group,:,:,:], 
                                       ax_side_length = ax_side_length, wspace = wspace, hspace = hspace, 
                                       group_col_sep = group_col_sep,
                                       lvl2_title_pad = lvl2_title_pad, 
                                       lvl2_title_font = lvl2_title_font, 
                                       border_off = border_off,
                                       in_title = in_title, out_title = out_title,
                                       same_in = same_in, cmap_in = cmap_in, 
                                    #    vmin_in = vmin_in, vmax_in = vmax_in,
                                       vmin_in = vmin_in, vmax_in = [1.0],
                                       same_out = False, cmap_out = cmap_out, 
                                       vmin_out = [0, -u_max_list[i], -v_max_list[i]], 
                                       vmax_out = [P_max_list[i], u_max_list[i], v_max_list[i]], 
                                       rejection = rejection,
                                       norm_per_img = [0, 0, 0, 0], 
                                       cmap_center  = [0, 0, 1, 1],
                                       plot_cbar = True,
                                       cbar_left = 0.18,
                                       cbar_bottom = 0.025, #0.05
                                       cbar_width = 0.72,
                                       cbar_height = 0.015*2, #0.2
                                       cbar_label_x = 0.15,
                                       cbar_sep    = 4, # 2.5
                                       save = fig_save, fig_file_path=fig_file_path+'Groups%s/'%(group_name), fig_file_name='pred_origin_%i_with_cbar'%(i+1), 
                                       fig_format_list=fig_format_list, fig_overwrite=fig_overwrite)

## Error Visualisation

In [None]:
# physics loss
error_mode = 'PI' # 'relative' 'abs_relative'  'continuity' 'NS'
model_out, PI_error_out = model_predictions(model_list, test_a, test_u, device, 
                                            error_mode = error_mode) 

div_residual, NS_x_res, NS_y_res = physics_error_evaluation(casename, 
                                                            test_a, test_u, eval_mode='tensor')

loss_dim = div_residual.shape
error_original = torch.cat((div_residual.view(loss_dim[0],1,loss_dim[1],loss_dim[2]),
                            NS_x_res.view(loss_dim[0],1,loss_dim[1],loss_dim[2]),
                            NS_y_res.view(loss_dim[0],1,loss_dim[1],loss_dim[2]),),1)

In [None]:
same_out = False
cmap_out = ['coolwarm']*3
norm_per_img = [0, 0, 0, 0]
cmap_center  = [0, 1, 1, 1]
lvl2_title_pad      = 0.4
lvl2_title_font     = 18

fig, ax = prediction_visualisation(model_list, PI_error_out[:,plot_index,:,:,:], 
                                   test_a[plot_index,:,:,:], 
                                   error_original[plot_index,:,:,:], 
                                   plot_truth = False,
                                   ax_side_length = ax_side_length, wspace = wspace, hspace = hspace, 
                                   group_col_sep = group_col_sep,
                                   lvl2_title_pad = lvl2_title_pad, 
                                   lvl2_title_font = lvl2_title_font, 
                                   border_off = border_off,
                                   in_title = in_title, 
                                   out_title = [r'$\mathcal{R}_{continuity}$',
                                                r'$\mathcal{R}_{x \,momentum}$',
                                                r'$\mathcal{R}_{y \,momentum}$'],
                                   same_in = same_in, cmap_in = cmap_in, 
                                #    vmin_in = vmin_in, vmax_in = vmax_in,
                                   vmin_in = vmin_in, vmax_in = [1.0],
                                   same_out = same_out, cmap_out = cmap_out, # the same for relative errors
                                   vmin_out = [-0.01, -0.2, -0.2], 
                                   vmax_out = [0.01, 0.2, 0.2], 
                                   norm_per_img = norm_per_img, 
                                   cmap_center  = cmap_center,
                                   rejection = 2, 
                                   plot_cbar=True,
                                   cbar_left = 0.2,
                                   cbar_bottom = 0.07, #0.05
                                   cbar_width = 0.7,
                                   cbar_height = 0.015, #0.2
                                   cbar_label_x = 0.135,
                                   cbar_sep    = 3.0, # 2.5
                                   save = fig_save, fig_file_path=fig_file_path, 
                                   fig_file_name='error'+label+'_'+error_mode+'_with_cbar_2', 
                                   fig_format_list=fig_format_list, fig_overwrite=fig_overwrite)


In [None]:
same_out = False
cmap_out = ['PiYG', 'PuOr', 'bwr',]
# 'PiYG'
# 'BrBG'
# 'PuOr'
# 'RdGy'
# 'RdBu'
# 'bwr'
# 'seismic'
# 'twilight_shifted'

norm_per_img = [0, 0, 0, 0]
cmap_center  = [0, 1, 1, 1]
lvl2_title_pad      = 0.4
lvl2_title_font     = 18

fig, ax = prediction_visualisation(model_list, PI_error_out[:,plot_index,:,:,:], 
                                   test_a[plot_index,:,:,:], 
                                   error_original[plot_index,:,:,:], 
                                   plot_truth = False,
                                   ax_side_length = ax_side_length, wspace = wspace, hspace = hspace, 
                                   group_col_sep = group_col_sep,
                                   lvl2_title_pad = lvl2_title_pad, 
                                   lvl2_title_font = lvl2_title_font, 
                                   border_off = border_off,
                                   in_title = in_title, 
                                   out_title = [r'$\mathcal{R}_{continuity}$',
                                                r'$\mathcal{R}_{x \,momentum}$',
                                                r'$\mathcal{R}_{y \,momentum}$'],
                                   same_in = same_in, cmap_in = cmap_in, 
                                #    vmin_in = vmin_in, vmax_in = vmax_in,
                                   vmin_in = vmin_in, vmax_in = [1.0],
                                   same_out = same_out, cmap_out = cmap_out, # the same for relative errors
                                   vmin_out = [-0.01, -0.2, -0.2], 
                                   vmax_out = [0.01, 0.2, 0.2], 
                                   norm_per_img = norm_per_img, 
                                   cmap_center  = cmap_center,
                                   rejection = 2, 
                                   plot_cbar=True,
                                   cbar_left = 0.2,
                                   cbar_bottom = 0.07, #0.05
                                   cbar_width = 0.7,
                                   cbar_height = 0.015, #0.2
                                   cbar_label_x = 0.135,
                                   cbar_sep    = 3.0, # 2.5
                                   save = fig_save, fig_file_path=fig_file_path, 
                                   fig_file_name='error'+label+'_'+error_mode+'_with_cbar', 
                                   fig_format_list=fig_format_list, fig_overwrite=fig_overwrite)


# Scalability Figure

In [None]:
resolution = 256 # 128 256 400

test_a_paper, test_u_paper = read_from_mat('./data/processed/data_20241125_alpha_reshuffle_3_normalised/paper_%i.mat'%(resolution))

df = pd.read_csv(data_path+'norm_info.csv')
norm_info = np.array(df)[:,2:].astype(float)
test_a_paper_origin, test_u_paper_origin = denormalisation(test_a_paper, test_u_paper, norm_info)

print(test_a_paper.shape)
print(test_u_paper.shape)

model_list_paper = model_list # [model_list[1:3]]

# prediction
batch_data_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a_paper, test_u_paper), 
                                                batch_size=10, shuffle=False , drop_last=False)


model_out_paper = []
error_out_paper = []

model_out_paper_origin = []
for x_test, y_test in batch_data_loader:

    pred, error = model_predictions(model_list_paper, x_test, y_test, device, 
                                                 error_mode = 'relative')
    
    pred_origin, _ = model_predictions(model_list_paper, x_test, y_test, device, 
                                         error_mode = 'relative',
                                         origin_recover = True, norm_info = norm_info)
    
    model_out_paper.append(pred)
    error_out_paper.append(error)
    
    model_out_paper_origin.append(pred_origin)

model_out_paper  = torch.cat(model_out_paper, dim=1)
error_out_paper  = torch.cat(error_out_paper, dim=1)

model_out_paper_origin = torch.cat(model_out_paper_origin, dim=1)

print(model_out_paper.shape)
print(error_out_paper.shape)
# Prediction error metrics:
print_error_statistics( casename, model_list_paper, model_out_paper, test_a_paper, test_u_paper, 
                        out_channel_names = ['P', 'U_x', 'U_y'],
                        eval_mode_list = [ 'R2', 'MSE', 'L2_rel'], PI_Loss=True,
                        if_print = False, if_save = False, overwrite = True,
                        file_path = 'results/%s/'%(casename),
                        file_name = 'metric_summary_invariant%s_%i'%(label,resolution),
                        )

fig_save            = 1
fig_overwrite       = True
fig_file_path       = 'results/%s/figures/Paper_Figure/'%(casename)
fig_format_list     = ['.png', '.eps']

ax_side_length      = 1.2
wspace              = 0.01
hspace              = 0.02
group_col_sep       = 0.1
lvl2_title_pad      = 0.4
lvl2_title_font     = 15
border_off          = True

in_title    = [r'$\alpha$'] # ['Gamma']
out_title   = [r'$P$', r'$U_x$', r'$U_y$']
same_in     = True
cmap_in     = ['gray_r']
vmin_in     = None
vmax_in     = None

rejection           = 3

# model_list, model_out
from matplotlib.colors import LinearSegmentedColormap

# halving cmap if needed
pressure_cmap = 'coolwarm'
velocity_cmap = 'coolwarm'  # 'coolwarm' 'bwr' #'Spectral_r' get_parula_map()
cmap2 = velocity_cmap

# cmap = plt.get_cmap(velocity_cmap)
# colors = cmap(np.linspace(0.5, 1, cmap.N // 2))
# cmap2 = LinearSegmentedColormap.from_list('Upper Half', colors)


In [None]:
fig, ax = prediction_visualisation(model_list_paper, 
                                   model_out_paper_origin, 
                                   test_a_paper, 
                                   test_u_paper_origin, 
                                   ax_side_length = ax_side_length, wspace = wspace, hspace = hspace, 
                                   group_col_sep = group_col_sep,
                                   lvl2_title_pad = lvl2_title_pad, 
                                   lvl2_title_font = lvl2_title_font, 
                                   border_off = border_off,
                                   in_title = in_title, out_title = out_title,
                                   same_in = same_in, cmap_in = cmap_in, 
                                #    vmin_in = vmin_in, vmax_in = vmax_in,
                                   vmin_in = vmin_in, vmax_in = [1.0],
                                   same_out = False, cmap_out = [pressure_cmap, cmap2, velocity_cmap], 
                                   vmin_out = [0,-0.006, -0.0006], 
                                   vmax_out = [0.1, 0.006, 0.0006], 
                                   norm_per_img = [0, 0, 0, 0], 
                                   cmap_center  = [0, 0, 1, 1],
                                   rejection = 3,
                                   plot_cbar = True,
                                   cbar_left = 0.18,
                                   cbar_bottom = 0.07, #0.05
                                   cbar_width = 0.72,
                                   cbar_height = 0.015, #0.2
                                   cbar_label_x = 0.15,
                                   cbar_sep    = 3.5, # 2.5
                                   save = fig_save, fig_file_path=fig_file_path, 
                                   fig_file_name='paper_pred_%i%s_origin_with_cbar'%(resolution, label), 
                                   fig_format_list=fig_format_list, fig_overwrite=fig_overwrite)