In [1]:
name="multifidelity_viscous_models_PbPb2760"

In [2]:
#import GPy
import os
#import pickle
import numpy as np
import seaborn as sns
import pandas as pd
import math
import matplotlib.pyplot as plt

#from sklearn.decomposition import PCA
#from numpy.linalg import inv
from sklearn.preprocessing import StandardScaler
from sklearn.gaussian_process import GaussianProcessRegressor as gpr
from sklearn.gaussian_process import kernels as krnl
import scipy.stats as st
from scipy import optimize

#import emcee
#import ptemcee
#import h5py
#from scipy.linalg import lapack
from multiprocessing import Pool
from multiprocessing import cpu_count

### Setup working folders


In [3]:
# Where to save the figures and data files
PROJECT_ROOT_DIR = "Results_ldr"
FIGURE_ID = "Results_ldr/FigureFiles"
DATA_ID = "DataFiles/"

In [4]:
if not os.path.exists(PROJECT_ROOT_DIR):
    os.mkdir(PROJECT_ROOT_DIR)

if not os.path.exists(FIGURE_ID):
    os.makedirs(FIGURE_ID)

if not os.path.exists(DATA_ID):
    os.makedirs(DATA_ID)

def image_path(fig_id):
    return os.path.join(FIGURE_ID, fig_id)

def data_path(dat_id):
    return os.path.join(DATA_ID, dat_id)

def save_fig(fig_id):
    plt.savefig(image_path(fig_id) + ".png", format='png')

### Load the simulation data from the four viscosity correction models

            0 : 'Grad',
            1 : 'Chapman-Enskog R.T.A',
            2 : 'Pratt-McNelis',
            3 : 'Pratt-Bernhard'

In [5]:
# Bounds for parametrs in the emulator are same as prior ranges so
prior_df = pd.read_csv(filepath_or_buffer="DataFiles/PbPb2760_prior", index_col=0)

In [6]:
# Design points
design = pd.read_csv(filepath_or_buffer="DataFiles/PbPb2760_design")

In [7]:
#Simulation outputs at the design points
simulation_df = []
simulation_sd_df = []
for idf in range(0,4):
    simulation_df.append(pd.read_csv(filepath_or_buffer=f"DataFiles/PbPb2760_simulation_{idf}"))
    #simulation_sd_df.append(pd.read_csv(filepath_or_buffer=f"DataFiles/PbPb2760_simulation_error_{idf}"))

In [8]:
df_clms=simulation_df[1].keys()

### Normalize all the other models using mean and variance of observable in the Grad model

In [9]:
#normalize data with respect to lower fidelity
s_l = StandardScaler()
x = simulation_df[0].values
s_l.fit(x)
for idf in range(0,4):
    x_tmp = simulation_df[idf].values
    simulation_df[idf]= pd.DataFrame(s_l.transform(x_tmp),columns=df_clms)

#diff = np.array(prior_df.loc['max'].values - prior_df.loc['min'].values ).reshape(1,-1)
#diff_mat = np.repeat(diff,X.shape[0],axis=0)
#print(f'Shape of diff matt {diff_mat.shape}')
#X= np.divide(X,diff_mat)

In [10]:
X = design.values

In [11]:
for idf in range(0,4):
    Y = simulation_df[idf].values
    print( "X.shape : "+ str(X.shape) )
    print( "Y.shape : "+ str(Y.shape) )

X.shape : (485, 17)
Y.shape : (485, 110)
X.shape : (485, 17)
Y.shape : (485, 110)
X.shape : (485, 17)
Y.shape : (485, 110)
X.shape : (485, 17)
Y.shape : (485, 110)


### Labels for plotting purposes

In [12]:
#Model parameter names in Latex compatble form
model_param_dsgn = ['$N$[$2.76$TeV]',
 '$p$',
 '$\\sigma_k$',
 '$w$ [fm]',
 '$d_{\\mathrm{min}}$ [fm]',
 '$\\tau_R$ [fm/$c$]',
 '$\\alpha$',
 '$T_{\\eta,\\mathrm{kink}}$ [GeV]',
 '$a_{\\eta,\\mathrm{low}}$ [GeV${}^{-1}$]',
 '$a_{\\eta,\\mathrm{high}}$ [GeV${}^{-1}$]',
 '$(\\eta/s)_{\\mathrm{kink}}$',
 '$(\\zeta/s)_{\\max}$',
 '$T_{\\zeta,c}$ [GeV]',
 '$w_{\\zeta}$ [GeV]',
 '$\\lambda_{\\zeta}$',
 '$b_{\\pi}$',
 '$T_{\\mathrm{sw}}$ [GeV]']

