# Check Inference on Synthetic Observations

In [6]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import os

import h5py
import numpy as np
import matplotlib.pyplot as plt
from prospect.utils.obsutils import fix_obs

from agnfinder.prospector import visualise, main, cpz_builders

In [8]:
os.chdir('/home/mike/repos/agnfinder')

In [None]:
!pwd

/home/mike/repos/agnfinder


In [None]:
galaxy_index = 1
galaxy = main.load_galaxy(galaxy_index)
redshift = galaxy['redshift']

In [None]:
agn_mass = True
agn_eb_v = True
agn_torus_mass = True
igm_absorbtion = True

In [None]:
run_params, obs, model, sps = main.construct_problem(galaxy, redshift=redshift, agn_mass=agn_mass, agn_eb_v=agn_eb_v, agn_torus_mass=agn_torus_mass, igm_absorbtion=igm_absorbtion)

sps should be an instance of my custom CSPSpecBasisAGN class

In [None]:
sps

And model should include agn_mass and agn_eb_v

In [None]:
model.free_params

In [None]:
model.fixed_params

In [None]:
model_spectrum, model_photometry, _ = model.sed(model.theta, obs, sps)  # trigger fsps calculation, takes a couple of minutes - then cached

In [None]:
model_photometry

In [None]:
sps.quasar_flux

In [None]:
assert max(sps.quasar_flux) > min(sps.quasar_flux)  # quasar component should be SOMETHING

## How does the SED change as we vary the AGN parameters?

In [None]:
model_param_index = dict(zip(model.free_params, range(len(model.free_params))))

In [None]:
theta_no_agn = model.theta.copy()
theta_agn = model.theta.copy()

# unobscured agn
theta_agn[model_param_index['agn_mass']] = 1e15
theta_agn[model_param_index['agn_torus_mass']] = 0
theta_no_agn[model_param_index['agn_mass']] = 0
theta_no_agn[model_param_index['agn_torus_mass']] = 0


In [None]:
fig, ax = plt.subplots(figsize=(16, 8))
visualise.plot_model_at_obs(ax, model, theta_no_agn, obs, sps)  # no torus, so looks weirdly blue-only
visualise.plot_model_at_obs(ax, model, theta_agn, obs, sps)  
# plt.loglog(sps.wavelengths, sps.galaxy_flux, label='Galaxy')

In [None]:
observer_wavelengths = visualise.get_observer_frame_wavelengths(model, sps)

In [None]:
plt.loglog(observer_wavelengths, sps.galaxy_flux, label='Galaxy')
plt.loglog(observer_wavelengths, sps.unextincted_quasar_flux, 'b--', label='Unextincted Quasar (not used)')
plt.loglog(observer_wavelengths, sps.extincted_quasar_flux, label='Extincted Quasar')
# plt.loglog(observer_wavelengths, sps.torus_flux, label='Torus')
# plt.loglog(observer_wavelengths, sps.quasar_flux, 'k', label='Net (Quasar)')
plt.legend()

What if we add extinction?

In [None]:
theta_agn_extinction = theta_agn.copy()
theta_agn_no_extinction = theta_agn.copy()

theta_agn_extinction[model_param_index['agn_eb_v']] = 0.5
theta_agn_no_extinction[model_param_index['agn_eb_v']] = 0

In [None]:
fig, ax = plt.subplots(figsize=(16, 8))
visualise.plot_model_at_obs(ax, model, theta_agn_extinction, obs, sps)  
visualise.plot_model_at_obs(ax, model, theta_agn_no_extinction, obs, sps)  # exactly as with AGN above

What about with the AGN + Torus?

In [None]:
theta_agn_extinction_torus = theta_agn_extinction.copy()
theta_agn_extinction_no_torus = theta_agn_extinction.copy()

theta_agn_extinction_torus[model_param_index['agn_torus_mass']] = 1.
theta_agn_extinction_no_torus[model_param_index['agn_torus_mass']] = 0.

In [None]:
fig, ax = plt.subplots(figsize=(16, 8))
# visualise.plot_model_at_obs(ax, model, theta_agn_extinction_no_torus, obs, sps)  
visualise.plot_model_at_obs(ax, model, theta_agn_extinction_torus, obs, sps)

In [None]:
def plot_components(observer_wavelengths, sps):
    plt.clf()
    fig, ax = plt.subplots(figsize=(16, 6))
    ax.loglog(sps.wavelengths, sps.galaxy_flux, 'g', label='Galaxy')
    ax.loglog(sps.wavelengths, sps.unextincted_quasar_flux, 'b--', label='Unextincted Quasar (not used)')
    ax.loglog(sps.wavelengths, sps.extincted_quasar_flux, 'b', label='Extincted Quasar')
    ax.loglog(sps.wavelengths, sps.torus_flux, 'orange', label='Torus')
    # ax.loglog(observer_wavelengths, sps.quasar_flux, 'k--', label='Net (Quasar)')
    ax.loglog(sps.wavelengths, sps.quasar_flux + sps.galaxy_flux, 'k', label='Net (All)')
    ax.legend()
#     ax.set_ylim([1e-16, 1e-10])
#     ax.set_xlim(1.5e2, 1e7)
    ax.set_ylabel('Flux')
    ax.set_xlabel('Wavelength (A, restframe)')
    fig.tight_layout()

In [None]:
plot_components(observer_wavelengths, sps)
plt.ylim([1e-30, 1e-13])
plt.xlim([1e3, 1e7])

In [None]:
theta_agn[model_param_index['mass']]

## What does this look like for a model that we've actually fit?

