# viz

> Visulaization modules

In [None]:
#| default_exp viz

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import warnings
warnings.simplefilter("ignore", UserWarning)
import seaborn

import numpy as np
import os
import matplotlib as mpl
import matplotlib.pylab as plt
import pandas as pd
import sepia.SepiaPlot as SepiaPlot
from sepia.SepiaModel import SepiaModel
from sepia.SepiaData import SepiaData
from matplotlib import ticker
from itertools import cycle
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import pygtc

In [None]:
#| export

def plot_lines_with_param_color(param_array:np.array=None, # parameter array
                                x_array:np.array=None, # x-axis array
                                y_array_all:np.array=None, # y-axis array
                                title_str:str=None, # Title string
                                xlabel_str:str=None, # x-label string
                                ylabel_str:str=None, # y-label string
                                param_name_str:str=None, # Parameter string,
                                ax: plt.Axes = None,
                                y_log_plot_scale:bool=False
                               ):
    
    
    # f, a = plt.subplots(1,1, figsize = (8, 5)
    if ax is None:
        f, ax = plt.subplots(1, 1, figsize=(8, 5))
    else:
        f = ax.figure  

    norm = mpl.colors.Normalize(vmin=param_array.min(), vmax=param_array.max())
    cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.plasma)
    cmap.set_array([])

    for sim_index in range(param_array.shape[0]):
        ax.plot(x_array, y_array_all[sim_index], 
                 '-', alpha= 0.5, c=cmap.to_rgba(param_array[sim_index]), 
                 label='Sim: '+str(sim_index) )

    # plt.plot(stellar_mass_test, gsmf_um_test, 'k.', label='UM z=0.00')
    ax.set_xscale('log')
    if y_log_plot_scale: ax.set_yscale('log')
    # plt.axhline(y=0, linestyle='dashed', color='black')
    # plt.yscale('log')
    # plt.xlim(4e9, )
    # plt.ylim(-0.02, 0.02)

    if ax is None:
        clb = f.colorbar(cmap , ax=ax)
        clb.ax.tick_params(labelsize=15) 
        clb.ax.set_title(param_name_str, fontsize=15)

    # plt.title(title_str, fontsize=15)
    # ax.set_xlabel(xlabel_str, fontsize=15)
    # ax.set_ylabel(ylabel_str, fontsize=15)

    if title_str and ax is None:
        plt.title(title_str, fontsize=15)
    if xlabel_str:
        ax.set_xlabel(xlabel_str, fontsize=15)
    if ylabel_str:
        ax.set_ylabel(ylabel_str, fontsize=15)
    
    return f

In [None]:
#| export

def plot_scatter_matrix(df:pd.DataFrame=None, 
                        colors:str=None,
                       ): 
    
    f, a = plt.subplots(1,1, figsize = (10, 10))
    scatter_matrix = pd.plotting.scatter_matrix(df, 
                                                color=colors,  
                                                figsize=(10,10), 
                                                alpha=1.0, 
                                                ax=a, 
                                                grid=False, 
                                                diagonal='hist',
                                                range_padding=0.1,
                                                s=80);

    for ax in scatter_matrix.ravel():
        ax.set_xlabel(ax.get_xlabel(), fontsize = 14, rotation = 0)
        ax.set_ylabel(ax.get_ylabel(), fontsize = 14, rotation = 90)
        
    return f

In [None]:
#| export

def plot_train_diagnostics(sepia_model:SepiaModel=None, # Input data in SEPIA format, after PCA
                          ) -> tuple: #Pair-plot and Trace-plot
    samples_dict = sepia_model.get_samples()
    theta_pairs = SepiaPlot.theta_pairs(samples_dict)
    mcmc_trace = SepiaPlot.mcmc_trace(samples_dict)
    
    return theta_pairs, mcmc_trace

In [None]:
#| export