In [13]:
observables_latex_2 = ['$\\frac{dN_{ch}}{d\\eta}$',
 '$\\frac{dE_T}{d\\eta}$',
 '$\\frac{dN_{\\pi}}{dy}$',
 '$\\frac{dN_{K}}{dy}$',
 '$\\frac{dN_{P}}{dy}$',
 '$\\langle pT_{\\pi} \\rangle$',
 '$\\langle pT_{K} \\rangle$',
 '$\\langle pT_{P} \\rangle$',
 '$\\frac{\\delta p_T}{\\langle p_T \\rangle}$',
 '$v_2${2}',
 '$v_3${2}',
 '$v_4${2}']

In [14]:
observables_latex = ['$dN_{ch} / d\\eta$',
 '$dE_T / d\\eta$',
 '${dN_{\\pi}} / {dy}$',
 '${dN_{K}} / {dy}$',
 '${dN_{P}} / {dy}$',
 '$\\langle p_{T, \\pi} \\rangle$',
 '$\\langle p_{T, K} \\rangle$',
 '$\\langle p_{T, P} \\rangle$',
 '${\\delta p_T} / {\\langle p_T \\rangle}$',
 '$v_2${2}',
 '$v_3${2}',
 '$v_4${2}']

### Observables considered in this analysis

In [15]:
simulation_df[1].keys()[80:100]

Index(['pT_fluct[10 15]', 'pT_fluct[15 20]', 'pT_fluct[20 25]',
       'pT_fluct[25 30]', 'pT_fluct[30 35]', 'pT_fluct[35 40]',
       'pT_fluct[40 45]', 'pT_fluct[45 50]', 'pT_fluct[50 55]',
       'pT_fluct[55 60]', 'v22[0 5]', 'v22[ 5 10]', 'v22[10 20]', 'v22[20 30]',
       'v22[30 40]', 'v22[40 50]', 'v22[50 60]', 'v22[60 70]', 'v32[0 5]',
       'v32[ 5 10]'],
      dtype='object')

In [16]:
observables_choosen = ['dNch_deta[0 5]',
 'dNch_deta[60 70]',
 'dN_dy_pion[0 5]',
 'dN_dy_pion[60 70]',
# 'dN_dy_kaon[0 5]',
# 'dN_dy_kaon[60 70]',
# 'dN_dy_proton[0 5]',
# 'dN_dy_proton[60 70]',
 'mean_pT_pion[0 5]',
 'mean_pT_pion[60 70]',
# 'mean_pT_kaon[0 5]',
# 'mean_pT_kaon[60 70]',               
# 'mean_pT_proton[0 5]',
# 'mean_pT_proton[60 70]',
 'pT_fluct[0 5]',
 'pT_fluct[55 60]',
 'v22[0 5]',
  'v22[60 70]']

### Linear multifidelity modeling using Emukit

In [17]:
import GPy
import emukit.multi_fidelity
from emukit.model_wrappers.gpy_model_wrappers import GPyMultiOutputWrapper
from emukit.multi_fidelity.models import GPyLinearMultiFidelityModel
## Convert lists of arrays to ndarrays augmented with fidelity indicators
from sklearn.model_selection import train_test_split
from emukit.multi_fidelity.convert_lists_to_array import convert_x_list_to_array, convert_xy_lists_to_arrays
from sklearn.model_selection import KFold
#from emukit.multi_fidelity.models.non_linear_multi_fidelity_model import make_non_linear_kernels, NonLinearMultiFidelityModel
from sklearn.metrics import r2_score, mean_squared_error
import matplotlib

In [18]:

