# Checking local likelihood variation compared to global

In [33]:
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from pitchfuncs_unit import WMSE
from pitchfuncs_unit import InversePCA
from pitchfuncs_unit import emulator
from pitchfuncs_unit import ns
import pandas as pd
import scipy
import numpy as np
import corner
import time
from scipy import constants
from scipy import stats
import astropy.constants
import pickle
from matplotlib.pyplot import cm
import json
import matplotlib.pyplot as plt

def calc_Teff(luminosity, radius):
    return np.array(((luminosity.values*astropy.constants.L_sun) / (4*np.pi*constants.sigma*((radius.values*astropy.constants.R_sun)**2)))**0.25)

def rescale_preds(preds, df, column):
    if 'star_feh' in column:
        return (preds[column+"_std"]*df[column].std())+df[column].mean()
    else:
        return 10**((preds["log_"+column+"_std"]*df["log_"+column].std())+df["log_"+column].mean())


## Import emulator:

In [34]:
pitchfork_name = "nu6-40_elu_nonorm_feh"
pitchfork = emulator(pitchfork_name)

with open("pickle jar/" +pitchfork_name+ ".pkl", 'rb') as fp:
     pitchfork_info = pickle.load(fp)

pitchfork_ranges = pitchfork_info['parameter_ranges']

initial_mass range: [min = 0.8, max = 1.2]
initial_Zinit range: [min = 0.003869061466818601, max = 0.0389797119014747]
initial_Yinit range: [min = 0.24, max = 0.32]
initial_MLT range: [min = 1.7, max = 2.5]
star_age range: [min = 0.029664111540787196, max = 13.999973871651315]


## Define emu

In [41]:
inputs = ['initial_mass', 'initial_Zinit', 'initial_Yinit', 'initial_MLT', 'star_age']
emu_inps = [1,0.014, 0.26, 2, 5]

outputs = ['calc_effective_T', 'luminosity', 'star_feh'] + [f'nu_0_{i}' for i in range(6,40+1)]
emu_outs = pitchfork.predict([emu_inps])[0]

emu = pd.DataFrame([emu_inps+list(emu_outs)], columns = inputs+outputs)

In [68]:
def closest_points(point, points, n):
    args = scipy.spatial.distance.cdist(point,points).argsort()[0,:n]
    return points[args]

test_data = pd.read_hdf('test_data.h5')
closest_points_df = pd.DataFrame(closest_points(emu[outputs].values, test_data[outputs].values,n=10000), columns=[outputs])
closest_points_df