def sensitivity_plot(k_all:np.array=None, # all wavenumbers
                     params_all:np.array=None, # all parameters
                     sepia_model:SepiaModel=None, # SEPIA emulator model
                     emulator_function=None, # function which takes in sepia model and parameters
                     param_name:tuple=None, # Parameter name
                     xy_lims:np.array=[1e-5, 2e5, 1, 1e4],
                     y_log_plot_scale:bool=False 
                    ):

    color_by_index = 0
    colorparams = params_all[:, color_by_index]
    # colorparams = X_test_transformed1[:, color_by_index]
    colormap = cm.Dark2
    normalize = mcolors.Normalize(vmin=np.min(colorparams), vmax=np.max(colorparams))

    allMax = np.max(params_all, axis = 0)
    allMin = np.min(params_all, axis = 0)
    allMean = np.mean(params_all, axis = 0)

    num_subplots = params_all.shape[-1]

    numPlots = 300

    fig, ax = plt.subplots(num_subplots,1, figsize = (7, 15), sharex='col')
    plt.subplots_adjust(wspace=0.25)
    plt.subplots_adjust(hspace=0.05)
    plt.suptitle('Sensitivity analysis using emulator', fontsize=18, y=0.95)


    for paramNo in range(num_subplots):
            para_range = np.linspace(allMin[paramNo], allMax[paramNo], numPlots)        
            # colorList = plt.cm.coolwarm(np.linspace(0,1,numPlots))

            colormap = cm.coolwarm
            normalize = mcolors.Normalize(vmin=np.min(allMin[paramNo]), vmax=allMax[paramNo])

            for plotID in range(numPlots):
                    para_plot = np.copy(allMean)
                    para_plot[paramNo] = para_range[plotID]  #### allMean gets changed everytime!!

                    color = colormap(normalize(para_plot[paramNo]))

                    gsmf_decoded, _ = emulator_function(sepia_model, para_plot)

                    lineObj = ax[paramNo].plot(k_all, gsmf_decoded, lw= 1.5, color = color) 

                    # ax[paramNo].set_yscale('log')
                    ax[paramNo].set_xscale('log')
                    if y_log_plot_scale: ax[paramNo].set_yscale('log')
                    ax[paramNo].set_ylabel('P(k)', fontsize=18)
                    ax[paramNo].set_yticks([], minor = True)
                    
                    ax[paramNo].set_xlim(xy_lims[0], xy_lims[1])
                    ax[paramNo].set_ylim(xy_lims[2], xy_lims[3])
                            
            
            # Colorbar setup
            s_map = cm.ScalarMappable(norm=normalize, cmap=colormap)
            s_map.set_array(colorparams)

            # If color parameters is a linspace, we can set boundaries in this way
            halfdist = (colorparams[1] - colorparams[0])/2.0
            boundaries = np.linspace(colorparams[0] - halfdist, colorparams[-1] + halfdist, len(colorparams) + 1)

            cbar = fig.colorbar(s_map, spacing='proportional', ax=ax[paramNo])

            cbarlabel = param_name[paramNo]
            cbar.set_label(cbarlabel, fontsize=20)
            ax[paramNo].fill_between(k_all, xy_lims[0], xy_lims[1], where=(k_all > 1.2), color='k', alpha=0.15)


    ax[paramNo].set_xlabel('k[h/Mpc]', fontsize=18)
    # plt.show()

    
    return fig

In [None]:
#| export 