def run_train_and_validation(x_train_l, x_train_h, x_test_h, y_train_l, y_train_h, y_test_h, obs_name):
    """ Train two types of emulators. And perform validation on a given set
            1. Linear multifidelity GPs
            2. Standard GP
        =============================
        Return:
            3 x 2 ndarray. and rho value of linear multifidelity model
            First row has MSE
            Second row has R2 scores
            Third row has the White noise variance
        """
    print('########################')
    print(obs_name)
    X_train, Y_train = convert_xy_lists_to_arrays([x_train_l, x_train_h], [y_train_l, y_train_h])
    n_opt = 5
    ## Construct a linear multi-fidelity model

    #kernels = [GPy.kern.RBF(17, ARD=True), GPy.kern.RBF(17, ARD=True)]
    kernels = [GPy.kern.Matern52(17, ARD=True), GPy.kern.Matern52(17, ARD=True)]
    lin_mf_kernel = emukit.multi_fidelity.kernels.LinearMultiFidelityKernel(kernels)
    gpy_lin_mf_model = GPyLinearMultiFidelityModel(X_train, Y_train, lin_mf_kernel, n_fidelities=2)
    #gpy_lin_mf_model.mixed_noise.Gaussian_noise.fix(0)
    gpy_lin_mf_model.mixed_noise.Gaussian_noise_1.fix(0)


    ## Wrap the model using the given 'GPyMultiOutputWrapper'
    lin_mf_model = GPyMultiOutputWrapper(gpy_lin_mf_model, 2, n_optimization_restarts=n_opt)

    ## Fit the model
  
    lin_mf_model.optimize()
    
    ## Create standard GP model using only high-fidelity data

    kernel = GPy.kern.RBF(input_dim=17, ARD=True)
    high_gp_model = GPy.models.GPRegression(x_train_h,y_train_h, kernel)
    #high_gp_model.Gaussian_noise.fix(0)

    ## Fit the GP model
    
    high_gp_model.optimize_restarts(n_opt, verbose=True, parallel=True, num_processes=5)
    #print(high_gp_model)
    ## Compute mean predictions and associated variance

    #hf_mean_high_gp_model, hf_var_high_gp_model  = high_gp_model.predict(x_plot)
    #hf_std_hf_gp_model = np.sqrt(hf_var_high_gp_model)
    x_temp = convert_x_list_to_array([x_test_h,x_test_h])
    x_test_h_idf_index = x_temp[x_test_h.shape[0]:,:]
    x_test_l_idf_index = x_temp[0:x_test_h.shape[0],:]
    
    hf_mean_lin_mf_model, hf_var_lin_mf_model = lin_mf_model.predict(x_test_h_idf_index)
    lf_mean_lin_mf_model, lf_var_lin_mf_model = lin_mf_model.predict(x_test_l_idf_index)


    print(hf_mean_lin_mf_model.shape)
    r2_1 = r2_score(y_test_h,hf_mean_lin_mf_model )
    mse_1 = mean_squared_error(y_test_h,hf_mean_lin_mf_model )
    print(f'r2 score for multifidelity linear {r2_score(y_test_h,hf_mean_lin_mf_model )}')
    print(f'mse for multifidelity linear {mean_squared_error(y_test_h,hf_mean_lin_mf_model )}')

    hf_mean_high_gp_model, hf_var_high_gp_model  = high_gp_model.predict(x_test_h)
    print(hf_mean_high_gp_model.shape)
    r2_3 = r2_score(y_test_h,hf_mean_high_gp_model)
    mse_3 = mean_squared_error(y_test_h,hf_mean_high_gp_model )
    print(f'r2 score for standard GP {r2_score(y_test_h,hf_mean_high_gp_model)}')
    print(f'mse for standard GP {mean_squared_error(y_test_h,hf_mean_high_gp_model )}')
    
    #Plots
    
    #corner plots for high fidelity input parameters
    fraction_points=10 #int(np.ceil(0.1* x_train_h.shape[0]))
    print(f'Fraction of points {fraction_points}')
    
    #plot input parameters for few poor validation points
    diff_test_lin=np.array(np.absolute(hf_mean_lin_mf_model-y_test_h))
    sort_index=np.argsort(diff_test_lin.flatten())
    worst= sort_index[-fraction_points:]
    print(f'Worst {fraction_points} of validations')
    print('######################')
    print(diff_test_lin[worst])
    #print(worst.shape)
    matplotlib.rcParams.update({'font.size': 20})
    fig1, axs= plt.subplots(17,17, figsize=(70,70))
    fig1.suptitle(f'{obs_name}_Worst validation for batch size {x_train_h.shape[0]}')
    for row in range(0,17):
        for clmn in range(0,17):
            if row < clmn:
                ax = axs[row,clmn]
                ax.axis('off')
                continue
            if row == clmn:
                ax = axs[row,clmn]
                ax.hist(x_train_h[:,clmn],color='blue',density=True, alpha=0.5)
                ax.hist(x_test_h[worst,clmn], color='red', density=True, alpha=0.5)
            else:
                ax= axs[row,clmn]
                if row == 16:
                    ax.set_xlabel(model_param_dsgn[clmn], fontsize =40)
                if clmn == 0:
                    ax.set_ylabel(model_param_dsgn[row], fontsize =40)
                ax.scatter(x_train_h[:,clmn],x_train_h[:,row], cmap='Greens', label = 'training')
                ax.scatter(x_test_h[:,clmn],x_test_h[:,row], cmap='Reds', c=diff_test_lin, label='worst validation')
    plt.tight_layout()
    save_fig(f'{obs_name}_{name}_{x_train_h.shape[0]}_worst_validation')
    #Plot the discrepency function contribution and the low fidelity contribution for validations seprately
    discrepency_prediction_lin_mf = np.divide(hf_mean_lin_mf_model - lf_mean_lin_mf_model,lf_mean_lin_mf_model)
    #discrepency_prediction_nonlin_mf = np.divide(hf_mean_nonlin_mf_model - lf_mean_nonlin_mf_model,lf_mean_nonlin_mf_model)
    print('Average discrepency ratio for linear mf model')
    print('#######################################')
    print(np.mean(np.absolute(discrepency_prediction_lin_mf)))
