# Test: what is the best way to quantify the starburst age?
---

In this notebook:
* come up with a few different ways to quantify the "starburst age".
* write necessary functions to compute these quantities given a star formation history.
* Using MC runs (and later on regularisation too), see whether any of these quantities can reliably be measured.

Ideas:
* As a control: previous method - time index at which most recent star formation event drops to 0.
    * *potential issue*: definitely unreliable, given the "noise" in the recovered SFH that persists at even very high S/N.
* Mass weighted age below some age threshold.
    * *potential issue*: choice of age threshold is somewhat arbitrary.
* Counting backwards from $t = 0$, the time index at which the galaxy has built up $X \%$ of its total stellar mass. 
    * *potential issue*: may not be a good proxy for the precise quantity we're looking for.
* In the *most recent* star formation event, find the earliest time index at which the mass in each bin exceeds some minimum value, e.g., 1e7 solar masses or 0.01% of the total stellar mass.
    * *potential issue*: bins are not linearly spaced in age - how to deal with this?
* In the *most recent* star formation event, find the earliest time index at which the SFR exceeds some minimum value, e.g. $1 \rm \, M_\odot \, yr^{-1}$.
    * *potential issue*: this may not represent the actual SFR in the galaxy, if the new stars were obtained via a merger, for instance. 


In [6]:
%matplotlib widget

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:75% !important; }</style>"))
display(HTML("<style>.output_result { max-width:75% !important; }</style>"))

In [102]:
import numpy as np
from numpy.random import RandomState
from time import time 
from tqdm.notebook import tqdm
import multiprocessing
import pandas as pd

from astropy.io import fits

from ppxftests.run_ppxf import run_ppxf
from ppxftests.ssputils import load_ssp_templates, get_bin_edges_and_widths
from ppxftests.mockspec import create_mock_spectrum
from ppxftests.sfhutils import load_sfh, compute_mw_age, compute_sfr_thresh_age
from ppxftests.ppxf_plot import plot_sfh_mass_weighted

import matplotlib.pyplot as plt
plt.ion()
plt.close("all")

from IPython.core.debugger import Tracer

In [8]:
###########################################################################
# Settings
###########################################################################
isochrones = "Padova"
sigma_star_kms = 250
z = 0.01

# Load the stellar templates so we can get the age & metallicity dimensions
_, _, metallicities, ages = load_ssp_templates(isochrones)
N_ages = len(ages)
N_metallicities = len(metallicities)


In [9]:
###########################################################################
# Definition 1: time index at which most recent star formation event drops 
# to 0
###########################################################################
def compute_sb_zero_age(sfh_mw):
    # Sum the SFH over the metallicity dimension to get the 1D SFH
    sfh_mw_1D = np.nansum(sfh_mw, axis=0) if sfh_mw.ndim > 1 else sfh_mw
        
    first_nonzero_idx = np.argwhere(sfh_mw_1D > 0)[0][0]
    if np.any(sfh_mw_1D[first_nonzero_idx:] == 0):
        first_zero_idx = np.argwhere(sfh_mw_1D[first_nonzero_idx:] == 0)[0][0] + first_nonzero_idx
        return ages[first_zero_idx], first_zero_idx
    else:
        return np.nan, np.nan

In [36]:
###########################################################################
# Definition 2: Mass weighted age below some age threshold
###########################################################################
def compute_mw_age(sfh_mw, 
                   age_thresh=1e9):
    # Sum the SFH over the metallicity dimension to get the 1D SFH
    sfh_mw_1D = np.nansum(sfh_mw, axis=0) if sfh_mw.ndim > 1 else sfh_mw

    # Find the index of the threshold age in the template age array
    age_thresh_idx = np.nanargmin(np.abs(ages - age_thresh))
    
    # Compute the mass-weighted age 
    log_age_mw = np.nansum(sfh_mw_1D[:age_thresh_idx] * np.log10(ages[:age_thresh_idx])) / np.nansum(sfh_mw_1D[:age_thresh_idx])
    
    # Compute the corresponding index in the array (useful for plotting)
    log_age_mw_idx = (log_age_mw - np.log10(ages[0])) / (np.log10(ages[1]) - np.log10(ages[0]))
    
    return 10**log_age_mw, log_age_mw_idx


