# fisher

In [1]:
#| default_exp fisher

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export

import numpy as np
import emcee
import time
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

from scipy.integrate import simps
from EarlyDarkEmu.load import *
from EarlyDarkEmu.viz import *
from EarlyDarkEmu.pca import *
from EarlyDarkEmu.gp import *
from EarlyDarkEmu.emu import *
from EarlyDarkEmu.mcmc import *


(210,)


In [4]:
#| export
class FisherMatrix:
    def __init__(self, xvals, yvals, y_errors, params):
        self.xvals = np.array(xvals)
        self.yvals = np.array(yvals)
        self.y_errors = np.array(y_errors)
        self.params = np.array(params)
        # self.redshift = self.redshift
        self.n_params = len(params)
        self.fmat = np.zeros((self.n_params, self.n_params))

    def compute_fmat(self, derivatives):
        for i in range(self.n_params):
            for j in range(self.n_params):
                integrand = derivatives[i] * derivatives[j] / self.y_errors**2
                # self.fmat[i, j] = np.sum(integrand)
                self.fmat[i, j] = simps(integrand, self.xvals)
        return self.fmat

    def uncertainty(self):
        if np.linalg.det(self.fmat) == 0:
            raise ValueError("Fisher is singular.")
        inv_fisher = np.linalg.inv(self.fmat)
        return np.sqrt(np.diagonal(inv_fisher))
    

def deriv(params_array, redshift, forward_func, perturb = 1e-5):
    n_par = len(params_array)
    yvals, _ = forward_func(params_array, redshift)
    derivative = np.zeros((n_par, len(yvals)))

    for i in range(n_par):
        params_up  = params_array.copy()
        params_down = params_array.copy()

        params_up[i] += perturb
        params_down[i] -= perturb

        yvals_up, _ = forward_func(params_up, redshift) 
        yvals_down, _ = forward_func(params_down, redshift)

        derivative[i] = (yvals_up - yvals_down)/(2*perturb)
        
    return derivative, yvals_up, yvals_down


In [5]:
#| export

def plot_contours(hessian, pos, nstd=1., ax=None, **kwargs):
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:, order]

    # Ensure Hessian is negative definite for covariance inversion
    if np.all(np.linalg.eigvals(-hessian) > 0):
        mat = -hessian
    else:
        #raise ValueError("The Hessian is not negative definite.")
        epsilon = 1e-5  # A small positive value
        mat = hessian - epsilon * np.eye(hessian.shape[0])
    cov = np.linalg.pinv(mat)
    
    # Check for valid covariance values
    if not np.all(np.isfinite(cov)):
        raise ValueError("Covariance matrix contains NaN or Inf.")

    sigma_marg = lambda i: np.sqrt(cov[i, i])

    if ax is None:
        ax = plt.gca()

    vals, vecs = eigsorted(cov)
    if np.any(vals < 0):
        raise ValueError("Negative eigenvalues found in covariance matrix.")

    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

    # Width and height are "full" widths, not radius
    width, height = 2 * nstd * np.sqrt(np.abs(vals))
    ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs)

    ax.add_artist(ellip)
    sz = max(width, height)
    s1 = 1.5 * nstd * sigma_marg(0)
    s2 = 1.5 * nstd * sigma_marg(1)

    ax.axhline(pos[1], color='blue')
    ax.axvline(pos[0], color='blue')

    ax.set_xlim(pos[0] - s1, pos[0] + s1)
    ax.set_ylim(pos[1] - s2, pos[1] + s2)
    plt.draw()
    return ellip


In [6]:
#| export 
def plot_fisher_grid_aligned(param_names, param_fiducial, fisher_matrix_6x6):
    n_params = len(param_names)
    fig, axes = plt.subplots(nrows=n_params, ncols=n_params, figsize=(12, 10), sharex='col', sharey='row')
    fig.subplots_adjust(top=0.9, right=0.9, left=0.0, bottom=0.0, hspace=0.01, wspace=0.01)

    neg_fmat = -fisher_matrix_6x6

    for i in range(n_params):
        for j in range(n_params):
            if i > j:
                ax = axes[i, j]
                param_duo = np.array([param_fiducial[j], param_fiducial[i]])
                fish2x2 = np.array([[neg_fmat[j, j], neg_fmat[j, i]], [neg_fmat[i, j], neg_fmat[i, i]]])
                
                plot_contours(fish2x2, param_duo, nstd=1., ax=ax, alpha=0.4, fill=True, edgecolor='red', linewidth=2)

                if j == 0:
                    ax.set_ylabel(param_names[i], fontsize=12)
                if i == n_params - 1:
                    ax.set_xlabel(param_names[j], fontsize=12)
            else:
                axes[i, j].axis('off')

    fig.tight_layout()
    fig.show()


