In [1]:
%load_ext autoreload
%autoreload 2

import click
import pickle
from tqdm import tqdm
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sbi.inference import FMPE
import os
from utils.data_generation import fetch_grid, sample_joint_distribution, sample_uniform_distribution
from utils.inference import preprocess_inputs, get_prior, get_eval_grid, get_posterior, compute_indicators_sampling_posterior, posterior_and_prior_kdeplot
from lf2i.inference import LF2I
from lf2i.test_statistics.posterior import Posterior
from lf2i.utils.other_methods import hpd_region
from lf2i.plot.parameter_regions import plot_parameter_regions
from lf2i.diagnostics.coverage_probability import compute_indicators_posterior

2024-12-18 09:05:54.150001: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-18 09:05:54.150050: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-18 09:05:54.151841: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-18 09:05:54.159442: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
FLOW_TYPE = 'npe'
B = 300_000  # num simulations to estimate posterior anid test statistics
B_PRIME = 100_000  # num simulations to estimate critical values
B_DOUBLE_PRIME = 30_000  # num simulations to estimate coverage probability
EVAL_GRID_SIZE = 50_000
DISPLAY_GRID_SIZE = 10 # irrelevant now that grid has been defined elsewhere
NORM_POSTERIOR_SAMPLES = None
CONFIDENCE_LEVEL = 0.9
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

POI_DIM = 5
POIS = ['t_eff', 'logg', 'feh_surf', 'logl', 'dist']
LABELS = [r"$T_{eff}$ (K)",
            r"$\log g$ (cgs)",
            r"$[\text{Fe/H}]_{\text{surf}}$ (relative to solar)",
            r"$\log L$ ($L_{\odot}$)",
            r"$d$ (kpc)"]
PRIOR_SETTINGS = [2.0, 1.0, 0.0, -1.0, -2.0]
PRIOR_ARGS = {
    'lower_bound' : torch.tensor([2.5e3, 0.0, -4.0, -1.5, 0.0]),
    'upper_bound' : torch.tensor([1.5e4, 5.0, 0.5, 3.5, 1.0e3])
}
PLOT_PRIORS = False # These figures have already been generated

assets_dir = f'{os.getcwd()}/assets'
os.makedirs(assets_dir, exist_ok=True)
params, seds = fetch_grid(assets_dir=assets_dir) # POI grid + raw SEDs

## Plot sets

In [3]:
PRIOR_SETTING = 2.0
example_dir_for_setting = f'{os.getcwd()}/results/results_121624/example_{FLOW_TYPE}/setting_{PRIOR_SETTING}'

In [4]:
with open(f'{assets_dir}/tryout_display_grid.pkl', 'rb') as f:
    theta_p_d, x_p_d = pickle.load(f)
with open(f'{assets_dir}/tryout_display_grid_labels.pkl', 'rb') as f:
    display_labels = pickle.load(f)

In [5]:
with open(f'{assets_dir}/tryout_set_{PRIOR_SETTING}.pkl', 'rb') as f:
    theta_p, x_p = pickle.load(f)

In [6]:
with open(f'{example_dir_for_setting}/credible_regions.pkl', 'rb') as f:
    credible_regions = pickle.load(f)

In [7]:
with open(f'{example_dir_for_setting}/confidence_sets.pkl', 'rb') as f:
    confidence_sets = pickle.load(f)

In [8]:
from lf2i.plot.parameter_regions import plot_parameter_region_2D
from matplotlib.gridspec import GridSpec

In [9]:
with open(f'{assets_dir}/tryout_evaluation_grid.pkl', 'rb') as f:
    theta_e_c, x_e_c = pickle.load(f)

In [10]:
parameter_space_max = theta_e_c.numpy().max(axis=0)
parameter_space_min = theta_e_c.numpy().min(axis=0)
parameter_space_max, parameter_space_min

(array([1.4177014e+04, 4.9059000e+00, 4.9828658e-01, 3.4386032e+00,
        1.0000000e+02], dtype=float32),
 array([ 2.7915581e+03, -2.0189059e-01, -3.9302344e+00, -1.4561915e+00,
         9.9999998e-03], dtype=float32))

In [11]:
parameter_space_bounds = dict(zip(POIS, [{'low': parameter_space_min[i], 'high': parameter_space_max[i]} for i in range(5)]))
parameter_space_bounds