In [11]:
###########################################################################
# Definition 3: Counting backwards from $t = 0$, the time index at which 
# the galaxy has built up $X \%$ of its total stellar mass. 
###########################################################################
def compute_first_massive_bin(sfh_mw,
                              mass_frac_thresh=0.01):
    # Sum the SFH over the metallicity dimension to get the 1D SFH
    sfh_mw_1D = np.nansum(sfh_mw, axis=0) if sfh_mw.ndim > 1 else sfh_mw

    M_tot = np.nansum(sfh_mw_1D) 
    age_idx = np.argwhere(sfh_mw_1D / M_tot > mass_frac_thresh)[0][0]
        
    return ages[age_idx], age_idx


In [12]:
###########################################################################
# Definition 4: In the *most recent* star formation event, find the earliest 
# time index at which the mass in each bin exceeds 
# some minimum value, e.g., 1e7 solar masses or 0.01% of the total stellar mass.
###########################################################################
def compute_cumulative_mass_frac(sfh_mw,
                                 mass_frac_thresh=0.01):
    # Sum the SFH over the metallicity dimension to get the 1D SFH
    sfh_mw_1D = np.nansum(sfh_mw, axis=0) if sfh_mw.ndim > 1 else sfh_mw

    M_tot = np.nansum(sfh_mw_1D) 
    cumsum_M_frac = np.cumsum(sfh_mw_1D) / M_tot
    age_idx = np.nanargmin(np.abs(cumsum_M_frac - mass_frac_thresh))
    
    return ages[age_idx], age_idx


In [13]:
###########################################################################
# Definition 5: counting backwards from t = 0, find the most recent time 
# when the SFR exceeded a given threshold
###########################################################################
def compute_sfr_thresh_age(sfh_mw, sfr_thresh, isochrones):
    
    # Sum the SFH over the metallicity dimension to get the 1D SFH
    sfh_mw_1D = np.nansum(sfh_mw, axis=0) if sfh_mw.ndim > 1 else sfh_mw

    # Compute the bin edges and widths so that we can compute the mean SFR in each bin
    bin_edges, bin_widths = get_bin_edges_and_widths(isochrones=isochrones)
    
    # Compute the mean SFR in each bin
    sfr_avg = sfh_mw_1D / bin_sizes
    
    # Find the first index where the SFR exceed a certain value
    if np.any(sfr_avg > sfr_thresh):
        age_idx = np.argwhere(sfr_avg > sfr_thresh)[0][0]
        return ages[age_idx], age_idx
    else:
        return np.nan, np.nan
    

In [68]:
####################################################################
# Plot to check each of these have worked
####################################################################
sfh = load_sfh(41, plotit=True)
sfh_1D = np.nansum(sfh, axis=0)
M_tot = np.nansum(sfh)

# Parameters for age estimators
mass_frac_thresh = 1e-1
age_thresh = 1e8
sfr_thresh = 1

age_1, age_idx_1 = compute_sb_zero_age(sfh)  # "Starburst" age
age_2, age_idx_2 = compute_mw_age(sfh, age_thresh, isochrones)  # Mass-weighted age
age_3, age_idx_3 = compute_first_massive_bin(sfh, mass_frac_thresh)  # Mass fraction threshold (instantaneous)
age_4, age_idx_4 = compute_cumulative_mass_frac(sfh, mass_frac_thresh)  # Mass fraction threshold (cumulative)
age_5, age_idx_5 = compute_sfr_thresh_age(sfh, sfr_thresh)  # SFR threshold

print(f"compute_sb_zero_age(): {age_1 / 1e6:.2f} Myr")
print(f"compute_mw_age(): {age_2 / 1e6:.2f} Myr")
print(f"compute_first_massive_bin(): {age_3 / 1e6:.2f} Myr")
print(f"compute_cumulative_mass_frac(): {age_4 / 1e6:.2f} Myr")
print(f"compute_sfr_thresh_age(): {age_5 / 1e6:.2f} Myr")

####################################################################
# Plot the SFH and the cumulative mass counting from t = 0
####################################################################
fig, ax = plt.subplots(figsize=(10, 4))
fig.subplots_adjust(bottom=0.3, top=0.9)

# Plot the SFH and the cumulative mass from t = 0
ax.step(x=range(N_ages), y=sfh_1D / M_tot, color="black", where="mid", label="1D SFH")
ax.step(x=range(N_ages), y=np.cumsum(sfh_1D) / M_tot, where="mid", label="Cumulative mass", linewidth=0.5)

# Indicate each SB age measure
ax.axvline(age_idx_1, color="grey", label="Starburst age")
ax.axvline(age_idx_2, color="blue", label=f"Mass-weighted mean age (< {age_thresh / 1e6:.0f} Myr)")
ax.axvline(age_idx_3, color="red", label="Mass fraction (instantaneous)")
ax.axvline(age_idx_4, color="orange", label="Mass fraction (cumulative)")
ax.axvline(age_idx_5, color="green", label="SFR")

