In [1]:
import os
from glob import glob
import numpy as np
import pandas as pd

In [2]:
experiment_paths = [os.path.realpath("./Pre-experiment1"), os.path.realpath("./Pre-experiment2")]

# load multirun yaml Omgeaconf
from omegaconf import OmegaConf

multirun_yaml_paths = []
multi_cfgs = []
sweep_cfgs = []
for experiment_path in experiment_paths:
    multirun_yaml_path = glob(f"{experiment_path}/**/multirun.yaml", recursive=True)[0]

    multi_cfg = OmegaConf.load(multirun_yaml_path)
    sweep_cfg = multi_cfg["hydra"]["sweeper"]["params"]

    multirun_yaml_paths.append(multirun_yaml_path)
    multi_cfgs.append(multirun_yaml_path)
    sweep_cfgs.append(sweep_cfg)

In [3]:
sweep_cfgs

[{'model': 'MLP_5x60', '++model.n_layers': 'choice(4,6,8)', '++model.n_nodes': 'choice(40,100,200)'},
 {'model': 'MLP_5x60', '++model.n_layers': 'choice(3,4)', '++model.n_nodes': 'choice(200,300)', 'optimizer': 'opt_v1', '++optimizer.batch': 'choice(8096, 16384, 32768)'}]

In [4]:
# use regex to find instances where prepend is ++
# and store keys minus the ++ in a list
import re

sweep_keys_list = []

for sweep_cfg, experiment_path in zip(sweep_cfgs, experiment_paths):
    sweep_keys = []
    for k, v in sweep_cfg.items():
        if re.match(r"\+\+", k):
            sweep_keys.append(k[2:])
        elif re.match(r"\+", k):
            sweep_keys.append(k[2:])
        elif re.match(r"choice", v):
            if k == "optimizer":
                sweep_keys.append(f"{k}.loss_balancing.type")
            elif k == "model":
                sweep_keys.append(k)
        elif re.match(r"glob", v):
            if k == "optimizer":
              sweep_keys.append(k)
            #   checkpoint_dir_paths = glob(f"{experiment_path}*/.hydra/overrides.yaml", recursive = True)
            
    sweep_keys_list.append(sweep_keys)

sweep_keys_list

[['model.n_layers', 'model.n_nodes'],
 ['model.n_layers', 'model.n_nodes', 'optimizer.batch']]

In [5]:
best_model_paths_list = []
for experiment_path in experiment_paths:
    checkpoint_dir_paths = glob(f"{experiment_path}/**/checkpoints*", recursive = True)

    best_model_paths = []
    for checkpoint_dir_path in checkpoint_dir_paths:
        subdirs = np.array(os.listdir(checkpoint_dir_path)).astype(int)
        if len(subdirs) == 0:
            best_model_path = checkpoint_dir_path+"/empty/empty"
        else:
            best_model_path = os.path.join(checkpoint_dir_path, f"{subdirs.max()}/default")
        best_model_paths.append(best_model_path)

    best_model_paths.sort()
    best_model_paths_list.append(best_model_paths)
best_model_paths_list