{'t_eff': {'low': 2791.558, 'high': 14177.014},
 'logg': {'low': -0.20189059, 'high': 4.9059},
 'feh_surf': {'low': -3.9302344, 'high': 0.49828658},
 'logl': {'low': -1.4561915, 'high': 3.4386032},
 'dist': {'low': 0.01, 'high': 100.0}}

In [19]:
LABELS

['$T_{eff}$ (K)',
 '$\\log g$ (cgs)',
 '$[\\text{Fe/H}]_{\\text{surf}}$ (relative to solar)',
 '$\\log L$ ($L_{\\odot}$)',
 '$d$ (kpc)']

In [27]:
theta_grid = theta_e_c.numpy()[np.random.choice(np.arange(len(theta_e_c)), 10_000)]

In [30]:
for idx, (true_theta, cr, cs) in enumerate(zip(theta_p_d, credible_regions, confidence_sets)):
    cr = cr[1]
    set_colors = ["magenta", "green"]
    fig, axs = plt.subplots(5, 5, figsize=(12, 12))

    for j in range(5):
        for i in range(5):
            if i <= j:
                axs[j, i].axis('off')
                continue

            axs[j, i].scatter(x=theta_grid[:, j], y=theta_grid[:, i], alpha=0.01, color='black', s=1)

            plot_parameter_region_2D(
                cs[:, [j, i]],
                None, 
                parameter_space_bounds=parameter_space_bounds,
                param_names=[POIS[j], POIS[i]],
                labels=[LABELS[j], LABELS[i]],
                custom_ax=axs[j, i],
                scatter=True,
                alpha_shape=False,
                alpha=8,
                color=set_colors[1]
            )

            plot_parameter_region_2D(
                cr[:, [j, i]],
                None, 
                parameter_space_bounds=parameter_space_bounds,
                param_names=[POIS[j], POIS[i]],
                labels=[LABELS[j], LABELS[i]],
                custom_ax=axs[j, i],
                scatter=True,
                alpha_shape=False,
                alpha=8,
                color=set_colors[0]
            )

            axs[j, i].scatter(x=true_theta[j], y=true_theta[i], alpha=1, color='red', marker="*", s=250, zorder=10)
            axs[j, i].set_xlabel(LABELS[j])
            axs[j, i].set_ylabel(LABELS[i])
            axs[j, i].invert_yaxis()
            axs[j, i].invert_xaxis()

    plt.suptitle(f"Point {display_labels[idx]}")
    plt.tight_layout()
    plt.savefig(f'{example_dir_for_setting}/scatter_regions/parameter_region_{display_labels[idx]}.png')
    # plt.show()
    plt.close()   

In [71]:
import torch

In [None]:
center = np.mean(theta_e_c.numpy(), axis=0)
std = np.std(theta_e_c.numpy(), axis=0)

In [None]:
for idx, (true_theta, cr, cs) in enumerate(zip(theta_p_d, credible_regions, confidence_sets)):
    cr = (cr[1] - center) / std
    cs = (cs - center) / std
    true_theta = (true_theta - center) / std
    set_colors = ["magenta", "green"]

    fig, axs = plt.subplots(5, 5, figsize=(12, 12))

    for i in range(5):
        for j in range(5):
            if i <= j:
                axs[i, j].axis('off')
                continue

            plot_parameter_region_2D(
                cr[:, [i, j]],
                None,
                # parameter_space_bounds=parameter_space_bounds,
                param_names=[POIS[i], POIS[j]],
                labels=[LABELS[i], LABELS[j]],
                custom_ax=axs[i, j],
                scatter=False,
                alpha_shape=True,
                alpha=8,
                color=set_colors[0]
            )
            plot_parameter_region_2D(
                cs[:, [i, j]],
                None, 
                # parameter_space_bounds=parameter_space_bounds,
                param_names=[POIS[i], POIS[j]],
                labels=[LABELS[i], LABELS[j]],
                custom_ax=axs[i, j],
                scatter=False,
                alpha_shape=True,
                alpha=12,
                color=set_colors[1]
            )

            axs[i, j].scatter(x=true_theta[i], y=true_theta[j], alpha=1, color='red', marker="*", s=250, zorder=10)
            axs[i, j].set_xlabel(LABELS[i])
            axs[i, j].set_ylabel(LABELS[j])
    plt.suptitle(f"Point {display_labels[idx]}")
    plt.tight_layout()
    plt.savefig(f'{example_dir_for_setting}/alphashape_regions/parameter_region_{display_labels[idx]}.png')
    # plt.show()
    plt.close()   