ax.axhline(mass_frac_thresh, linestyle="--", color="grey", label="Mass fraction threshold")

ax.set_yscale("log")
ax.set_xticks(range(N_ages))
ax.set_xticklabels(ages / 1e6, rotation="vertical", fontsize="x-small")
ax.grid()
ax.legend(fontsize="x-small")
ax.autoscale(axis="x", tight=True, enable=True)
ax.set_ylabel(r"Stellar mass fraction $M/M_{\rm tot}$")
ax.set_xlabel("Bin age (Myr)")

####################################################################
# Plot the mean SFR in each bin
####################################################################
# Compute the bin edges and widths so that we can compute the mean SFR in each bin
bin_widths = np.zeros(len(ages))
bin_edges = np.zeros(len(ages) + 1)
for aa in range(1, len(ages)):
    bin_edges[aa] = 10**(0.5 * (np.log10(ages[aa - 1]) + np.log10(ages[aa])) )
delta_log_age = np.diff(np.log10(ages))[0]
age_0 = 10**(np.log10(ages[0]) - delta_log_age)
age_last = 10**(np.log10(ages[-1]) + delta_log_age)
bin_edges[0] = 10**(0.5 * (np.log10(age_0) + np.log10(ages[0])) )
bin_edges[-1] = 10**(0.5 * (np.log10(ages[-1]) + np.log10(age_last)))
bin_sizes = np.diff(bin_edges)
sfr_avg = sfh_1D / bin_sizes

# Plot the mean SFR in each bin
fig, ax = plt.subplots(figsize=(10, 4))
fig.subplots_adjust(bottom=0.3, top=0.9)
ax.step(x=range(N_ages), y=sfr_avg, where="mid", label="Average SFR")

# Indicate each SB age measure
ax.axvline(age_idx_1, color="grey", label="Starburst age")
ax.axvline(age_idx_2, color="blue", label=f"Mass-weighted mean age (< {age_thresh / 1e6:.0f} Myr)")
ax.axvline(age_idx_3, color="red", label="Mass fraction (instantaneous)")
ax.axvline(age_idx_4, color="orange", label="Mass fraction (cumulative)")
ax.axvline(age_idx_5, color="green", label="SFR")

ax.axhline(sfr_thresh, linestyle="--", color="grey", label="SFR threshold")

ax.set_yscale("log")
ax.set_xticks(range(N_ages))
ax.set_xticklabels(ages / 1e6, rotation="vertical", fontsize="x-small")
ax.grid()
ax.legend(fontsize="x-small")
ax.autoscale(axis="x", tight=True, enable=True)
ax.set_ylabel(r"Mean SFR ($\rm M_\odot \, yr^{-1}$)")
ax.set_xlabel("Bin age (Myr)")

####################################################################
# Plot the mass-weighted mean age as a function of age threshold
####################################################################
mw_mean_ages = [compute_mw_age(sfh, age)[0] for age in ages]
fig, ax = plt.subplots()
ax.scatter(ages, mw_mean_ages)
ax.set_xlabel("Age threshold (Myr)")
ax.set_ylabel("Mass-weighted mean age (Myr)")
ax.set_xscale("log")
ax.set_yscale("log")


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

TypeError: compute_mw_age() missing 1 required positional argument: 'isochrones'

## Can ppxf accurately return the "SFR age"?
---
Run ppxf on a few different input SFHs, using the MC method. 
Plot the SFH with the "SFR ages" indicated for both the input and the output.

In [85]:
###########################################################################
# Settings
###########################################################################
isochrones = "Padova"
sigma_star_kms = 250
z = 0.01
SNR = 100

niters = 1000
nthreads = 20

# Load the stellar templates so we can get the age & metallicity dimensions
_, _, metallicities, ages = load_ssp_templates(isochrones)
N_ages = len(ages)
N_metallicities = len(metallicities)


In [86]:
###########################################################################
# Helper function for running MC simulations
###########################################################################
def ppxf_helper(args):
    # Unpack arguments
    seed, spec, spec_err, lambda_vals_A = args
    
    # Add "extra" noise to the spectrum
    rng = RandomState(seed)
    noise = rng.normal(scale=spec_err)
    spec_noise = spec + noise

    # This is to mitigate the "edge effects" of the convolution with the LSF
    spec_noise[0] = -9999
    spec_noise[-1] = -9999

    # Run ppxf
    pp = run_ppxf(spec=spec_noise, spec_err=spec_err, lambda_vals_A=lambda_vals_A,
                  z=z, ngascomponents=1,
                  regularisation_method="none", 
                  isochrones="Padova",
                  fit_gas=False, tie_balmer=True,
                  plotit=False, savefigs=False, interactive_mode=False)
    return pp