In [7]:
#| export

def plot_fisher_grid_multiple(param_names, param_fiducial, fisher_matrices, matrix_labels):
    n_params = len(param_names)
    n_matrices = len(fisher_matrices)

    fig, axes = plt.subplots(nrows=n_params, ncols=n_params, figsize=(12, 10), sharex='col', sharey='row')
    plt.subplots_adjust(top=0.90, right=0.85, left=0.08, bottom=0.08, hspace=0.4, wspace=0.4)

    colors = ['red', 'blue', 'green', 'purple']  # Adjust as needed
    legend_handles = []

    for i in range(n_params):
        for j in range(n_params):
            if i > j:
                ax = axes[i, j]
                param_duo = np.array([param_fiducial[j], param_fiducial[i]])

                # Plot ellipses for each Fisher matrix
                for k, (fisher_matrix, label) in enumerate(zip(fisher_matrices, matrix_labels)):
                    neg_fmat = -fisher_matrix
                    fish2x2 = np.array([[neg_fmat[j, j], neg_fmat[j, i]], 
                                        [neg_fmat[i, j], neg_fmat[i, i]]])

                    ellip = plot_contours(fish2x2, param_duo, nstd=1., ax=ax, 
                                          alpha=0.3, fill=True, edgecolor=colors[k], 
                                          linewidth=2)

                    if i == n_params - 1 and j == 0:
                        legend_handles.append((ellip, label))  # Collect handles and labels

                if j == 0:
                    ax.set_ylabel(param_names[i], fontsize=12)
                if i == n_params - 1:
                    ax.set_xlabel(param_names[j], fontsize=12)

            else:
                axes[i, j].axis('off')

    # Create a single legend outside the plot
    # fig.legend(*zip(*legend_handles), loc='center right', fontsize=10)
    fig.legend(*zip(*legend_handles), loc='center', fontsize=14, 
               bbox_to_anchor=(0.7, 0.7), ncol=1, frameon=True)

    plt.tight_layout()
    plt.show()

In [8]:
#| export
######INCLUDE CLASS/FUNCs IN THE FUNCTION ARGUMENT - maybe - check with Nesar########
def compute_fisher_k(k_vals, param_fid, redshift, forward_func, pk_fid, pk_error_fid, perturb=1e-5):

    len_k = len(k_vals)
    
    pk_slice = np.log(pk_fid[0:len_k-1])
    pk_error_slice = np.log(pk_error_fid[0:len_k-1])
    k_slice = k_vals[0:len_k-1]
    
    # Add random errors to pk_error_slice
    pk_error_added = 1 * np.random.uniform(low=0.5, high=2.2, size=pk_slice.shape[0]) + 0.2 * pk_slice * np.random.uniform(low=0.9, high=1.1, size=pk_slice.shape[0])
    pk_error_total = pk_error_slice + pk_error_added  # Instead of pk_error_fid

    # Initialize the derivatives array
    derivs = np.zeros((len(pk_slice), len(pk_slice)))

    # Compute derivatives for the Fisher matrix
    for i in range(len(param_fid)):
        params_up = param_fid.copy()
        params_down = param_fid.copy()

        params_up[i] += perturb
        params_down[i] -= perturb

        # Compute the power spectra for perturbed parameters
        pk_up, _ = forward_func(params_up, redshift)
        pk_down, _ = forward_func(params_down, redshift)

        # Compute the derivative (finite difference)
        derivs[i] = (pk_up[0:len_k-1] - pk_down[0:len_k-1]) / (2 * perturb)

    # Create Fisher matrix instance and compute the matrix
    fisher_matrix_instance = FisherMatrix(k_slice, pk_slice, pk_error_total, param_fid)
    fmat = fisher_matrix_instance.compute_fmat(derivs)

    return fmat

In [9]:
#| export
def computs_fisher_mat(params, redshift, forward_func, perturb = 1e-5):

    pk_fid, pk_error_fid = power_spectra(param_fid, redshift_fid)
    pk_error_added = 1* np.random.uniform(low=0.5, high=2.2, size=pk_fid.shape[0]) + 0.2*pk_fid* np.random.uniform(low=0.9, high=1.1, size=pk_fid.shape[0])
    
    pk_error_total = pk_error_fid + pk_error_added ## instead of pk_error_fid
    
    derivs, pk_up, pk_down = deriv(param_fid, redshift_fid, power_spectra)
    k_values = k_all
    fisher = FisherMatrix(k_values, pk_fid, pk_error_total, param_fid)
    fmat_fid = fisher.compute_fmat(derivs)
    #fish_all = np.array(_2x2_matrices(-fmat_fid))

    return fmat_fid


In [10]:
#| hide
import nbdev; nbdev.nbdev_export()