def validation_plot(k_all:np.array=None, 
                    target_vals:np.array=None, 
                    pred_mean:np.array=None, 
                    pred_std:np.array=None, 
                    xy_lims:np.array=[2e-2, 1e1, 0.98, 1.3],
                    y_log_plot_scale:bool=False
                    ):
    
    delta_y_lims = [-0.0401, 0.0401]

    f, a = plt.subplots(2, 1, figsize=(8, 6), gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.05)

    colors = ['b', 'lightgreen', 'g', 'orange', 'cyan', 'r']
    colors = colors[0: pred_mean.shape[1]]
    styles = ['-', '--']
    styles_label = ['Sim', 'Emulated mean']


    for one_index in range(pred_mean.shape[1]):

        a[0].plot(k_all, target_vals[one_index], c=colors[one_index], ls=styles[0])
        a[0].plot(k_all, pred_mean[:, one_index], c=colors[one_index], ls=styles[1])
        # a[0].plot(k_all, pred_quant[:, one_index, 0], c=colors[one_index], ls=styles[2])

        # a[0].fill_between(k_all, pred_quant[:, one_index, 0], pred_quant[:, one_index, 1], color=colors[one_index], alpha=0.2) 
        a[0].fill_between(k_all, pred_mean[:, one_index] - pred_std[:, one_index], pred_mean[:, one_index] + pred_std[:, one_index], 
                          color=colors[one_index], alpha=0.2) 

        #'Emulated (0.05, 0.95) quantile'


        a[1].plot(k_all, (pred_mean[:, one_index]/target_vals[one_index]) - 1, c=colors[one_index])

    for cc, col in enumerate(colors):
        a[0].plot(np.NaN, np.NaN, c=colors[cc], label='T' + str(cc + 1))

    ax2 = a[0].twinx()
    for ss, sty in enumerate(styles):
        ax2.plot(np.NaN, np.NaN, ls=styles[ss],
                 label=styles_label[ss], c='black')

    ax2.get_yaxis().set_visible(False)
    
    # a[0].fill_between(k_all, xy_lims[0], xy_lims[1], where=(k_all > 1.2), color='k', alpha=0.15)
    # a[1].fill_between(k_all, delta_y_lims[0], delta_y_lims[1], where=(k_all > 1.2), color='k', alpha=0.15)


    a[0].legend(loc=2, title='Test cosmologies', ncol=2)
    ax2.legend(loc=3)
    a[1].set_xlabel('k[h/Mpc]')
    a[1].set_ylabel(r'$\delta P(k)/P(k)$')
    a[0].set_ylabel(r'$P(k) [(Mpc/h)^3]$')
    a[0].set_xscale('log')
    if y_log_plot_scale: a[0].set_yscale('log')
    # plt.show()
    a[0].set_xlim(xy_lims[0], xy_lims[1])
    a[0].set_ylim(xy_lims[2], xy_lims[3])
    a[1].set_ylim(delta_y_lims[0], delta_y_lims[1])
    
    return f

In [None]:
#| export 

def plot_mcmc(samples:np.array, 
              params_list:list, 
              if_truth_know:bool=False):
    
    # Extract parameter names and truths, assuming params_list structure as
    # [(name, truth, min_range, max_range), ...]
    param_names = [param[0] for param in params_list]
    truths = [param[1] for param in params_list] if if_truth_know else None
    param_ranges = [(param[2], param[3]) for param in params_list]

    # Configure the plot settings
    fig = pygtc.plotGTC(samples, 
                        paramNames=param_names,
                        truths=truths,
                        figureSize=8,
                        plotDensity=True,
                        filledPlots=True,
                        smoothingKernel=3,
                        nContourLevels=3,
                        customLabelFont={'family': 'DejaVu Sans', 'size': 12},
                        customTickFont={'family': 'DejaVu Sans', 'size': 12},
                        # paramRanges=param_ranges
                        )
    
    return fig

In [None]:
#| hide

# from getdist import plots, MCSamples
# import re

# def latex_to_plain(text):
#     # Replace LaTeX specific symbols
#     text = re.sub(r'\$', '', text)  # Remove $ signs
#     text = re.sub(r'\\', '', text)  # Remove backslashes
#     text = re.sub(r'\{|\}', '', text)  # Remove braces
#     return text


# def plot_mcmc_getdist(samples:np.array, 
#               params_list:list, 
#               if_truth_know:bool=False):


#     # from getdist.gaussian_mixtures import GaussianND
#     # mat = - h
#     # cov = np.linalg.pinv(mat)