Unnamed: 0,calc_effective_T,luminosity,star_feh,nu_0_6,nu_0_7,nu_0_8,nu_0_9,nu_0_10,nu_0_11,nu_0_12,...,nu_0_31,nu_0_32,nu_0_33,nu_0_34,nu_0_35,nu_0_36,nu_0_37,nu_0_38,nu_0_39,nu_0_40
0,5779.359553,1.050687,2.980126e-02,920.489488,1056.770620,1194.210497,1329.385146,1462.355639,1591.949034,1719.195830,...,4169.417099,4299.201081,4427.995815,4552.596865,4678.993401,4802.886754,4923.414238,5041.009909,5157.896684,5277.329841
1,5784.719676,1.053612,-7.472621e-02,921.076515,1057.562240,1195.057612,1329.965125,1462.792013,1592.020008,1719.297789,...,4169.492000,4299.133685,4426.679337,4551.959434,4677.838325,4800.855558,4920.577178,5037.760929,5155.263484,5275.757681
2,5794.104926,1.030893,2.210767e-02,921.978192,1058.027662,1195.339758,1330.478215,1463.040093,1592.416656,1719.552258,...,4171.824857,4301.484389,4429.455034,4554.594764,4680.489753,4803.698945,4923.439974,5040.512049,5158.057490,5278.403553
3,5768.284789,1.046869,1.269560e-01,924.204522,1061.846339,1198.403241,1334.529180,1466.919148,1596.323676,1724.589075,...,4169.947979,4299.422378,4427.493507,4551.526201,4676.977481,4799.664408,4919.086752,5035.869489,5153.247045,5273.724446
4,5759.288227,1.017723,-1.809659e-01,919.903640,1056.726175,1193.761904,1328.549224,1461.393337,1589.965073,1717.496603,...,4171.423323,4301.067379,4427.579269,4554.367538,4680.171847,4803.270287,4922.790675,5039.913314,5157.632327,5278.405718
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,6143.657526,1.644365,5.000000e-01,878.095693,1008.877328,1140.316678,1270.693457,1401.077726,1528.769845,1653.621287,...,3998.608105,4123.294796,4247.917446,4372.389187,4495.393797,4616.651241,4739.017379,4859.893559,4978.787644,5095.174427
9996,6190.869055,1.598614,1.338362e-09,881.151487,1013.312314,1144.018432,1275.388346,1404.464598,1531.393261,1654.469591,...,4009.523761,4134.234969,4256.241078,4379.772099,4502.307595,4623.244757,4741.835141,4857.603413,4971.216076,5084.753676
9997,5373.805372,0.845560,3.170573e-01,890.462032,1023.579250,1155.005680,1284.680488,1412.095819,1536.126240,1659.093095,...,4009.582544,4134.449027,4258.676848,4380.996234,4499.271465,4618.132196,4733.920350,4846.498556,4958.460841,5072.745636
9998,6055.293726,1.456307,5.308619e-02,878.048893,1009.906789,1140.276301,1271.173671,1399.625692,1525.592806,1648.185422,...,3997.609064,4122.202166,4245.745220,4366.824598,4489.138311,4609.814740,4728.067999,4843.481203,4956.476937,5069.564118


In [39]:
def logl_plot(emulator, emu_inps, emu_obs, x_label, x_min, x_max, sigma, points=1000):
    inputs = ['initial_mass', 'initial_Zinit', 'initial_Yinit', 'initial_MLT', 'star_age']
    emu_df = pd.DataFrame(emu_inps, columns=inputs)
    emu_array = pd.concat([emu_df]*points)
    
    x = np.linspace(x_min, x_max, points)
    emu_array[x_label] = x

    sigma_inv = np.linalg.inv(sigma)

    _,log_sigma_det = np.linalg.slogdet(sigma)
    
    m = emulator.predict(emu_array.values)
    
    y = np.empty(points)

    for i in range(points):
        residual_matrix = np.matrix(m[i,:]-emu_obs)
        
        y[i] = -(len(m)*0.5*np.log(2*np.pi))-(0.5*log_sigma_det)-(0.5*residual_matrix*sigma_inv*residual_matrix.T).flat[0]

    plt.plot(x, y)
    plt.xlabel(x_label)
    plt.ylabel('logl')

notebook_x_label='initial_mass'

teff_unc = 70 #K
luminosity_unc = 0.04 #L\odot
surface_feh_unc = 0.1 #dex
frequency_unc = 0.5 #\muHz

# obs_unc = np.array([teff_unc, luminosity_unc, surface_feh_unc]+[frequency_unc+0.01*np.abs(((n_max-n_min)/2+n_min)-i) for i in range(n_min,n_max+1)])
obs_unc = np.array([teff_unc, luminosity_unc, surface_feh_unc]+[frequency_unc for i in range(6,40+1)])

pitchfork_ranges = pitchfork_info['parameter_ranges']

emulator_errors = pd.read_json('pickle jar/emulator_errors2.json')
emulator_errors.columns = ['calc_effective_T', 'luminosity', 'star_feh'] + [f'nu_0_{i}' for i in range(6,41)]
emulator_errors = emulator_errors[outputs]

emulator_cov = np.cov(emulator_errors.to_numpy().T)

sigma_nn = emulator_cov

sigma_obs = (obs_unc*obs_unc)*(np.identity(len(emulator_cov)))

sigma = sigma_nn + sigma_obs