In [87]:
###########################################################################
# Load SFH
###########################################################################
sfh_mw = load_sfh(41, plotit=True)
sfh_mw_1D = np.nansum(sfh_mw, axis=0)
bin_edges, bin_widths = get_bin_edges_and_widths(isochrones=isochrones)
sfr_mw_1D = sfh_mw_1D / bin_widths
M_tot = np.nansum(sfh_mw)

###########################################################################
# Generate spectrum
###########################################################################
spec, spec_err, lambda_vals_A = create_mock_spectrum(
    sfh_mass_weighted=sfh_mw,
    agn_continuum=False,
    isochrones=isochrones, z=z, SNR=SNR, sigma_star_kms=sigma_star_kms,
    plotit=True)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_w, fig_h))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(fig_w, fig_h))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  fig = plt.figure(figsize=(10, 3.5))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [88]:
plt.close("all")

In [89]:
###########################################################################
# Run MC simulations with ppxf
###########################################################################
# Input arguments
seeds = list(np.random.randint(low=0, high=100 * niters, size=niters))
args_list = [[s, spec, spec_err, lambda_vals_A] for s in seeds]

# Run in parallel
print(f"Running ppxf on {nthreads} threads...")
t = time()
with multiprocessing.Pool(nthreads) as pool:
    pp_list = list(tqdm(pool.imap(ppxf_helper, args_list), total=niters))
print(f"Elapsed time in ppxf: {time() - t:.2f} s")

Running ppxf on 20 threads...


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


Elapsed time in ppxf: 2465.69 s


In [114]:
###########################################################################
# Compute mean quantities from the MC runs
###########################################################################
sfh_MC_mw_list = [pp.weights_mass_weighted for pp in pp_list]
sfh_MC_mw_1D_list = [pp.sfh_mw_1D for pp in pp_list]
sfr_MC_mw_list = [pp.sfr_mean for pp in pp_list]

# Compute the mean SFH from the MC runs
sfh_MC_mw_mean = np.nansum(np.array(sfh_MC_mw_list), axis=0) / len(sfh_MC_mw_list)
sfh_MC_mw_1D_mean = np.nansum(sfh_MC_mw_mean, axis=0)
sfr_MC_mw_mean = np.nansum(np.array(sfr_MC_mw_list), axis=0) / len(sfr_MC_mw_list)

# Compute the mean mass-weighted age below 1 Gyr
age_thresh = 1e9  # yr
age_mw_idx_MC_list = [compute_mw_age(sfh, age_thresh, isochrones)[1] for sfh in sfh_MC_mw_list]
age_mw_idx_MC_median = np.nanmedian(age_mw_idx_MC_list)
age_mw_idx_MC_std = np.nanstd(age_mw_idx_MC_list)
age_mw_idx_input = compute_mw_age(sfh_mw, age_thresh, isochrones)[1]

# Compute the "SFR age"
sfr_thresh = 1.0  # Msun/yr
sfr_age_idx_MC_list = [compute_sfr_thresh_age(sfh, sfr_thresh, isochrones)[1] for sfh in sfh_MC_mw_list]
sfr_age_idx_MC_median = np.nanmedian(sfr_age_idx_MC_list)
sfr_age_idx_MC_mean = np.nanmean(sfr_age_idx_MC_list)
sfr_age_idx_MC_std = np.nanstd(sfr_age_idx_MC_list)
sfr_age_idx_input = compute_sfr_thresh_age(sfh_mw, sfr_thresh, isochrones)[1]

In [125]:
sfr_age_idx_input

16

In [121]:
np.nanmedian(sfr_age_idx_MC_list)

12.0

In [124]:
###########################################################################
# Plot the mass-weighted weights, summed over the metallicity dimension
###########################################################################
fig = plt.figure(figsize=(13, 4))
ax = fig.add_axes([0.1, 0.2, 0.7, 0.7])
ax.set_title(f"Mass-weighted template weights (S/N = {SNR})")

# Plot the SFHs from each ppxf run, plus the "truth" SFH
ax.fill_between(range(N_ages), sfh_mw_1D, step="mid", alpha=1.0, color="lightblue", label="Input SFH")
for jj in range(niters):
    ax.step(range(N_ages), sfh_MC_mw_1D_list[jj], color="pink", alpha=0.2, where="mid", linewidth=0.25, label="ppxf fits (MC simluations)" if jj == 0 else None)
