In [7]:
from iminuit import Minuit
import numpy as np
import torch
from sunbird.inference.priors import Yuan23, AbacusSummit
from acm.data.io_tools import *
import matplotlib.pyplot as plt
%matplotlib inline

def get_priors(cosmo=True, hod=True):
    stats_module = 'scipy.stats'
    priors, ranges, labels = {}, {}, {}
    if cosmo:
        priors.update(AbacusSummit(stats_module).priors)
        ranges.update(AbacusSummit(stats_module).ranges)
        labels.update(AbacusSummit(stats_module).labels)
    if hod:
        priors.update(Yuan23(stats_module).priors)
        ranges.update(Yuan23(stats_module).ranges)
        labels.update(Yuan23(stats_module).labels)
    return priors, ranges, labels

def fill_params(theta):
    params = np.ones(len(priors.keys()))
    itheta = 0
    for i, param in enumerate(priors.keys()):
        if param not in fixed_parameters.keys():
            params[i] = theta[itheta]
            itheta += 1
        else:
            params[i] = fixed_parameters[param]
    return params

def get_model_prediction(theta):
    with torch.no_grad():
        prediction = []
        for model, filters in zip(theory_model, model_filters):
            pred = model.get_prediction(
                x=torch.Tensor(theta),
                filters=filters,
            )
            prediction.append(pred.numpy())
        prediction = np.concatenate(prediction, axis=-1)
        return prediction
    
def log_likelihood_minuit(*theta):
    return log_likelihood(theta)

def log_likelihood(theta):
    theta = np.array(theta)
    params = fill_params_batch(theta) if len(np.shape(theta)) > 1 else fill_params(theta)
    prediction = get_model_prediction(params)
    diff = data_y - prediction
    if len(theta.shape) > 1:
        return [-0.5 * diff[i] @ precision_matrix @ diff[i].T for i in range(len(theta))]
    return 0.5 * diff @ precision_matrix @ diff.T


# set up the inference
priors, ranges, labels = get_priors(cosmo=True, hod=True)
select_filters = {'cosmo_idx': 0, 'hod_idx': 30,
}
fixed_parameters = ['w0_fld', 'wa_fld', 'N_ur', 'nrun', 's', 'A_cen', 'A_sat',
'B_cen', 'B_sat', 'alpha', 'kappa', 'sigma', 'alpha_c', 'alpha_s']
statistics = ['pk']
kmin, kmax = 0.0, 0.5
slice_filters = {'k': [kmin, kmax]}

# load the covariance matrix
covariance_matrix, n_sim = read_covariance(statistics=statistics,
                                            select_filters=select_filters,
                                            slice_filters=slice_filters)
print(f'Loaded covariance matrix with shape: {covariance_matrix.shape}')
precision_matrix = np.linalg.inv(covariance_matrix)

# load the data
data_x, data_y, data_x_names, model_filters = read_lhc(statistics=statistics,
                                                    select_filters=select_filters,
                                                    slice_filters=slice_filters,
                                                    return_mask=True)
print(f'Loaded LHC x with shape: {data_x.shape}')
print(f'Loaded LHC y with shape {data_y.shape}')

fixed_parameters = {key: data_x[data_x_names.index(key)]
                    for key in fixed_parameters}

# load the model
models = read_model(statistics=statistics)

theory_model = models

init_params = {key: data_x[data_x_names.index(key)] for key in data_x_names if key not in fixed_parameters}
minuit_params = {}
minuit_params['name'] = [str(param) for param in init_params.keys()]

m = Minuit(log_likelihood_minuit, **init_params, **minuit_params)

for param in init_params.keys():
    m.limits[param] = (ranges[param][0], ranges[param][1])

# m.limits["omega_cdm"] = (0.1032, 0.14)
# m.limits["sigma8_m"] = (0.678, 0.938)
# m.limits["n_s"] = (0.9012, 1.025)
# m.limits["omega_b"] = (0.0207, 0.0243)
# log_likelihood([value for key, value in init_params.items()])