#     # covariance = cov #[[0.001**2, 0.0006*0.05, 0], [0.0006*0.05, 0.05**2, 0.2**2], [0, 0.2**2, 2**2]]
#     # mean =np.array([para4[1], para5[1]]) #params #[0.02, 1, -2] 
#     # gauss=GaussianND(mean, covariance, names=['logfr0','n'], labels=[para4[0], para5[0]], label = 'Fisher')
#     # # g = plots.get_subplot_plotter()
#     # # g.triangle_plot(gauss,filled=True)

#     # names = ['a', 'b', 'c', 'd', 'e']
#     # s1 = samples_plot[samples_plot[:, 0] > -5.3]

#     s1 = samples
#     param_names = [param[0] for param in params_list]
#     PARAM_NAME_trial = np.array([latex_to_plain(item) for item in param_names]) #['Omega_m', 'n_s', '10^{9} A_s', 'h', 'f_\phi']

#     samples1 = MCSamples(samples=s1, 
#                         names=PARAM_NAME_trial, 
#                         labels=PARAM_NAME_trial, 
#                         label='MCMC' 
#                         )#, ranges={'logfr0':(-5.2, -4.8), 'n':(0.5, 1.5)})
#     g = plots.get_subplot_plotter(subplot_size=4)
#     g.settings.axes_fontsize=27
#     g.settings.axes_labelsize = 27
#     g.settings.legend_fontsize = 27
#     g.settings.fontsize = 27
#     g.settings.alpha_filled_add=0.6
#     # g.settings.title_limit_fontsize = 27
#     g.settings.solid_contour_palefactor = 0.5
#     g.triangle_plot([samples1],  
#                     # ['logfr0','n'], 
#                     filled=True, 
#                     # markers={'logfr0':para4[1], 'n':para5[1]}, 
#                     markers=[param[1] for param in params_list],
#                     marker_args={'lw':2, 'ls':'dashed', 'color':'k'}
#                     )

#     # g.triangle_plot([samples1, gauss],  ['logfr0','n'], filled=True, markers={'logfr0':para4[1], 'n':para5[1]}, marker_args={'lw':2, 'ls':'dashed', 'color':'k'})
#     # g.plot_contours(h, params, fill=True, alpha=0.5, label = 'Fisher info', facecolor = 'red')

#     # g.export('../../../Plots/triangle_plot.png')

In [None]:
#| export 

def generate_param_grid_with_fixed(param_name:list=None, 
                                   param_indices:np.array=None, 
                                   fixed_params:np.array=None, 
                                   param_min:np.array=None, 
                                   param_max:np.array=None, 
                                   steps:int=40
                                   ):
    # param_indices: Indices of the parameters to vary
    # fixed_params: Dictionary of fixed parameter values with parameter names as keys

    # Generate grids for the two varying parameters
    varying_param_grids = [np.linspace(param_min[idx], param_max[idx], steps) for idx in param_indices]
    grid_a, grid_b = np.meshgrid(*varying_param_grids)

    # Generate the full parameter grid by inserting fixed values for other parameters
    full_grid = np.empty((steps * steps, len(param_min)))
    
    # Set varying parameters
    full_grid[:, param_indices[0]] = grid_a.ravel()
    full_grid[:, param_indices[1]] = grid_b.ravel()
    
    # Set fixed parameters
    for i in range(len(param_min)):
        if i not in param_indices:
            full_grid[:, i] = fixed_params[param_name[i]]
    
    return full_grid


# Plot the heatmap of errors
def plot_error_heatmap(errors:np.array=None, 
                       param_names:list=None, 
                       param_range:tuple=None):
    f = plt.figure(figsize=(5, 4))
    plt.imshow(errors, extent=(param_range[0][0], param_range[0][1], param_range[1][0], param_range[1][1]),
               origin='lower', aspect='auto', cmap='YlOrRd')
    plt.colorbar(label='Mean variance across k[h/Mpc]')
    plt.xlabel(param_names[0])
    plt.ylabel(param_names[1])
    plt.title('Heatmap of emulator variance')
    return f

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()