In [None]:
def plot_paper_figure_2(
    # config,
    # hpd_lpd_obs_id_pair,
    data_sets, # [(high prior hpd, high prior lf2i), (low prior hpd, low prior lf2i)]
    # test_ds, 
    hpd_coverage_estimator,
    lf2i_coverage_estimator,
    test_layout=False, 
    axis_font_size=12, 
    title_font_size=14
):
    set_colors = ["magenta", "green"]
    # with plt.style.context("stylesheets/538-roboto-nogrid.mplstyle"):
        # hpd_obs, lpd_obs = test_ds[hpd_lpd_obs_id_pair[0]], test_ds[hpd_lpd_obs_id_pair[1]]

    fig, axs = plt.subplots(1, 3, figsize=(12, 3.5))
        # fig.add_artist(patches.Rectangle(
        #     (0, 0),          # Lower-left corner in figure coordinates
        #     1, 1,            # Width and height in figure coordinates (normalized)
        #     transform=fig.transFigure,  # Use figure coordinates
        #     edgecolor="darkgray",
        #     linewidth=2,
        #     facecolor="gainsboro",
        #     zorder=0  # Place the rectangle behind everything
        # ))

    for row, test_obs, data_set_pair, truth_color in zip([0, 1], [hpd_obs, lpd_obs], data_sets, ["blue", "red"]):
        hpd_coverage = hpd_coverage_estimator.predict_proba(
            np.array([test_obs["params"].numpy()])
        )[:, 1][0]
        
        waldo_coverage = lf2i_coverage_estimator.predict_proba(
            np.array([test_obs["params"].numpy()])
        )[:, 1][0]

        for point_set, color in zip(reversed(data_set_pair), reversed(set_colors)):
            for col, pair, param_component, slice_component, axis_labels in zip(
                [0, 1, 2],
                [["log10_energy", "zenith"], ["azimuth", "log10_energy"], ["azimuth", "zenith"]],
                [[0, 1], [2, 0], [2, 1]],
                [2, 1, 0],
                [["Log10 Energy (GeV)", "Zenith Angle (Rad)"], ["Azimuth Angle (Rad)", "Log10 Energy (GeV)"], ["Azimuth Angle (Rad)", "Zenith Angle (Rad)"]]
            ):
                ax = axs[row, col] 
                # plot_priors(ax, False, pair[0], pair[1], 2_000)
                if not test_layout:
                    _slice, _ = pos.slice_param_set(point_set, slice_component, test_obs["params"][slice_component].item())
                    slice_points = _slice[:, [param_component[0], param_component[1]]]
                    truth_slice = test_obs["params"][[param_component[0], param_component[1]]].numpy()
                    ax.scatter(x=truth_slice.reshape(-1,)[0], y=truth_slice.reshape(-1,)[1], alpha=1, color=truth_color, marker="*", s=250, zorder=10)
                    
                    plot_parameter_region_2D(
                        slice_points,
                        None, 
                        custom_ax=ax,
                        scatter=False,
                        alpha_shape=True,
                        alpha=8,
                        color=color
                    )
                
                    
                    ax.axvline(test_obs["params"][param_component[0]].item(), color=truth_color, linestyle="dashed")
                    ax.axhline(test_obs["params"][param_component[1]].item(), color=truth_color, linestyle="dashed")
                    ax.set_xlim(config.eval_param_mins[param_component[0]].item(), config.eval_param_maxes[param_component[0]].item())
                    ax.set_ylim(config.eval_param_mins[param_component[1]].item(), config.eval_param_maxes[param_component[1]].item())
                    ax.set_xlabel(axis_labels[0], fontsize=axis_font_size)
                    ax.set_ylabel(axis_labels[1], fontsize=axis_font_size)
                    
                    if col == 1:
                        ax.set_title(f"{['High', 'Low'][row]} Prior Density Observation | HPD Coverage = {hpd_coverage:0.3f} | LF2I Coverage = {waldo_coverage:0.3f}", fontsize=title_font_size)
    
    handles = [lines.Line2D([0], [0], color=color) for color in set_colors]
    labels = ["90% HPD Credible Region", "90% LF2I Confidence Set"]
    axs[0, 0].legend(
        handles,
        labels
    )
    # fig.suptitle("Panel B: 90% HPD Credible Regions vs 90% LF2I Confidence Sets", fontsize=20, fontweight="bold")
    fig.tight_layout()
    return fig, axs
