# How accurately can we recover the SFR threshold age?
---
In this notebook, we will see how accurately we can recover the SFR threshold age via MC simulations with ppxf, applied to a variety of different SFHs.
We will save the figures for each individual SFH to a .pdf for later reference.


In [2]:
%matplotlib widget

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:75% !important; }</style>"))
display(HTML("<style>.output_result { max-width:75% !important; }</style>"))

In [23]:
import os
import numpy as np
from numpy.random import RandomState
from time import time 
from tqdm.notebook import tqdm
import multiprocessing
import pandas as pd

from astropy.io import fits

from ppxftests.run_ppxf import run_ppxf
from ppxftests.ssputils import load_ssp_templates, get_bin_edges_and_widths
from ppxftests.mockspec import create_mock_spectrum
from ppxftests.sfhutils import load_sfh, compute_mw_age, compute_lw_age, compute_sfr_thresh_age
from ppxftests.ppxf_plot import plot_sfh_mass_weighted, plot_sfh_light_weighted

import matplotlib.pyplot as plt
plt.ion()
plt.close("all")

from IPython.core.debugger import Tracer

In [5]:
###########################################################################
# Settings
###########################################################################
isochrones = "Padova"
sigma_star_kms = 250
z = 0.01
SNR = 100

savefigs = False
fig_path = "/priv/meggs3/u5708159/ppxftests/figs/"

# Parallel execution
niters = 20
nthreads = 20

# Load the stellar templates so we can get the age & metallicity dimensions
_, _, metallicities, ages = load_ssp_templates(isochrones)
N_ages = len(ages)
N_metallicities = len(metallicities)

In [6]:
###########################################################################
# Helper function for running MC simulations
###########################################################################
def ppxf_helper(args):
    # Unpack arguments
    seed, spec, spec_err, lambda_vals_A = args
    
    # Add "extra" noise to the spectrum
    rng = RandomState(seed)
    noise = rng.normal(scale=spec_err)
    spec_noise = spec + noise

    # This is to mitigate the "edge effects" of the convolution with the LSF
    spec_noise[0] = -9999
    spec_noise[-1] = -9999

    # Run ppxf
    pp = run_ppxf(spec=spec_noise, spec_err=spec_err, lambda_vals_A=lambda_vals_A,
                  z=z, ngascomponents=1,
                  regularisation_method="none", 
                  isochrones="Padova",
                  fit_gas=False, tie_balmer=True,
                  plotit=False, savefigs=False, interactive_mode=False)
    return pp