#     print('#######################################')    
#     print('Average discrepency ratio for non linear mf model')
#     print(np.mean(np.absolute(discrepency_prediction_nonlin_mf)))
    matplotlib.rcParams.update({'font.size': 20})
    fig2, axs= plt.subplots(17,17, figsize=(70,70))
    fig2.suptitle(f'{obs_name}_biggest discrpency multifidelity GP for batch size {x_train_h.shape[0]}')
    #plot input parameters for few poor validation points
    disc_test_lin=np.array(np.absolute(discrepency_prediction_lin_mf))
    #print(diff_test_lin)
    sort_index=np.argsort(disc_test_lin.flatten())
    #print(sort_index)
    worst_disc_lin= sort_index[-fraction_points:]
    for row in range(0,17):
        for clmn in range(0,17):
            if row < clmn:
                ax = axs[row,clmn]
                ax.axis('off')
                continue
            if row == clmn:
                ax = axs[row,clmn]
                ax.axis('off')
                #ax.hist(x_train_h[:,clmn],color='green',density=True)
                #ax.hist(x_test_h[worst,clmn], color='red', density=True)
            else:
                ax= axs[row,clmn]
                if row == 16:
                    ax.set_xlabel(model_param_dsgn[clmn], fontsize =40)
                if clmn == 0:
                    ax.set_ylabel(model_param_dsgn[row], fontsize =40)
                ax.scatter(x_test_h[worst_disc_lin,clmn],x_test_h[worst_disc_lin,row], cmap='Reds',c= discrepency_prediction_lin_mf.flatten()[worst_disc_lin], label = 'linear')
                #ax.scatter(x_test_h[worst_disc_lin,clmn],x_test_h[worst_disc_lin,row], cmap='Reds',c= discrepency_prediction_lin_mf.flatten()[worst_disc_lin], label = 'linear')
    plt.tight_layout()
    save_fig(f'{obs_name}_{name}_{x_train_h.shape[0]}_discrepency')
    
    #Plot QQ plots? // uncertainty prediction magnitudes? 
   # print(f'shape of test linear mean {hf_mean_lin_mf_model.shape}')
   # print(f'shape of test linear variance {hf_var_lin_mf_model.shape}')
   # print(f'shape of test nonlinear mean {hf_mean_nonlin_mf_model.shape}')
   # print(f'shape of test nonlinear variance {hf_var_nonlin_mf_model.shape}')
   # print(f'shape of test standard GP mean {hf_mean_high_gp_model.shape}')
   # print(f'shape of test standard GP variance {hf_var_high_gp_model.shape}')
    lin_mf_normalize = np.divide(np.array(hf_mean_lin_mf_model-y_test_h).reshape(-1,1),np.sqrt(hf_var_lin_mf_model.reshape(-1,1)))
   # nonlin_mf_normalize = np.divide(np.array(hf_mean_nonlin_mf_model-y_test_h).reshape(-1,1),np.sqrt(hf_var_nonlin_mf_model.reshape(-1,1))) 
    standard_gp_normalize = np.divide(np.array(hf_mean_high_gp_model-y_test_h).reshape(-1,1),np.sqrt(hf_var_high_gp_model.reshape(-1,1))) 

    fig3, ax = plt.subplots(figsize=(10,10))
    percs = np.linspace(5,95,19)
    qn_a = np.percentile(lin_mf_normalize, percs)
  #  qn_b = np.percentile(nonlin_mf_normalize, percs)
    qn_c = np.percentile(standard_gp_normalize, percs)
    qn_th = st.norm.ppf(q=0.01*np.linspace(5,95,19))
    ax.plot(qn_th,qn_a, ls="", marker="o", label = 'Linear')
  #  ax.plot(qn_th,qn_b, ls="", marker="o", label = 'Nonlinear')
    ax.plot(qn_th,qn_c, ls="", marker="o", label = 'Standard')
    x = np.linspace(np.min((qn_a.min(),qn_c.min())), np.max((qn_a.max(),qn_c.max())))
    ax.plot(x,x, color="k", ls="--")
    ax.legend()
    ax.set_xlabel('Theorotical quantiles')
    ax.set_ylabel(f'{obs_name}_Emulation quantiles')
    plt.tight_layout()
    save_fig(f'{obs_name}_{name}_{x_train_h.shape[0]}_QQ_plot')
    
    r2s= [r2_1, r2_3]
    mses=[mse_1, mse_3]
    Wn = [lin_mf_model.gpy_model.param_array[-2], high_gp_model.param_array[-1]]
    rho = lin_mf_model.gpy_model.param_array[-3]
    return np.array([mses, r2s, Wn]), rho