Loaded covariance matrix with shape: (78, 78)
Loaded LHC x with shape: (20,)
Loaded LHC y with shape (78,)


In [8]:
# Run the minimization
m.errordef = Minuit.LIKELIHOOD
m.migrad()

Migrad,Migrad.1
FCN = 5.326,Nfcn = 2234
EDM = 8.41e+04 (Goal: 0.0001),time = 1.3 sec
INVALID Minimum,ABOVE EDM threshold (goal x 10)
SOME parameters at limit,Below call limit
Hesse ok,Covariance APPROXIMATE

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,omega_b,0.0222,0.0010,,,0.0207,0.0243,
1.0,omega_cdm,0.120,0.006,,,0.103,0.14,
2.0,sigma8_m,0.81,0.04,,,0.678,0.938,
3.0,n_s,0.97,0.04,,,0.901,1.02,
4.0,logM_cut,12.54,0.14,,,12.5,13.7,
5.0,logM_1,14.0,0.4,,,13.6,15.1,

0,1,2,3,4,5,6
,omega_b,omega_cdm,sigma8_m,n_s,logM_cut,logM_1
omega_b,1.21e-06,-5.4e-6 (-0.768),30.3e-6 (0.639),0.7e-6 (0.016),0,-9.1e-6 (-0.018)
omega_cdm,-5.4e-6 (-0.768),4.04e-05,-0 (-0.007),0.15e-3 (0.627),0,0.07e-3 (0.025)
sigma8_m,30.3e-6 (0.639),-0 (-0.007),0.00186,0.0013 (0.760),0.0000,0.0026 (0.133)
n_s,0.7e-6 (0.016),0.15e-3 (0.627),0.0013 (0.760),0.00149,-0.0000,-0.0006 (-0.033)
logM_cut,0,0,0.0000,-0.0000,0.0219,0.000
logM_1,-9.1e-6 (-0.018),0.07e-3 (0.025),0.0026 (0.133),-0.0006 (-0.033),0.000,0.211


In [9]:
m.hesse()

Migrad,Migrad.1
FCN = 5.326,Nfcn = 2292
EDM = 693 (Goal: 0.0001),
INVALID Minimum,ABOVE EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance FORCED pos. def.

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,omega_b,22.2152834e-3,0.0000006e-3,,,0.0207,0.0243,
1.0,omega_cdm,0.1200,0.0012,,,0.103,0.14,
2.0,sigma8_m,809.969379e-3,0.000006e-3,,,0.678,0.938,
3.0,n_s,965.0048460e-3,0.0000008e-3,,,0.901,1.02,
4.0,logM_cut,12.5376843147,0.0000000026,,,12.5,13.7,
5.0,logM_1,13.978998209,0.000000008,,,13.6,15.1,

0,1,2,3,4,5,6
,omega_b,omega_cdm,sigma8_m,n_s,logM_cut,logM_1
omega_b,3.5e-19,-698.78458e-15 (-0.999),-3.61e-18 (-0.966),-0,-0,0.01e-18 (0.002)
omega_cdm,-698.78458e-15 (-0.999),1.4e-06,7.21501e-12 (0.967),0,702e-18,-23.56e-15 (-0.002)
sigma8_m,-3.61e-18 (-0.966),7.21501e-12 (0.967),3.99e-17,0,0e-18,-0 (-0.002)
n_s,-0,0,0,6.09e-19,0,-0
logM_cut,-0,702e-18,0e-18,0,6.74e-18,-0e-18
logM_1,0.01e-18 (0.002),-23.56e-15 (-0.002),-0 (-0.002),-0,-0e-18,6.54e-17


In [10]:
m.errors

<ErrorView omega_b=5.917412566891489e-10 omega_cdm=0.001181291635413334 sigma8_m=6.312876277281276e-09 n_s=7.802687940205999e-10 logM_cut=2.596999060244798e-09 logM_1=8.088211878032325e-09>