[['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/0/checkpoints/962/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/1/checkpoints/726/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/2/checkpoints/1011/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/3/checkpoints/1152/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/4/checkpoints/602/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/5/checkpoints/21/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/6/checkpoints/1011/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/7/checkpoints/440/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/8/checkpoints/515/default'],
 ['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/0/checkpoints/1310/default',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/1/che

In [6]:
config_paths_list = []
overrides_paths_list = []
for best_model_paths in best_model_paths_list:
    config_paths = [best_model_path.split("/")[:-3]+[".hydra/config.yaml"] for best_model_path in best_model_paths]
    config_paths = ["/".join(config_path) for config_path in config_paths]
    config_paths_list.append(config_paths)
    overrides_paths = [best_model_path.split("/")[:-3]+[".hydra/overrides.yaml"] for best_model_path in best_model_paths]
    overrides_paths = ["/".join(overrides_path) for overrides_path in overrides_paths]
    overrides_paths_list.append(overrides_paths)
print(len(config_paths_list), config_paths_list)
print(len(overrides_paths_list), overrides_paths_list)

2 [['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/0/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/1/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/2/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/3/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/4/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/5/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/6/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/7/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/8/.hydra/config.yaml'], ['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/0/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/1/.hydra/config.yaml', '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-exp

In [11]:
from benedict import benedict

sweep_values_df_list = []


for config_paths, overrides_paths, sweep_keys in zip(config_paths_list, overrides_paths_list, sweep_keys_list):
    sweep_values_df = pd.DataFrame(columns=sweep_keys)
    for i, (config_path, overrides_path) in enumerate(zip(config_paths, overrides_paths)):
        cfg = OmegaConf.load(config_path)
        model_values = []
        for k in sweep_keys:
            if "." in k:
                n_dots = k.count(".")
                val = cfg[k.split(".")[0]]
                for j in range(n_dots):
                    val = val[k.split(".")[j+1]]
                model_values.append(val)
            else:
                model_values.append(cfg[k])
        sweep_values_df.loc[i] = model_values
    

    sweep_values_df_list.append(sweep_values_df)
# print(len(sweep_values_df_list), sweep_values_df_list)

In [13]:
import sys
sys.path.append("../")
from utils.jax_flax import load_model


loaded_models_list = []
best_models_list = []
best_params_list = []
best_cfgs_list = []

for best_model_paths in best_model_paths_list:
    loaded_models = []
    best_models = []
    best_params = []
    best_cfgs = []
    for model_path in best_model_paths:
        if "empty" in model_path:
            loaded_model = None
            best_model = None
            best_param = None
            best_cfg = None
        else:
            loaded_model = load_model(model_path)
            best_model = loaded_model[0]
            best_param = loaded_model[1]
            best_cfg = loaded_model[2]

        best_models.append(best_model)
        best_params.append(best_param)
        best_cfgs.append(best_cfg)
        loaded_models.append(loaded_model)
    loaded_models_list.append(loaded_models)
    best_models_list.append(best_models)
    best_params_list.append(best_params)
    best_cfgs_list.append(best_cfgs)

In [15]:
model_nos_list = []

# exp_categories = ["A", "B", "C"]
exp_names_list = []

for exp_main_path, best_model_paths in zip(experiment_paths, best_model_paths_list):
    model_nos = []
    exp_names = []
    for best_model_path in best_model_paths:
        model_no = best_model_path.split("/")[-4]
        # model_nos.append(f"{exp_no}_{int(model_no)}")
        model_nos.append(model_no)
        exp_names.append(exp_main_path.split("/")[-1])
    model_nos = np.array(model_nos).astype(int)
    model_nos_list.append(model_nos)
    exp_names_list.append(exp_names)
# print(model_nos_list)
# print(exp_names_list)

In [16]:
sweep_values_df

Unnamed: 0,model.n_layers,model.n_nodes,optimizer.batch
0,3,200,8096
1,3,200,16384
2,4,300,16384
3,4,300,32768
4,3,200,32768
5,3,300,8096
6,3,300,16384
7,3,300,32768
8,4,200,8096
9,4,200,16384


In [18]:
from utils.data import DataManager
from utils.plotting import Plotter
main_path = "../"
from tqdm import tqdm


nc_path_old = ""

model_comparison_df_list = []
metric_dfs_list = []
plot_classes_list = []


for best_model_paths, best_cfgs, sweep_values_df, model_nos, exp_names in zip(best_model_paths_list, best_cfgs_list, sweep_values_df_list, model_nos_list, exp_names_list):
    plot_classes = []
    metrics_dfs = []
    model_comparison_df = pd.DataFrame()
    # for best_model, best_param, best_cfg in loaded_models:
    i=0
    for best_model_path, cfg in tqdm(zip(best_model_paths, best_cfgs)):
        # evaluate
        if "empty" in best_model_path:
            plot_class = None
            # metrics_df = pd.concat([sweep_values_df.loc[i], pd.Series({"total": [np.nan]})])
            model_row = sweep_values_df.loc[i]
    
        else:
            nc_path = os.path.abspath(os.path.join(main_path, cfg.data.data_path))
            if nc_path != nc_path_old:
                DM = DataManager(nc_path = nc_path, 
                            exclusion_radius = 1., 
                            input_coords = ["z_cyl", "r", "CT", "TI_amb"],
                            output_vars = ["U_z", "U_r", "P"], 
                            val_split=0.1, 
                            development_mode=False)
                nc_path_old = nc_path
            plot_class = Plotter(DM, best_model_path)
            sweep_vals = []
            metrics_df, _ = plot_class.make_metric_df()
            model_row = pd.concat([sweep_values_df.loc[i], metrics_df.loc['total']])
        metrics_dfs.append(metrics_df)
        plot_classes.append(plot_class)
        model_comparison_df = pd.concat([model_comparison_df, model_row], axis=1)
        i+=1
    model_comparison_df = model_comparison_df.T.reset_index(drop=True)
    tup_idx = [(model_no, exp_name) for  model_no, exp_name in zip(model_nos, exp_names)]
    model_comparison_df.index = pd.Index(tup_idx)
    # model_comparison_df.index = model_nos
    model_comparison_df["paths"] = best_model_paths
    
    model_comparison_df_list.append(model_comparison_df)
    metric_dfs_list.append(metrics_dfs)
    plot_classes_list.append(plot_classes)
# model_comparison_df_list

9it [20:39, 137.70s/it]
12it [35:05, 175.48s/it]


In [19]:
# final_model_paths = [best_model_path.split("/")[:-3]+["final_model"] for best_model_path in best_model_paths]

final_model_paths_list = []

for best_model_paths in best_model_paths_list:
    final_model_paths = []
    for best_model_path in best_model_paths:
        if "empty" in best_model_path:
            final_model_path = best_model_path.split("/")[:-3]+["empty"]
            final_model_paths.append(final_model_path)
        else:
            final_model_path = best_model_path.split("/")[:-3]+["final_model"]
            final_model_paths.append(final_model_path)

    final_model_paths = ["/".join(final_model_path) for final_model_path in final_model_paths]

    final_model_paths_list.append(final_model_paths)
final_model_paths_list

[['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/0/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/1/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/2/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/3/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/4/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/5/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/6/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/7/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/8/final_model'],
 ['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/0/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/1/final_model',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/10/final_model',
  '/home/jpsch/code/jax

In [20]:
finished_bools_list = []
for final_model_paths, model_comparison_df in zip(final_model_paths_list, model_comparison_df_list):
    finished_bools = []
    for final_model_path in final_model_paths:
        finished_bools.append(os.path.exists(final_model_path))

    model_comparison_df["finished"] = finished_bools
    finished_bools_list.append(finished_bools)

In [21]:
# tb_paths = [final_model_path.split("/")[:-1]+["tensorboard"] for final_model_path in final_model_paths]

tb_paths_list = []

for final_model_paths in final_model_paths_list:
    tb_paths = []
    for final_model_path in final_model_paths:
        if "empty" in final_model_path:
            tb_path = final_model_path.split("/")[:-1]+["empty"]
            tb_paths.append(tb_path)
        else:
            tb_path = final_model_path.split("/")[:-1]+["tensorboard"]
            tb_paths.append(tb_path)

    tb_paths = ["/".join(tb_path) for tb_path in tb_paths]
    tb_paths_list.append(tb_paths)
tb_paths_list

[['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/0/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/1/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/2/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/3/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/4/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/5/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/6/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/7/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment1/8/tensorboard'],
 ['/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/0/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/1/tensorboard',
  '/home/jpsch/code/jax-flax-wake-pinn/Results/Pre-experiment2/01-16-34/10/tensorboard',
  '/home/jpsch/code/jax

In [22]:
# Construct dataframe of timing s with tensorboard logs
import sys
sys.path.append("../")
from utils.data import TensorBoardLoader

tb_dfs_list = []
for i, (tb_paths, model_comparison_df, model_nos) in enumerate(zip(tb_paths_list, model_comparison_df_list, model_nos_list)):

    tb_dfs = []
    timings = pd.DataFrame(columns=["total_time[h]", "epochs", "time_per_epoch[m]"])
    for j, tb_path in enumerate(tb_paths):
        if "empty" in tb_path:
            timings.loc[j] = [np.nan, np.nan, np.nan]
            df = None
        else:
            tb_class = TensorBoardLoader(tb_path)
            time_data, epochs_data, loss_data = tb_class.load_scalar_events("Loss/data")
            total_epochs = epochs_data[-1]
            total_time = (time_data[-1]-time_data[0])/60/60 # total time in hours
            time_per_epoch = (total_time*60)/total_epochs # time per epoch in minutes
            timings.loc[j] = [total_time, total_epochs, time_per_epoch]
            df = tb_class.load_df()
        tb_dfs.append(df)

    timings.index = model_comparison_df_list[i].index
    model_comparison_df_list[i] = model_comparison_df_list[i].join(timings)
    # model_comparison_df_list[i] = pd.concat([model_comparison_df, timings])
    tb_dfs_list.append(tb_dfs)

In [23]:
model_comparison_df

Unnamed: 0,Unnamed: 1,model.n_layers,model.n_nodes,optimizer.batch,MSE,MAE,RMSE,MAPE,R2,MSE_pinn,paths,finished
0,Pre-experiment2,3,200,8096,0.0,7.4e-05,0.000145,30.333341,1.0,0.287227,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
1,Pre-experiment2,3,200,16384,0.0,8.1e-05,0.000156,22.168833,1.0,7.892798,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
10,Pre-experiment2,4,300,16384,0.0,0.000184,0.000579,97.794485,0.999998,6678758.5,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
11,Pre-experiment2,4,300,32768,0.0,0.000104,0.000255,32.347095,1.0,27.815332,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
2,Pre-experiment2,3,200,32768,0.0,8.7e-05,0.000179,23.007352,1.0,0.547626,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
3,Pre-experiment2,3,300,8096,0.0,7.6e-05,0.000195,18.100477,1.0,327.1651,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
4,Pre-experiment2,3,300,16384,0.0,8.7e-05,0.000173,47.706488,1.0,5.739264,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
5,Pre-experiment2,3,300,32768,0.0,7.6e-05,0.000193,19.304,1.0,155.843597,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,False
6,Pre-experiment2,4,200,8096,0.0,7.9e-05,0.000159,28.876644,1.0,0.00683,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True
7,Pre-experiment2,4,200,16384,0.0,8.2e-05,0.000153,55.12012,1.0,0.002147,/home/jpsch/code/jax-flax-wake-pinn/Results/Pr...,True


In [24]:
model_comparison_df = pd.concat(model_comparison_df_list)
hl_df = model_comparison_df.reset_index()

# Dropping cases where the batch size is not 8096
# Replace nan with 0.0 for colloc_data_rati
# hl_df = hl_df[hl_df["optimizer.batch"]==np.nan]

pre_exp1_mask = hl_df["level_1"]=="Pre-experiment1"
pre_exp2_mask = hl_df["level_1"]=="Pre-experiment2"

hl_df.loc[pre_exp1_mask, "optimizer.batch"] = 8096
hl_df = hl_df.where(hl_df["optimizer.batch"]==8096, inplace=False).dropna(axis=0)
hl_df = hl_df[hl_df["level_1"]=="Pre-experiment1"]
hl_df.drop(columns=["level_0", "level_1", "paths", "optimizer.batch", "time_per_epoch[m]", "finished", "R2"], inplace=True)

for column in hl_df.columns:
    if column == "epochs":
       hl_df[column] = hl_df[column].astype(int)
    elif column == "total_time[h]":
        hl_df[column] = hl_df[column].astype(float).round(3).astype(str)

    else:
        pass
        # hl_df[column] = hl_df[column].astype(float).round(5).astype(str)

hl_df.columns = [r'$n_\textrm{layers}$', r'$m_\textrm{neurons}$', r"$\textrm{MSE}$", r"$\textrm{MAE}$", r"$\textrm{RMSE}$", r"$\textrm{MAPE}$", r"$\textrm{MSE}_\phi$",
       'Wall time [h]', 'epochs']


In [25]:
# Define a custom formatter for scientific notation
scientific_formatter = "{:.2e}".format
hl_df.to_latex("../figures/Pre-experiment1.tex", 
                float_format="{:.2e}".format,
                escape=False)

hl_df.to_latex("../figures/Pre-experiment1_f_formatted.tex", 
                float_format="{:.2f}".format,
                escape=False)

In [26]:
# load latex table
with open("../figures/Pre-experiment1.tex", 'r') as f:
    lines = f.readlines()

with open("../figures/Pre-experiment1_f_formatted.tex", 'r') as f:
    lines2 = f.readlines()

# Change inital lines and line indicatiors to match Journal of Physics and own style

for i, _ in enumerate(lines):
    splitted_line = lines[i].split("&")
    for j, splited in enumerate(splitted_line):
        if (r"e-" in splited or r"e+" in splited):
            power = int(splited.split(r"e")[1])
            if np.abs(power) < 3:
                splitted2 = lines2[i].split("&")
                splitted_line[j] = splitted2[j]
            else:
                splitted_line[j] = r"\num{"+splited+r"}"

    if r"\\" in lines[i]:
        lines[i] = lines[i].replace(r"\\", r"\cr")
    if "&" in lines[i]:
        lines[i] = "&".join(splitted_line[1:])

start_lines = [r"\begin{tabular}{cccccccccc}"+"\n", r"\br"+"\n"]
lines[:2] = start_lines
lines[3] = r"\mr"+"\n"
lines[-2] = r"\br"+"\n"
lines[-1] = r"\end{tabular}"+"\n"

# write as latex file
with open("../figures/Pre-experiment1.tex", 'w') as f:
    f.writelines(lines)

In [27]:
hl_df = hl_df.sort_values(by=r"$\textrm{MSE}$", ascending=True)
hl_df = hl_df.style.highlight_min(subset=[r"$\textrm{MSE}$", r"$\textrm{MAE}$", r"$\textrm{RMSE}$", r"$\textrm{MAPE}$", r"$\textrm{MSE}_\phi$"], color = 'darkgrey')
# hl_df = hl_df.highlight_max(subset=[r"$\textrm{R}^2$"], color = 'darkgrey', axis = 0)

hl_df

Unnamed: 0,$n_\textrm{layers}$,$m_\textrm{neurons}$,$\textrm{MSE}$,$\textrm{MAE}$,$\textrm{RMSE}$,$\textrm{MAPE}$,$\textrm{MSE}_\phi$,Wall time [h],epochs
1,4,100,0.0,7.1e-05,0.000159,15.321644,0.042449,14.196,897
2,4,200,0.0,7.9e-05,0.000159,28.876644,0.00683,27.603,1177
6,8,40,0.0,6.5e-05,0.00017,9.196713,0.293689,16.863,1214
0,4,40,0.0,9.3e-05,0.000193,40.012421,24.793255,13.162,1213
3,6,40,0.0,0.000144,0.000283,82.688969,0.034803,14.571,1152
4,6,100,0.0,0.000157,0.000467,149.544946,3.212239,16.843,853
7,8,100,0.0,0.000249,0.000626,303.905756,23321764.0,16.54,691
8,8,200,1e-06,0.000549,0.001083,360.843594,47267.316406,28.675,766
5,6,200,2e-06,0.000403,0.001294,114.580328,0.141783,7.807,272


In [28]:
# highlighted_df.to_latex("model_comparison.tex")

In [29]:
# from matplotlib import pyplot as plt
# for i, (plot_classes_list, tb_dfs_list, plot_classes_list) in enumerate(zip(plot_classes_list, tb_dfs_list, plot_classes_list)):
#     for j, (df, plot_class) in enumerate(zip(tb_dfs, plot_classes)):
#         if plot_class is None:
#             continue
#         title = final_model_paths[j].split("/")[-2]
#         fig, axes = plt.subplots(1, 2, figsize=(20, 6))
#         plt.title(f"Exp: {i}, Model {title}")
#         if "Loss/pinn" in df.columns:
#             df.plot("step",["Loss/tot", "loss", "Loss/pinn", "Loss/val"], logy=True, ax=axes[0])
#             df.plot("step",["Loss/alpha_data", "Loss/alpha_pinn"], logy=False, ax=axes[1])
#         else:
#             df.plot("step",["loss", "Loss/val"], logy=True, ax=axes[0])

#         plot_class.plot_pred_triplet("U_z", flowcase=0)
#         plot_class.plot_pred_triplet("U_r", flowcase=0)
#         plot_class.plot_pred_triplet("P", flowcase=0)