In [None]:
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import os

import scripts
from scripts import sbi_tools
from scripts import plot_utils

%load_ext autoreload
%autoreload 2

%matplotlib inline
mpl.pyplot.style.use('default')
mpl.pyplot.close('all')

font, rcnew = plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
mpl.pyplot.rcParams.update(rcnew)
mpl.pyplot.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'

N_threads = sbi_tools.set_N_threads(6)

In [None]:
rng = np.random.default_rng(seed=533)
n_train = 1000
n_dim = 10
theta = np.expand_dims(rng.uniform(size=n_train), axis=1)
print(theta.shape)
Pk = rng.normal(loc=theta[0], size=(n_dim, n_train)).T
print(Pk.shape)

In [None]:
train_val_split = int(theta.shape[0]*0.8)

theta_train = theta[:train_val_split]
theta_test = theta[train_val_split:]
print(theta_train.shape, theta_test.shape)

Pk_train = Pk[:train_val_split]
Pk_test = Pk[train_val_split:]

In [None]:
dict_bounds = {
    'mu'    :  [-5, 5],
}

In [None]:
# class Scaler:

#     def __init__(self):
#           pass
        
#     def fit(self, x_train):
#         self.x_train_min = np.min(x_train)
#         self.x_train_max = np.max(x_train)
           
#     def scale(self, x):
#         log_x = np.log10(x)
#         log_x_norm = (log_x - np.log10(self.x_train_min)) / (np.log10(self.x_train_max) - np.log10(self.x_train_min))
#         return log_x_norm
    
#     def unscale(self, x_scaled):
#         x = x_scaled * (np.log10(self.x_train_max) - np.log10(self.x_train_min)) + np.log10(self.x_train_min)
#         return 10**x  

In [None]:
# scaler = Scaler()
# scaler.fit(Pk_train)
# Pk_train_scaled = scaler.scale(Pk_train)
# Pk_test_scaled = scaler.scale(Pk_test)
Pk_train_scaled = Pk_train
Pk_test_scaled =  Pk_test

In [None]:
print(np.min(Pk_train), np.max(Pk_train))
print(np.min(Pk_train_scaled), np.max(Pk_train_scaled))

print(np.min(Pk_test), np.max(Pk_test))
print(np.min(Pk_test_scaled), np.max(Pk_test_scaled))

In [None]:
print(Pk_train.shape)
print(theta_train.shape)

In [None]:
inference, posterior = sbi_tools.train_model(
    theta_train,
    Pk_train_scaled,
    prior= sbi_tools.get_prior(dict_bounds),
    training_batch_size=16,
    validation_fraction=0.2
)

In [None]:

#idx_train_check = rng.choice(np.arange(len(theta_train)))
idx_train_check = 42

print(idx_train_check)
theta_train_check = np.array([theta_train[idx_train_check]])
print(theta_train_check)
Pk_train_scaled_check = np.array([Pk_train_scaled[idx_train_check]])

In [None]:
# ------------------ posterior inference ------------------ #

# norm_xx_test = scaler.transform(xx_test)

inferred_theta_train_check = sbi_tools.sample_posteriors_theta_test(
    posterior,
    Pk_train_scaled_check,
    dict_bounds,
    N_samples=1000
)

# ------------------ rank stats ------------------ #

ranks_train_check = sbi_tools.compute_ranks(theta_train_check, inferred_theta_train_check)

In [None]:
custom_titles = [
    r'$\mu$'
]

In [None]:
N_examples = 1

colors = plot_utils.get_N_colors(N_examples, mpl.colormaps['prism'])
for ii_sample in range(N_examples):
    fig, axs = plot_utils.corner_plot(
        theta_train_check[ii_sample],
        inferred_theta_train_check[ii_sample],
        custom_titles,
        dict_bounds,
        color_infer=colors[ii_sample]
    )
    mpl.pyplot.show()
    
#fig.save("popopo.png")

In [None]:
# ------------------ posterior inference ------------------ #

# norm_xx_test = scaler.transform(xx_test)

inferred_theta_test = sbi_tools.sample_posteriors_theta_test(
    posterior,
    Pk_test_scaled,
    dict_bounds,
    N_samples=1000
)

# ------------------ rank stats ------------------ #

ranks = sbi_tools.compute_ranks(theta_test, inferred_theta_test)

In [None]:
N_examples = 1

colors = plot_utils.get_N_colors(N_examples, mpl.colormaps['prism'])
for ii_sample in range(N_examples):
    fig, axs = plot_utils.corner_plot(
        theta_test[ii_sample],
        inferred_theta_test[ii_sample],
        custom_titles,
        dict_bounds,
        color_infer=colors[ii_sample]
    )
    mpl.pyplot.show()
    
#fig.save("popopo.png")

In [None]:
fig, axs = plot_utils.plot_parameter_prediction_vs_truth(inferred_theta_test, theta_test, custom_titles)
plt.tight_layout()
plt.subplots_adjust(wspace=0.6)
plt.show()

In [None]:
fig, axs = plot_utils.plot_rank_statistcis(ranks, inferred_theta_test.shape[1], custom_titles)
plt.tight_layout()
plt.subplots_adjust(wspace=0.05)
plt.show()