ax.step(range(N_ages), sfh_MC_mw_1D_mean, color="red", where="mid", label="Mean ppxf fit (MC simulations)")
# ax.step(range(N_ages), sfh_fit_mw_1D_regul, color="lightgreen", where="mid", label="ppxf fit (regularised)")

# Plot horizontal error bars indicating the SFR threshold age from the MC simulations
y = 10**(0.9 * np.log10(ax.get_ylim()[1])) if log_scale else 0.9 * ax.get_ylim()[1]
ax.errorbar(x=sfr_age_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
            marker="^", mfc="red", mec="red", ecolor="red", linestyle="none",
            label="SFR age (median, MC simulations)")
ax.errorbar(x=sfr_age_idx_MC_mean, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
            marker="^", mfc="orange", mec="orange", ecolor="orange", linestyle="none",
            label="SFR age (mean, MC simulations)")
ax.errorbar(x=sfr_age_idx_input, y=y, xerr=0, yerr=0, 
            marker="^", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none", 
            label="SFR age (input)")

# Plot horizontal error bars indicating the mean mass-weighted age from the MC simulations
y = 10**(0.8 * np.log10(ax.get_ylim()[1])) if log_scale else 0.8 * ax.get_ylim()[1]
ax.errorbar(x=age_mw_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
            marker="D", mfc="red", mec="red", ecolor="red", linestyle="none",
            label="Mean MW age < 1 Gyr (MC simulations)")
ax.errorbar(x=age_mw_idx_input, y=y, xerr=0, yerr=0, 
            marker="D", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none", 
            label="Mean MW age < 1 Gyr (input)")

# Decorations 
ax.set_xticks(range(N_ages))
ax.set_xlabel("Age (Myr)")
ax.set_xticklabels(["{:}".format(age / 1e6) for age in ages], rotation="vertical", fontsize="x-small")
ax.autoscale(axis="x", enable=True, tight=True)
ax.set_ylim([1, None])
ax.set_ylabel(r"Template weight ($\rm M_\odot$)")
ax.legend(fontsize="x-small", loc="center left", bbox_to_anchor=(1.01, 0.5))
ax.set_xlabel("Age (Myr)")
ax.set_yscale("log") if log_scale else None

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [123]:
###########################################################################
# Plot the average SFR in each bin
###########################################################################
fig = plt.figure(figsize=(13, 4))
ax = fig.add_axes([0.1, 0.2, 0.7, 0.7])
ax.set_title(f"Mean SFR (S/N = {SNR})")

# Plot the SFHs from each ppxf run, plus the "truth" SFH
ax.fill_between(range(N_ages), sfr_mw_1D, step="mid", alpha=1.0, color="lightblue", label="Input mean SFR")
for jj in range(niters):
    ax.step(range(N_ages), sfr_MC_mw_list[jj], color="pink", alpha=0.2, where="mid", linewidth=0.25, label="ppxf fits (MC simluations)" if jj == 0 else None)
ax.step(range(N_ages), sfr_MC_mw_mean, color="red", where="mid", label="Mean ppxf fit (MC simulations)")

# Plot horizontal error bars indicating the SFR threshold age from the MC simulations
y = 10**(0.9 * np.log10(ax.get_ylim()[1])) if log_scale else 0.9 * ax.get_ylim()[1]
ax.errorbar(x=sfr_age_idx_MC_median, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
            marker="^", mfc="red", mec="red", ecolor="red", linestyle="none",
            label="SFR age (median, MC simulations)")
ax.errorbar(x=sfr_age_idx_MC_mean, y=y, xerr=sfr_age_idx_MC_std, yerr=0, 
            marker="^", mfc="orange", mec="orange", ecolor="orange", linestyle="none",
            label="SFR age (mean, MC simulations)")
ax.errorbar(x=sfr_age_idx_input, y=y, xerr=0, yerr=0, 
            marker="D", mfc="lightblue", mec="lightblue", ecolor="lightblue", linestyle="none",
            label="SFR age (input)")

# Decorations 
ax.set_xticks(range(N_ages))
ax.set_xlabel("Age (Myr)")
ax.set_xticklabels(["{:}".format(age / 1e6) for age in ages], rotation="vertical", fontsize="x-small")
ax.autoscale(axis="x", enable=True, tight=True)
ax.set_ylabel(r"Mean SFR ($\rm M_\odot \, yr^{-1}$")
ax.legend(fontsize="x-small", loc="center left", bbox_to_anchor=(1.01, 0.5))
ax.set_xlabel("Age (Myr)")
ax.set_yscale("log") if log_scale else None

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …