In [2]:
%load_ext autoreload
%autoreload 2

import click
import pickle
from tqdm import tqdm
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sbi.inference import FMPE
import os
from utils.data_generation import fetch_grid, sample_joint_distribution, sample_uniform_distribution
from utils.inference import preprocess_inputs, get_prior, get_eval_grid, get_posterior, compute_indicators_sampling_posterior, posterior_and_prior_kdeplot
from lf2i.inference import LF2I
from lf2i.test_statistics.posterior import Posterior
from lf2i.utils.other_methods import hpd_region
from lf2i.plot.parameter_regions import plot_parameter_regions
from lf2i.diagnostics.coverage_probability import compute_indicators_posterior
FLOW_TYPE = 'npe'
B = 300_000  # num simulations to estimate posterior anid test statistics
B_PRIME = 100_000  # num simulations to estimate critical values
B_DOUBLE_PRIME = 30_000  # num simulations to estimate coverage probability
EVAL_GRID_SIZE = 50_000
DISPLAY_GRID_SIZE = 10 # irrelevant now that grid has been defined elsewhere
NORM_POSTERIOR_SAMPLES = None
CONFIDENCE_LEVEL = 0.9
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

POI_DIM = 5
POIS = ['t_eff', 'logg', 'feh_surf', 'logl', 'dist']
LABELS = [r"$T_{eff}$ (K)",
            r"$\log g$ (cgs)",
            r"$[\text{Fe/H}]_{\text{surf}}$ (relative to solar)",
            r"$\log L$ ($L_{\odot}$)",
            r"$d$ (kpc)"]
PRIOR_SETTINGS = [2.0, 1.0, 0.0, -1.0, -2.0]
PRIOR_ARGS = {
    'lower_bound' : torch.tensor([2.5e3, 0.0, -4.0, -1.5, 0.0]),
    'upper_bound' : torch.tensor([1.5e4, 5.0, 0.5, 3.5, 1.0e3])
}
PLOT_PRIORS = True # These figures have already been generated

assets_dir = f'{os.getcwd()}/assets'
os.makedirs(assets_dir, exist_ok=True)
params, seds = fetch_grid(assets_dir=assets_dir) # POI grid + raw SEDs

2024-12-19 09:25:43.742508: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-19 09:25:43.742548: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-19 09:25:43.744370: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-19 09:25:43.751848: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Age-metallicity relationship

In [29]:
prior_samples = []

for PRIOR_SETTING in PRIOR_SETTINGS:
    # Get prior
    try:
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'rb') as f:
            prior = pickle.load(f)
    except:
        theta, x = sample_joint_distribution(params=params,
                                            seds=seds,
                                            args={'age_feh_hyperparam': PRIOR_SETTING,},
                                            n_samples=B,
                                            assets_dir=assets_dir,)
        theta_p, x_p = preprocess_inputs(theta, x, ['t_eff', 'logg', 'feh_surf', 'logl', 'dist'])
        prior = get_prior(theta_p, prior_args=PRIOR_ARGS)
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'wb') as f:
            pickle.dump(prior, f)

    prior_samples.append(prior.sample((1_000,)))

if PLOT_PRIORS:
    theta_dfs = []
    # Convert theta to a pandas DataFrame
    for i, prior_sample in enumerate(prior_samples):
        theta_df_i = pd.DataFrame(prior_sample.numpy(), columns=LABELS)
        theta_df_i['set'] = str(i+1)
        theta_dfs.append(theta_df_i)

    # Convert theta to a pandas DataFrame
    theta_df = pd.concat(theta_dfs)

    # Create pairwise heatmaps
    palette = ["#ca0020", "#f4a582", "#f7f7f7", "#92c5de", "#0571b0"]
    g = sns.pairplot(data=theta_df,
                        hue='set',
                        palette=sns.color_palette(palette, 5),
                        kind='kde',
                        diag_kind='hist')

    for ax in g.axes.ravel():
        ax.invert_xaxis()
        ax.invert_yaxis()

    plt.savefig(f'{assets_dir}/priors_age_metallicity.png')
    plt.close()

## Halo number density

In [None]:
prior_samples = []

