# JWST-ERS Clusters
## Spectral Population Synthesis Fitting

The goal of this notebook is to fit some simple JWST prism spectra with [prospector](https://github.com/bd-j/prospector).

In [1]:
import os
import time
import h5py

import nestle
import numpy as np

In [2]:
from prospect import fitting
from prospect.io import write_results

ModuleNotFoundError: No module named 'sedpy'

In [None]:
%pylab inline

### Define the key prospector functions


In [None]:
def load_obs(zred=1.0, wave=None, flux=None, ferr=None, mask=None):
    """Generate the prospector-style "obs" dictionary which contains the input
    photometry, redshift, etc. for a single object.

    Args:
        zred (float): Galaxy redshift.
        wave (np.array, npix): Wavelength array (micron).
        flux (np.array, npix): Galaxy spectrum (units??).
        ferr (np.array, npix): Uncertainty spectrum corresponding to flux (units??).
        mask (np.array, npix): Mask spectrum (1=good, 0=bad).

    """
    obs = {} 

    # Input photometry
    obs['maggies'] = None
    obs['maggies_unc'] = None
    obs['phot_mask'] = None

    # Input spectroscopy
    obs['wavelength'] = wave 
    obs['spectrum'] = flux
    obs['unc'] = ferr
    obs['mask'] = mask

    # Store the redshift and any other galaxy properties
    obs['zred'] = zred

    return obs

In [None]:
def load_model(zred=1.0, seed=None):
    """Initialize the priors on each free and fixed parameter.

    Args:
      zred (float): input (fixed) galaxy redshift.

    Returns:
      sed (prospect.models.sedmodel.SedModel): SED priors and other stuff.

    Notes:
      FSPS parameters are documented here:
        http://dan.iel.fm/python-fsps/current/stellarpop_api/#api-reference

      Initialization parameters:
        * compute_vega_mags (must be set at initialization)
        * vactoair_flag (must be set at initialization)
        * zcontinuous (must be set at initialization)
    
      Metallicity parameters:
        * zmet (default 1, ignored if zcontinuous>0)
        * logzsol (default 0.0, used if zcontinuous>0)
        * pmetals (default 2.0, only used if zcontinuous=2)

      Dust parameters:
        * add_agb_dust_model (default True)
        * add_dust_emission (default True)
        * cloudy_dust (default False)
        * agb_dust (default 1.0)
        * dust_type (default 0=power law)
        * dust_index, dust1_index
        * dust_tesc
        * dust1 (default 0.0) - extra optical depth towards young stars at 5500A
        * dust2 (default 0.0) - diffuse dust optical depth towards all stars at 5500A
        * dust_clumps, frac_nodust, frac_obrun
        * mwr, uvb, wgp1, wgp2, wgp3, 
        * duste_gamma, duste_umin, duste_qpah

      Star formation history parameters:
        * sfh (default 0=SSP, 1=tau, 4=delayed, 5=truncated delayed tau)
        * tau (default 1)
        * const, sf_start, sf_trunc
        * tage (default 0.0)
        * fburst, tburst, sf_slope
    
      Miscellaneous parameters:
        * add_igm_absorption (default False)
        * igm_factor (default 1.0)
        * smooth_velocity (default True)
        * sigma_smooth, min_wave_smooth, max_wave_smooth
        * redshift_colors (default False, do not use)
        * compute_light_ages (default False, do not use)
       
      Stellar population parameters:
        * add_stellar_remnants (default True)
        * tpagb_norm_type (default 2)
        * dell (default 0.0, do not use)
        * delt (default 0.0, do not use)
        * redgb (default 1.0)
        * fcstar (default 1.0)
        * sbss (default 0.0)
        * fbhb (default 0.0)
        * pagb (default 1.0)
        * imf_type (default 2=Kroupa01)
        * imf1, imf2, imf3, vdmc, mdave, masscut
        * evtype (default 1)
        * tpagb_norm_type

      Emission lines:
        * add_neb_emission (default False)
        * add_neb_continuum (default False)
        * gas_logz (default 0.0)
        * gas_logu (default -2)

      Galaxy properties:
        * zred (default 0.0)

      AGN properties:
        * fagn (default 0.0)
        * agn_tau (default 10)

      Calibration parameters:
        * phot_jitter

    """
    from prospect.models import priors, sedmodel

    model_params = []

    ##################################################
    # Fixed priors

    # Galaxy redshift
    model_params.append({
        'name': 'zred',
        'N': 1,
        'isfree': False,
        'init': zred,
        'units': '',
        'prior': None,       
        })

    model_params.append({ # current mass in stars, not integral of SFH
        'name': 'mass_units',
        'N': 1,
        'isfree': False,
        'init': 'mstar', # 'mformed'
        'prior': None,       
        })

    # IMF (Chabrier)
    model_params.append({
        'name': 'imf_type',
        'N': 1,
        'isfree': False,
        'init':   1, # 1 - Chabrier
        'units': '',
        'prior': None,       
        })

    # SFH parameterization (delayed-tau)
    model_params.append({
        'name': 'sfh',
        'N': 1,
        'isfree': False,
        'init':   4, # 4 = delayed tau model
        'units': 'type',
        'prior': None,       
        })

    # Do not include dust emission
    model_params.append({
        'name': 'add_dust_emission',
        'N': 1,
        'isfree': False,
        'init':   False, # do not include dust emission
        'units': 'index',
        'prior': None,       
        })

    ##################################################
    # Free priors / parameters

    # Priors on stellar mass and stellar metallicity
    logmass_prior = priors.TopHat(mini=9.0, maxi=13.0)#, seed=seed)
    logmass_init = np.diff(logmass_prior.range)/2.0 + logmass_prior.range[0] # logmass_prior.sample()
    model_params.append({
        'name': 'logmass',
        'N': 1,
        'isfree': True,
        'init': logmass_init, # mass, 
        'init_disp': 0.5,     # dex
        'units': r'$M_{\odot}$',
        'prior': logmass_prior,
        })
    
    model_params.append({
        'name': 'mass',
        'N': 1,
        'isfree': False,
        'init': 10**logmass_init,
        'units': '',
        'prior': None,
        'depends_on': logmass2mass,
        })

    logzsol_prior = priors.TopHat(mini=np.log10(0.004/0.019), maxi=np.log10(0.04/0.019))#, seed=seed)
    logzsol_init = np.diff(logzsol_prior.range)/2.0 + logzsol_prior.range[0] # logzsol_prior.sample(), # logzsol,
    model_params.append({
        'name': 'logzsol',
        'N': 1,
        'isfree': True,
        'init': logzsol_init,
        'init_disp': 0.3, # logzsol_prior.range[1] * 0.1,
        'units': r'$\log_{10}\, (Z/Z_\odot)$',
        'prior': logzsol_prior, # roughly (0.2-2)*Z_sun
        })

    # Prior(s) on dust content
    dust2_prior = priors.TopHat(mini=0.0, maxi=3.0)#, seed=seed)
    dust2_init = np.diff(dust2_prior.range)/2.0 + dust2_prior.range[0] # dust2_prior.sample(), # dust2,
    model_params.append({
        'name': 'dust2',
        'N': 1,
        'isfree': True,
        'init': dust2_init,
        'init_disp': 0.5, # dust2_prior.range[1] * 0.1,
        'units': '', # optical depth
        'prior': dust2_prior,
        })
    
    # Priors on tau and age
    #tau_prior = priors.TopHat(mini=0.1, maxi=10.0)#, seed=seed)
    tau_prior = priors.LogUniform(mini=0.1, maxi=10.0)#, seed=seed)
    tau_init = np.diff(tau_prior.range)/2.0 + tau_prior.range[0] # tau_prior.sample(), # tau,
    model_params.append({
        'name': 'tau',
        'N': 1,
        'isfree': True,
        'init': tau_init,
        'init_disp': 1.0, # tau_prior.range[1] * 0.1,
        'units': 'Gyr',
        'prior': tau_prior,
        })

    tage_prior = priors.TopHat(mini=0.5, maxi=15)#, seed=seed)
    tage_init = np.diff(tage_prior.range) / 2.0 + tage_prior.range[0] # tage_prior.sample(), # tage,
    model_params.append( {
        'name': 'tage',
        'N': 1,
        'isfree': True,
        'init': tage_init,
        'init_disp': 2.0, # tage_prior.range[1] * 0.1,
        'units': 'Gyr',
        'prior': tage_prior,
        })

    model = sedmodel.SedModel(model_params)
    
    return model

In [None]:
def lnprobfn(theta, model, obs, sps, spec_noise=None, phot_noise=None, verbose=False):
    """Define the likelihood function.

    Given a parameter vector and a dictionary of observational data and a model
    object, return the ln of the posterior. This requires that an sps object
    (and if using spectra and gaussian processes, a GP object) be instantiated.

    """
    from prospect.likelihood import lnlike_spec, lnlike_phot, chi_spec, chi_phot
    
    lnp_prior = model.prior_product(theta, nested=True)
    if np.isfinite(lnp_prior):
        # Generate the mean model--
        t1 = time.time()
        try:
            model_spec, model_phot, model_extras = model.mean_model(theta, obs, sps=sps)
            mu, phot, x = model.mean_model(theta, obs, sps=sps)
        except(ValueError):
            return -np.infty
        d1 = time.time() - t1

        # Noise modeling--
        if spec_noise:
            spec_noise.update(**model.params)
        if phot_noise:
            phot_noise.update(**model.params)

        vectors = {
            'spec': model_spec,    # model spectrum
            'phot': model_phot,    # model photometry
            'sed': model._spec,    # object spectrum
            'cal': model._speccal, # object calibration spectrum
        }

        # Calculate likelihoods--
        t2 = time.time()
        lnp_spec = lnlike_spec(model_spec, obs=obs, spec_noise=spec_noise, **vectors)
        lnp_phot = lnlike_phot(model_phot, obs=obs, phot_noise=phot_noise, **vectors)
        d2 = time() - t2
        if verbose:
            from prospect.likelihood import write_log
            write_log(theta, lnp_prior, lnp_spec, lnp_phot, d1, d2)

        return lnp_prior + lnp_phot + lnp_spec
    else:
        return -np.infty

### Read the set of galaxy spectra to fit.

In [None]:
def read_jwst(prefix='jwst'):
    """Read a simulated JWST spectrum.  Returns wavelength, flux, and ferr."""
    import fitsio
    zred = 1.0 # hard-coded!
    fitsflux = fitsio.read(os.path.join(jwstdir, '{}_extracted_flux.fits'.format(prefix)), ext=1, upper=True)
    fitsferr = fitsio.read(os.path.join(jwstdir, '{}_extracted_noise.fits'.format(prefix)), ext=1, upper=True)
    wave, flux, ferr = fitsflux['WAVELENGTH'], fitsflux['EXTRACTED_FLUX'], fitsferr['EXTRACTED_NOISE']
    return wave, flux, ferr, zred

In [None]:
jwstdir = os.path.join( os.getenv('IM_PROJECTS_DIR'), 'jwstclusters' )

In [None]:
seed = 123
prefix = 'lineplot'
outroot = os.path.join( jwstdir, '{}_prospect'.format(prefix) )
hfilename = os.path.join( jwstdir, '{}_prospect_mcmc.h5'.format(prefix) )

In [None]:
# Specify the run parameters and initialize the SPS object.
run_params = {
    'prefix':  prefix,
    'verbose': True,
    'seed':    seed,
    # initial optimization choices (nmin is only for L-M optimization)
    'do_levenburg': True,
    'nmin': 10,
    # emcee fitting parameters
    'nwalkers': 128,
    'nburn': [32, 32, 64], 
    'niter': 256, # 512,
    'interval': 0.1, # save 10% of the chains at a time
    # Nestle fitting parameters
    'nestle_method': 'single',
    'nestle_npoints': 200,
    'nestle_maxcall': int(1e6),
    # Multiprocessing
    'nthreads': 10,
    # SPS initialization parameters
    'compute_vega_mags': False,
    'vactoair_flag': False, # use wavelengths in air
    'zcontinuous': 1,      # interpolate in metallicity
    }

In [None]:
wave, flux, ferr, zred = read_jwst(prefix)

In [None]:
fig, ax = plt.subplots()
ax.errorbar(wave, flux, ferr)
ax.set_xlabel('Observed-frame Wavelength (micron)')
ax.set_ylabel('')

In [None]:
obs = load_obs(zred=zred, wave=wave, flux=flux, ferr=ferr)
obs.keys()

In [None]:
model = load_model(zred=obs['zred'], seed=seed)

In [None]:
# Open the output HDF5 file and write out some basic info.
hfile = h5py.File(hfilename, 'a')
write_results.write_h5_header(hfile, run_params, model)
write_results.write_obs_to_h5(hfile, obs)