In [None]:
def load_theta_from_samples(samples_loc):
    with h5py.File(samples_loc, 'r') as f:
        samples = f['samples'][...]
        return np.median(samples, axis=0)

In [None]:
# old: AGN now 1e14 bigger

### Quasar? (including 1e14 rescaling)

In [None]:
samples_loc = '/home/mike/repos/agnfinder/results/qso_fixed_inclination_bigger_agn/qso_bigger_agn_mass_fixed_inclination_0_1564527506_multinest_samples.h5py'
assert os.path.isfile(samples_loc)
fit_theta = load_theta_from_samples(samples_loc)
dict(zip(model.free_params, fit_theta))

In [None]:
fig, ax = plt.subplots(figsize=(16, 6))
visualise.plot_model_at_obs(ax, model, fit_theta, obs, sps)  
ax.set_ylim([1e-7, 1e-4])
ax.set_xlim([1e3, 1e7])

In [None]:
plot_components(observer_wavelengths, sps)
plt.ylim([1e-17, 1e-10])
plt.xlim([1e3, 1e7])

### Starforming?

In [None]:
samples_loc = '/home/mike/repos/agnfinder/results/qso_fixed_inclination_bigger_agn/starforming_bigger_agn_mass_fixed_inclination_1_1564532765_multinest_samples.h5py'
assert os.path.isfile(samples_loc)
fit_theta = load_theta_from_samples(samples_loc)
dict(zip(model.free_params, fit_theta))

fig, ax = plt.subplots(figsize=(16, 6))
visualise.plot_model_at_obs(ax, model, fit_theta, obs, sps)  
# ax.set_ylim([1e-7, 1e-4])
ax.set_xlim([1e3, 1e7])

### AGN?

In [None]:
samples_loc = '/home/mike/repos/agnfinder/results/qso_fixed_inclination_bigger_agn/agn_bigger_agn_mass_fixed_inclination_0_1564532822_multinest_samples.h5py'
assert os.path.isfile(samples_loc)
fit_theta = load_theta_from_samples(samples_loc)
dict(zip(model.free_params, fit_theta))

fig, ax = plt.subplots(figsize=(16, 6))
visualise.plot_model_at_obs(ax, model, fit_theta, obs, sps)  
# ax.set_ylim([1e-7, 1e-4])
ax.set_xlim([1e3, 1e7])

## Can we recover the original parameters?

In [None]:
assert False

In [None]:
theta_to_recover = theta_agn_extinction

In [None]:
model_spectrum, model_photometry, _ = model.sed(theta_to_recover, obs, sps)  # trigger fsps calculation, takes a couple of minutes - then cached

In [None]:
def make_synthetic_obs_from_model(real_obs, maggies, snr=10.):
    synthetic_obs = {}
    synthetic_obs["filters"] =  real_obs['filters']
    synthetic_obs["maggies"] = maggies
    synthetic_obs['maggies_unc'] = maggies / snr

    synthetic_obs["phot_mask"] = np.array([True for _ in synthetic_obs['filters']])
    synthetic_obs["phot_wave"] = np.array([f.wave_effective for f in synthetic_obs["filters"]])
    synthetic_obs["wavelength"] = None
    synthetic_obs["spectrum"] = None
    synthetic_obs['unc'] = None
    synthetic_obs['mask'] = None
    synthetic_obs = fix_obs(synthetic_obs)
    return synthetic_obs


In [None]:
synthetic_obs = make_synthetic_obs_from_model(obs, model_photometry)

### Max Likelihood (starting from correct values, so unhelpful!)

In [None]:
# theta_best, time_elapsed = main.fit_galaxy(run_params, synthetic_obs, model, sps)

### MCMC

In [None]:
samples, mcmc_time_elapsed = main.mcmc_galaxy(run_params, synthetic_obs, model, sps, initial_theta=None, test=False)

In [None]:
synthetic_obs = build_synthetic_obs_from_model(obs, model_photometry, snr=10.)

In [None]:
visualise.visualise_obs(synthetic_obs)

In [None]:
name = 'inference_to_recover_theta'
output_dir = '/home/mike/repos/agnfinder/results'

In [None]:
sample_loc = os.path.join(output_dir, '{}_mcmc_samples.h5py'.format(name))
main.save_samples(samples, model, sample_loc)
corner_loc = os.path.join(output_dir, '{}_mcmc_corner.png'.format(name))
main.save_corner(samples[int(len(samples)/2):], model, corner_loc)  # nested sampling has no burn-in phase, early samples are bad

In [None]:
theta_to_recover

In [None]:
import corner

In [None]:
for index in range(len(theta_to_recover)):
    print('Name: {}'.format(model.free_params[index]))
    print('True value: {:3.1E}'.format(theta_to_recover[index]))
    low, med, up = corner.quantile(samples[:, index], [.1, .5, .9])
    print('Estimate: {:3.1E} (min {:3.2E}, max {:3.3E})'.format(low, med, up))
    print('\n')
    

In [None]:
traces_loc = os.path.join(output_dir, '{}_mcmc_traces.png'.format(name))
main.save_sed_traces(samples[-2000:], synthetic_obs, model, sps, traces_loc)


In [None]:
plt.loglog(sps.wavelengths, sps.extincted_quasar_flux)
_, spectra, _ = sps.get_galaxy_spectrum()
plt.loglog(sps.wavelengths, spectra)

In [None]:
len(sps.wavelengths)

In [None]:
sps.extincted_quasar_flux.sum() / model_spectrum.sum()

In [None]:
len(sps.extincted_quasar_flux)

In [None]:
len(model_spectrum)