In [None]:
kf = KFold(n_splits=5)
n_batch = 10

r2_ar_obs = []
mse_ar_obs = []
batch_ar_obs = []
wn_ar_obs = []
rho_ar_obs = []

for selected_observable in observables_choosen:

    Y_l=simulation_df[0][selected_observable].values.reshape(-1,1)
    Y_h=simulation_df[3][selected_observable].values.reshape(-1,1)
    
    r2_ar_crs = []
    mse_ar_crs = []
    batch_ar_crs = []
    wn_ar_crs = []
    rho_ar_crs = []
    for split_i,[train_index, test_index] in enumerate(kf.split(X)):
        if split_i>0:
            break
        r2_ar=[]
        mse_ar=[]
        batch_ar=[]
        wn_ar=[]
        rho_ar=[]
    #print("TRAIN:", train_index, "TEST:", test_index)
        x_train_h, x_test_h, y_train_h, y_test_h = X[train_index,:], X[test_index], Y_h[train_index,:], Y_h[test_index,:]
        x_train_l, x_test_l, y_train_l, y_test_l = X[train_index,:], X[test_index], Y_l[train_index,:], Y_l[test_index,:]
    
        for i in range(0,n_batch):
    
            l=0
            h=(train_index.shape[0]//n_batch)*(i+1)
            #if i == n_batch-1:
            if i == 1:
                h=train_index.shape[0]
            if i >1:
                break
    ## Construct a linear multi-fidelity model
            r2_mse_wn,rho = run_train_and_validation(x_train_l, x_train_h[l:h,:],x_test_h, y_train_l, y_train_h[l:h,:], y_test_h, selected_observable)
            r2_ar.append(r2_mse_wn[0,:])
            mse_ar.append(r2_mse_wn[1,:])
            wn_ar.append(r2_mse_wn[2,:])
            batch_ar.append(h)
            rho_ar.append(rho)
        r2_ar = np.array(r2_ar)
        mse_ar = np.array(mse_ar)
        batch_ar = np.array(batch_ar)
        wn_ar = np.array(wn_ar)
        rho_ar = np.array(rho_ar)
    
        r2_ar_crs.append(r2_ar)
        mse_ar_crs.append(mse_ar)
        batch_ar_crs.append(batch_ar)
        wn_ar_crs.append(wn_ar)
        rho_ar_crs.append(rho_ar)
    r2_ar_crs = np.array(r2_ar_crs)
    mse_ar_crs = np.array(mse_ar_crs)
    batch_ar_crs = np.array(batch_ar_crs)
    wn_ar_crs = np.array(wn_ar_crs)
    rho_ar_crs = np.array(rho_ar_crs)
    
    r2_ar_obs.append(r2_ar_crs)
    mse_ar_obs.append(mse_ar_crs)
    batch_ar_obs.append(batch_ar_crs)
    wn_ar_obs.append(wn_ar_crs)
    rho_ar_obs.append(rho_ar_crs)
    
r2_ar_obs = np.array(r2_ar_obs)
mse_ar_obs = np.array(mse_ar_obs)
batch_ar_obs = np.array(batch_ar_obs)
wn_ar_obs = np.array(wn_ar_obs)
rho_ar_obs = np.array(rho_ar_obs)


########################
dNch_deta[0 5]




Optimization restart 1/5, f = -39.61621197862621
Optimization restart 2/5, f = -38.85965567645863
Optimization restart 3/5, f = -39.40649184002132
Optimization restart 4/5, f = -39.218473523219984
Optimization restart 5/5, f = -40.193526626088214
Optimization restart 1/5, f = 16.272004381765207
Optimization restart 2/5, f = 50.95435476800239
Optimization restart 3/5, f = 16.271994985890498
Optimization restart 4/5, f = 55.31045222934547
Optimization restart 5/5, f = 16.271999013749646
(97, 1)
r2 score for multifidelity linear 0.9608275894226321
mse for multifidelity linear 0.028281994473427233
(97, 1)
r2 score for standard GP 0.8644554412230872
mse for standard GP 0.0978614899039834
Fraction of points 10
Worst 10 of validations
######################
[[0.25992273]
 [0.27694627]
 [0.29871173]
 [0.32882033]
 [0.33822959]
 [0.35977354]
 [0.44856092]
 [0.52483031]
 [0.59074491]
 [0.62695502]]
Average discrepency ratio for linear mf model
#######################################
2.3855234600



Optimization restart 1/5, f = -485.07051994074357
Optimization restart 2/5, f = 637.797250832107
Optimization restart 3/5, f = 638.6678498810529
Optimization restart 4/5, f = 261.15443907792337
Optimization restart 5/5, f = -486.3731450234674
Optimization restart 1/5, f = 69.6284072337483
Optimization restart 2/5, f = 69.55995798842343
Optimization restart 3/5, f = 69.56261371368828
Optimization restart 4/5, f = 69.56189907586517
Optimization restart 5/5, f = 69.6023975201976
(97, 1)
r2 score for multifidelity linear 0.9799469755156973
mse for multifidelity linear 0.014478034904704513
(97, 1)
r2 score for standard GP 0.9812765675970834
mse for standard GP 0.013518085916540953
Fraction of points 10
Worst 10 of validations
######################
[[0.1923161 ]
 [0.19425618]
 [0.19873719]
 [0.23149728]
 [0.25269447]
 [0.28782254]
 [0.28818825]
 [0.29791694]
 [0.32736981]
 [0.45427265]]
Average discrepency ratio for linear mf model
#######################################
2.5092920707297997




In [None]:
# fig,ax =plt.subplots(nrows=1, ncols=2, figsize=(20,10))
# ax1, ax2 = ax
# ax1.plot(batch_ar,r2_ar[:,0], c ='r', label='multifidelity linear')
# ax1.plot(batch_ar,r2_ar[:,1], c ='y', label='multifidelity nonlinear')
# ax1.plot(batch_ar,r2_ar[:,2], c ='c', label='Standard GP')
#     #l,h=ax.get_ylim()
#     #line_1d = np.linspace(l,h,100)
#     #ax.plot(line_1d,line_1d)
# ax1.set_xlabel('Number of high fidelity points used in training')
# ax1.set_ylabel('R2 score')
# ax1.legend()
    
# ax2.plot(batch_ar,mse_ar[:,0], c ='r', label='multifidelity linear')
# ax2.plot(batch_ar,mse_ar[:,1], c ='y', label='multifidelity nonlinear')
# ax2.plot(batch_ar,mse_ar[:,2], c ='c', label='Standard GP')
# ax2.set_xlabel('Number of high fidelity points used in training')
# ax2.set_ylabel('MSE')
# ax2.legend()
# save_fig(f'{val_i}_batchsize_{n_batch}_ARD_PTB')