In [34]:
for gal in range(1):
    plt.close("all")

    ###########################################################################
    # Load SFH
    ###########################################################################
    sfh_mw = load_sfh(gal, plotit=False)  
    sfh_mw_1D = np.nansum(sfh_mw, axis=0)
    bin_edges, bin_widths = get_bin_edges_and_widths(isochrones=isochrones)
    sfr_1D = sfh_mw_1D / bin_widths
    M_tot = np.nansum(sfh_mw)
    
    ###########################################################################
    # Generate spectrum
    ###########################################################################
    spec, spec_err, lambda_vals_A = create_mock_spectrum(
        sfh_mass_weighted=sfh_mw,
        agn_continuum=False,
        isochrones=isochrones, z=z, SNR=SNR, sigma_star_kms=sigma_star_kms,
        plotit=False)

    ###########################################################################
    # Run MC simulations with ppxf
    ###########################################################################
    # Input arguments
    seeds = list(np.random.randint(low=0, high=100 * niters, size=niters))
    args_list = [[s, spec, spec_err, lambda_vals_A] for s in seeds]

    # Run in parallel
    print(f"Running ppxf on {nthreads} threads...")
    t = time()
    with multiprocessing.Pool(nthreads) as pool:
        pp_list = list(tqdm(pool.imap(ppxf_helper, args_list), total=niters))
    print(f"Elapsed time in ppxf: {time() - t:.2f} s")    
    
    ###########################################################################
    # Compute mean quantities from the MC runs
    ###########################################################################
    lambda_norm_idx = np.nanargmin(np.abs(lambda_vals_A - 5000))
    
    sfh_lw = sfh_mw * pp_list[0].stellar_template_norms
    sfh_lw_1D = np.nansum(sfh_lw, axis=0)
    
    sfh_MC_mw_list = [pp.weights_mass_weighted for pp in pp_list]
    sfh_MC_mw_1D_list = [pp.sfh_mw_1D for pp in pp_list]
    
    sfh_MC_lw_list = [pp.weights_light_weighted for pp in pp_list]
    sfh_MC_lw_1D_list = [pp.sfh_lw_1D for pp in pp_list]
    
    sfr_MC_mw_list = [pp.sfr_mean for pp in pp_list]
    
    # Compute the mean SFH from the MC runs
    sfh_MC_mw_mean = np.nansum(np.array(sfh_MC_mw_list), axis=0) / len(sfh_MC_mw_list)
    sfh_MC_mw_1D_mean = np.nansum(sfh_MC_mw_mean, axis=0)

    sfh_MC_lw_mean = np.nansum(np.array(sfh_MC_lw_list), axis=0) / len(sfh_MC_lw_list)
    sfh_MC_lw_1D_mean = np.nansum(sfh_MC_lw_mean, axis=0)

    sfr_MC_mw_mean = np.nansum(np.array(sfr_MC_mw_list), axis=0) / len(sfr_MC_mw_list)
   
    # Compute the mean mass-weighted age below 1 Gyr
    age_thresh = 1e9  # yr
    age_thresh_idx = (np.log10(age_thresh) - np.log10(ages[0])) / (np.log10(ages[1]) - np.log10(ages[0]))
    
    age_mw_idx_MC_list = [compute_mw_age(sfh, age_thresh, isochrones)[1] for sfh in sfh_MC_mw_list]
    age_mw_idx_MC_median = np.nanmedian(age_mw_idx_MC_list)
    age_mw_idx_MC_std = np.nanstd(age_mw_idx_MC_list)
    age_mw_idx_input = compute_mw_age(sfh_mw, age_thresh, isochrones)[1]

    age_lw_idx_MC_list = [compute_lw_age(sfh, age_thresh, isochrones)[1] for sfh in sfh_MC_lw_list]
    age_lw_idx_MC_median = np.nanmedian(age_lw_idx_MC_list)
    age_lw_idx_MC_std = np.nanstd(age_lw_idx_MC_list)
    age_lw_idx_input = compute_lw_age(sfh_lw, age_thresh, isochrones)[1]
    
    # Compute the "SFR age"
    sfr_thresh = 1.0  # Msun/yr
    sfr_age_idx_MC_list = [compute_sfr_thresh_age(sfh, sfr_thresh, isochrones)[1] for sfh in sfh_MC_mw_list]
    sfr_age_idx_MC_median = np.nanmedian(sfr_age_idx_MC_list)
    sfr_age_idx_MC_mean = np.nanmean(sfr_age_idx_MC_list)
    sfr_age_idx_MC_std = np.nanstd(sfr_age_idx_MC_list)
    sfr_age_idx_input = compute_sfr_thresh_age(sfh_mw, sfr_thresh, isochrones)[1]
    
    ###########################################################################
    # Plot the input mass- and light-weighted SFHs
    ###########################################################################
    fig, axs = plt.subplots(nrows=5, ncols=1, figsize=(12, 25))
    fig.subplots_adjust(hspace=0.35)

    plot_sfh_mass_weighted(sfh_mw, ages, metallicities, ax=axs[0])
    axs[0].set_title(f"Galaxy {gal:004} " + r"- $M_{\rm tot} = %.4e\,\rm M_\odot$" % M_tot)

    plot_sfh_light_weighted(sfh_lw, ages, metallicities, ax=axs[1])
    axs[1].set_title(f"Galaxy {gal:004} " + r"- $M_{\rm tot} = %.4e\,\rm M_\odot$" % M_tot)

    ###########################################################################
    # Plot the mass-weighted weights, summed over the metallicity dimension
    ###########################################################################
    log_scale = True
    axs[2].set_title(f"Mass-weighted template weights (S/N = {SNR})")

    # Plot the SFHs from each ppxf run, plus the "truth" SFH
    axs[2].fill_between(range(N_ages), sfh_mw_1D, step="mid", alpha=1.0, color="lightblue", label="Input SFH")
    for jj in range(niters):
        axs[2].step(range(N_ages), sfh_MC_mw_1D_list[jj], color="pink", alpha=0.2, where="mid", linewidth=0.25, label="ppxf fits (MC simluations)" if jj == 0 else None)
    axs[2].step(range(N_ages), sfh_MC_mw_1D_mean, color="red", where="mid", label="Mean ppxf fit (MC simulations)")
    # axs[2].step(range(N_ages), sfh_fit_mw_1D_regul, color="lightgreen", where="mid", label="ppxf fit (regularised)")

    # Plot horizontal error bars indicating the SFR threshold age from the MC simulations
    y = 10**(0.9 * np.log10(axs[2].get_ylim()[1])) if log_scale else 0.9 * axs[2].get_ylim()[1]
    axs[2].errorbar(x=sfr_age_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="SFR age (median, MC simulations)")
    axs[2].errorbar(x=sfr_age_idx_MC_mean, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="orange", mec="orange", ecolor="orange", linestyle="none", capsize=10,
                label="SFR age (mean, MC simulations)")
    axs[2].errorbar(x=sfr_age_idx_input, y=y, xerr=0, yerr=0, 
                marker="^", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="SFR age (input)")

    # Plot horizontal error bars indicating the mean mass-weighted age from the MC simulations
    y = 10**(0.8 * np.log10(axs[2].get_ylim()[1])) if log_scale else 0.8 * axs[2].get_ylim()[1]
    axs[2].errorbar(x=age_mw_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="D", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="Mean MW age < 1 Gyr (MC simulations)")
    axs[2].errorbar(x=age_mw_idx_input, y=y, xerr=0, yerr=0, 
                marker="D", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="Mean MW age < 1 Gyr (input)")

    # Plot horizontal error bars indicating the mean light-weighted age from the MC simulations
    y = 10**(0.8 * np.log10(axs[2].get_ylim()[1])) if log_scale else 0.8 * axs[2].get_ylim()[1]
    axs[2].errorbar(x=age_lw_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="X", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="Mean LW age < 1 Gyr (MC simulations)")
    axs[2].errorbar(x=age_lw_idx_input, y=y, xerr=0, yerr=0, 
                marker="X", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="Mean LW age < 1 Gyr (input)")
    
    axs[2].axvline(age_thresh_idx, color="black", linestyle="--", label="Age threshold")

    # Decorations 
    axs[2].set_xticks(range(N_ages))
    axs[2].set_xlabel("Age (Myr)")
    axs[2].set_xticklabels(["{:}".format(age / 1e6) for age in ages], rotation="vertical", fontsize="x-small")
    axs[2].autoscale(axis="x", enable=True, tight=True)
    axs[2].set_ylim([1e2, None])
    axs[2].set_ylabel(r"Template weight ($\rm M_\odot$)")
    axs[2].legend(fontsize="small", loc="center left", bbox_to_anchor=(1.01, 0.5))
    axs[2].set_xlabel("Age (Myr)")
    axs[2].set_yscale("log") if log_scale else None
    axs[2].grid()
    
    ###########################################################################
    # Plot the light-weighted weights, summed over the metallicity dimension
    ###########################################################################
    log_scale = True
    axs[3].set_title(f"Light-weighted template weights (S/N = {SNR})")

    # Plot the SFHs from each ppxf run, plus the "truth" SFH
    axs[3].fill_between(range(N_ages), sfh_lw_1D, step="mid", alpha=1.0, color="lightblue", label="Input SFH")
    for jj in range(niters):
        axs[3].step(range(N_ages), sfh_MC_lw_1D_list[jj], color="pink", alpha=0.2, where="mid", linewidth=0.25, label="ppxf fits (MC simluations)" if jj == 0 else None)
    axs[3].step(range(N_ages), sfh_MC_lw_1D_mean, color="red", where="mid", label="Mean ppxf fit (MC simulations)")
#     axs[3].step(range(N_ages), sfh_fit_lw_1D_regul, color="lightgreen", where="mid", label="ppxf fit (regularised)")
    axs[3].set_ylim([1e35, None])

    # Plot horizontal error bars indicating the SFR threshold age from the MC simulations
    y = 10**(0.9 * np.log10(axs[3].get_ylim()[1])) if log_scale else 0.9 * axs[3].get_ylim()[1]
    axs[3].errorbar(x=sfr_age_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="SFR age (median, MC simulations)")
    axs[3].errorbar(x=sfr_age_idx_MC_mean, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="orange", mec="orange", ecolor="orange", linestyle="none", capsize=10,
                label="SFR age (mean, MC simulations)")
    axs[3].errorbar(x=sfr_age_idx_input, y=y, xerr=0, yerr=0, 
                marker="^", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="SFR age (input)")

    # Plot horizontal error bars indicating the mean mass-weighted age from the MC simulations
    y = 10**(0.8 * np.log10(axs[3].get_ylim()[1])) if log_scale else 0.8 * axs[3].get_ylim()[1]
    axs[3].errorbar(x=age_mw_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="D", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="Mean MW age < 1 Gyr (MC simulations)")
    axs[3].errorbar(x=age_mw_idx_input, y=y, xerr=0, yerr=0, 
                marker="D", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="Mean MW age < 1 Gyr (input)")

    # Plot horizontal error bars indicating the mean light-weighted age from the MC simulations
    y = 10**(0.7 * np.log10(axs[3].get_ylim()[1])) if log_scale else 0.8 * axs[3].get_ylim()[1]
    axs[3].errorbar(x=age_lw_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="X", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="Mean LW age < 1 Gyr (MC simulations)")
    axs[3].errorbar(x=age_lw_idx_input, y=y, xerr=0, yerr=0, 
                marker="X", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="Mean LW age < 1 Gyr (input)")
    
    axs[3].axvline(age_thresh_idx, color="black", linestyle="--", label="Age threshold")

    # Decorations 
    axs[3].set_xticks(range(N_ages))
    axs[3].set_xlabel("Age (Myr)")
    axs[3].set_xticklabels(["{:}".format(age / 1e6) for age in ages], rotation="vertical", fontsize="x-small")
    axs[3].autoscale(axis="x", enable=True, tight=True)
    axs[3].set_ylabel(r"Template weight ($\rm M_\odot\,erg\,s^{-1}\,Å^{-1}$)")
    axs[3].legend(fontsize="small", loc="center left", bbox_to_anchor=(1.01, 0.5))
    axs[3].set_xlabel("Age (Myr)")
    axs[3].set_yscale("log") if log_scale else None
    axs[3].grid()

    ###########################################################################
    # Plot the average SFR in each bin
    ###########################################################################
    log_scale = True
    axs[4].set_title(f"Mean SFR (S/N = {SNR})")

    # Plot the SFHs from each ppxf run, plus the "truth" SFH
    axs[4].fill_between(range(N_ages), sfr_1D, step="mid", alpha=1.0, color="lightblue", label="Input mean SFR")
    for jj in range(niters):
        axs[4].step(range(N_ages), sfr_MC_mw_list[jj], color="pink", alpha=0.2, where="mid", linewidth=0.25, label="ppxf fits (MC simluations)" if jj == 0 else None)
    axs[4].step(range(N_ages), sfr_MC_mw_mean, color="red", where="mid", label="Mean ppxf fit (MC simulations)")

    # Plot horizontal error bars indicating the SFR threshold age from the MC simulations
    y = 10**(0.9 * np.log10(axs[4].get_ylim()[1])) if log_scale else 0.9 * axs[4].get_ylim()[1]
    axs[4].errorbar(x=sfr_age_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="red", mec="red", ecolor="red", linestyle="none", capsize=10,
                label="SFR age (median, MC simulations)")
    axs[4].errorbar(x=sfr_age_idx_MC_mean, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
                marker="^", mfc="orange", mec="orange", ecolor="orange", linestyle="none", capsize=10,
                label="SFR age (mean, MC simulations)")
    axs[4].errorbar(x=sfr_age_idx_input, y=y, xerr=0, yerr=0, 
                marker="D", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
                label="SFR age (input)")
    axs[4].axhline(sfr_thresh, color="black", linestyle=":", label="SFR threshold")

    # Decorations 
    axs[4].set_xticks(range(N_ages))
    axs[4].set_xlabel("Age (Myr)")
    axs[4].set_xticklabels(["{:}".format(age / 1e6) for age in ages], rotation="vertical", fontsize="x-small")
    axs[4].autoscale(axis="x", enable=True, tight=True)
    axs[4].set_ylabel(r"Mean SFR ($\rm M_\odot \, yr^{-1}$)")
    axs[4].legend(fontsize="small", loc="center left", bbox_to_anchor=(1.01, 0.5))
    axs[4].set_xlabel("Age (Myr)")
    axs[4].set_yscale("log") if log_scale else None
    axs[4].grid()
    
    # Save to file
    if savefigs:
        fig.savefig(os.path.join(fig_path, "sfr_threshold", f"{gal:004}.pdf"), format="pdf", bbox_inches="tight")

Running ppxf on 20 threads...


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))


Elapsed time in ppxf: 40.65 s