for PRIOR_SETTING in np.linspace(0.0, 1.0, 5):
    # Get prior
    try:
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'rb') as f:
            prior = pickle.load(f)
    except:
        theta, x = sample_joint_distribution(params=params,
                                            seds=seds,
                                            args={'halo_hyperparam': PRIOR_SETTING,},
                                            n_samples=B,
                                            assets_dir=assets_dir,)
        theta_p, x_p = preprocess_inputs(theta, x, ['t_eff', 'logg', 'feh_surf', 'logl', 'dist'])
        prior = get_prior(theta_p, prior_args=PRIOR_ARGS)
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'wb') as f:
            pickle.dump(prior, f)

    prior_samples.append(prior.sample((1_000,)))

if PLOT_PRIORS:
    theta_dfs = []
    # Convert theta to a pandas DataFrame
    for i, prior_sample in enumerate(prior_samples):
        theta_df_i = pd.DataFrame(prior_sample.numpy(), columns=LABELS)
        theta_df_i['set'] = str(i+1)
        theta_dfs.append(theta_df_i)

    # Convert theta to a pandas DataFrame
    theta_df = pd.concat(theta_dfs)

    # Create pairwise heatmaps
    palette = ["#ca0020", "#f4a582", "#f7f7f7", "#92c5de", "#0571b0"]
    g = sns.pairplot(data=theta_df,
                        hue='set',
                        palette=sns.color_palette(palette, 5),
                        kind='kde',
                        diag_kind='hist')

    for ax in g.axes.ravel():
        ax.invert_xaxis()
        ax.invert_yaxis()

    plt.savefig(f'{assets_dir}/priors_halo.png')
    plt.close()

## IMF

In [None]:
prior_samples = []

for PRIOR_SETTING in np.linspace(0.0, 1.0, 5):
    # Get prior
    try:
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'rb') as f:
            prior = pickle.load(f)
    except:
        theta, x = sample_joint_distribution(params=params,
                                            seds=seds,
                                            args={'imf_hyperparam': PRIOR_SETTING,},
                                            n_samples=B,
                                            assets_dir=assets_dir,)
        theta_p, x_p = preprocess_inputs(theta, x, ['t_eff', 'logg', 'feh_surf', 'logl', 'dist'])
        prior = get_prior(theta_p, prior_args=PRIOR_ARGS)
        with open(f'{assets_dir}/prior_{PRIOR_SETTING}.pkl', 'wb') as f:
            pickle.dump(prior, f)

    prior_samples.append(prior.sample((1_000,)))

if PLOT_PRIORS:
    theta_dfs = []
    # Convert theta to a pandas DataFrame
    for i, prior_sample in enumerate(prior_samples):
        theta_df_i = pd.DataFrame(prior_sample.numpy(), columns=LABELS)
        theta_df_i['set'] = str(i+1)
        theta_dfs.append(theta_df_i)

    # Convert theta to a pandas DataFrame
    theta_df = pd.concat(theta_dfs)

    # Create pairwise heatmaps
    palette = ["#ca0020", "#f4a582", "#f7f7f7", "#92c5de", "#0571b0"]
    g = sns.pairplot(data=theta_df,
                        hue='set',
                        palette=sns.color_palette(palette, 5),
                        kind='kde',
                        diag_kind='hist')

    for ax in g.axes.ravel():
        ax.invert_xaxis()
        ax.invert_yaxis()

    plt.savefig(f'{assets_dir}/priors_imf.png')
    plt.close()

## Heatmap view of priors

In [3]:
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_pointwise_coverage_full(
    parameter_set: np.ndarray,
    log_probs_1: np.ndarray,
    log_probs_2: np.ndarray,
    title="Coverage Diagnostics: 90% HPD Credible Regions",
    axis_font_size=16,
    title_font_size=20
):
    cmap = plt.cm.inferno
    grid_resolution = 70
    margin = 0.01

    theta_diag_df = pd.DataFrame(parameter_set, columns=LABELS)
    fig, axs = plt.subplots(5, 5, figsize=(30, 30))

    for non_fixed_dim_indexes, axis_labels in zip(
        [[j, i] for i in range(POI_DIM-1) for j in range(i+1, POI_DIM)],
        [[LABELS[i], LABELS[j]] for i in range(POI_DIM-1) for j in range(i+1, POI_DIM)],
    ):
        ax = axs[non_fixed_dim_indexes[0], non_fixed_dim_indexes[1]]
        x_bins = np.histogram_bin_edges(theta_diag_df[axis_labels[0]], bins='auto')
        y_bins = np.histogram_bin_edges(theta_diag_df[axis_labels[1]], bins='auto')
        binned_sum_proba, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins], weights=np.exp(log_probs_1) + 1e-10)
        binned_sum_proba_2, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins], weights=np.exp(log_probs_2) + 1e-10)
        bin_counts, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins])
        heatmap_values = np.log(binned_sum_proba / bin_counts + 1e-10) - np.log(binned_sum_proba_2 / bin_counts + 1e-10)
        heatmap = ax.imshow(heatmap_values.T, origin='lower', extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], aspect='auto', cmap=cmap, vmin=-12, vmax=12) # norm=norm, 
        ax.invert_xaxis()
        ax.invert_yaxis()
        ax.set_xlabel(axis_labels[0], fontsize=axis_font_size)
        ax.set_ylabel(axis_labels[1], fontsize=axis_font_size)

    divider = make_axes_locatable(axs[1, 0])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(heatmap, cax=cax, orientation='vertical')

    for i in range(5):
        for j in range(5):
            if i >= j:
                axs[j, i].axis('off')
                continue

    plt.suptitle(title, fontsize=title_font_size)
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    return fig

In [10]:
grid = sample_uniform_distribution(params, seds, B)
grid_torch = preprocess_inputs(*grid, POIS)

In [12]:
params, params_torch = grid[0], grid_torch[0]

In [15]:
log_prob_inil = prior(params, age_feh_hyperparam=2.0)

In [None]:
for PRIOR_SETTING in [-2.0, -1.0, 0.0, 1.0]:
    log_prob = prior(params, age_feh_hyperparam=PRIOR_SETTING)
    fig = plot_pointwise_coverage_full(
        parameter_set=params_torch,
        log_probs_1=log_prob_inil,
        log_probs_2=log_prob,
        title=f"Log Prior 2.0 - Log Prior {PRIOR_SETTING}",
        axis_font_size=16,
        title_font_size=20
    )
    plt.savefig(f'{assets_dir}/prior_comparison_2.0_{PRIOR_SETTING}.png')
    plt.close()

## Zoom in on the above

In [22]:
def plot_pointwise_coverage_full_zoom_in(
    parameter_set: np.ndarray,
    log_probs_1: np.ndarray,
    log_probs_2: np.ndarray,
    title="Coverage Diagnostics: 90% HPD Credible Regions",
    axis_font_size=16,
    title_font_size=20
):
    cmap = plt.cm.inferno
    grid_resolution = 70
    margin = 0.01

    theta_diag_df = pd.DataFrame(parameter_set, columns=LABELS)
    fig, ax = plt.subplots(1, 1, figsize=(30, 30))
    non_fixed_dim_indexes = [0, 1]
    axis_labels = [LABELS[0], LABELS[1]]

    x_bins = np.histogram_bin_edges(theta_diag_df[axis_labels[0]], bins='auto')
    y_bins = np.histogram_bin_edges(theta_diag_df[axis_labels[1]], bins='auto')
    binned_sum_proba, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins], weights=np.exp(log_probs_1) + 1e-10)
    binned_sum_proba_2, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins], weights=np.exp(log_probs_2) + 1e-10)
    bin_counts, xedges, yedges = np.histogram2d(theta_diag_df[axis_labels[0]], theta_diag_df[axis_labels[1]], bins=[x_bins, y_bins])
    heatmap_values = np.log(binned_sum_proba / bin_counts + 1e-10) - np.log(binned_sum_proba_2 / bin_counts + 1e-10)
    heatmap = ax.imshow(heatmap_values.T, origin='lower', extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], aspect='auto', cmap=cmap, vmin=-12, vmax=12) # norm=norm, 
    ax.invert_xaxis()
    ax.invert_yaxis()
    ax.set_xlabel(axis_labels[0], fontsize=axis_font_size)
    ax.set_ylabel(axis_labels[1], fontsize=axis_font_size)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(heatmap, cax=cax, orientation='vertical')
    cbar.ax.tick_params(labelsize=axis_font_size)

    plt.suptitle(title, fontsize=title_font_size)
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    return fig

In [25]:
for PRIOR_SETTING in [-2.0, -1.0, 0.0, 1.0]:
    log_prob = prior(params, age_feh_hyperparam=PRIOR_SETTING)
    fig = plot_pointwise_coverage_full_zoom_in(
        parameter_set=params_torch,
        log_probs_1=log_prob_inil,
        log_probs_2=log_prob,
        title=f"Log Prior 2.0 - Log Prior {PRIOR_SETTING}",
        axis_font_size=30,
        title_font_size=44
    )
    plt.savefig(f'{assets_dir}/prior_comparison_2.0_{PRIOR_SETTING}_zoom_in.png')
    plt.close()