## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import units
from astropy.cosmology import FlatLambdaCDM

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import snapshot
import dataset_comp
import curve_fit

In [None]:
import importlib
importlib.reload(dataset_comp)
importlib.reload(snapshot)
importlib.reload(curve_fit)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
data = {
    "plain-LCDM-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (2, 0),
        "Color": ["black", "gray"]
    },
    "curv-p082-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p082_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 1),
        "Color": ["red", "pink"]
    },
    "curv-p084-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p084_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 0),
        "Color": ["blue", "lightblue"]
    }
}

In [None]:
snap_id = 127
data = {
    "plain-LCDM": {
        "Snapshot": snapshot.Snapshot("V1_MR_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (2, 0),
        "Color": ["black", "gray"]
    },
    "curv-p082": {
        "Snapshot": snapshot.Snapshot("V1_MR_curvaton_p082_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 1),
        "Color": ["red", "pink"]
    }
}

In [None]:
vmax_ranges = [[10, 20], [20, 40], [40, 200]]

In [None]:
for name, sim_data in data.items():
    snap = sim_data["Snapshot"]
    
    H0 = snap.get_attribute("HubbleParam", "Header") * 100
    Om0 = snap.get_attribute("Omega0", "Header")
    cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
    
    # Read the scale factor at formation time for each star particle in each subhalo
    # and convert to lookback time:    
    sf_a = dataset_comp.group_particles_by_subhalo(
        snap, "StellarFormationTime", part_type=[4]
    )["StellarFormationTime"]
    sf_times = [cosmo.age(0).value - cosmo.age(1/a - 1).value for a in sf_a]
    
    # Get the star-formation onset times of subhalos, as the formation times of their
    # earliest star particles:
    sim_data["StarFormationOnset"] = np.array([np.max(sft) if sft.size > 0 
                                               else np.nan for sft in sf_times])
        
    # Split into satellites:
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        snap, sim_data["M31_ID"], sim_data["MW_ID"]
    )

    mask_lum, mask_dark = dataset_comp.split_luminous(snap)
    
    # Add selections (masking arrays):
    data[name]['Selections'] = {
        'M31': mask_m31,
        'MW': mask_mw,
        'Satellite': np.logical_or(mask_m31, mask_mw),
        'Isolated': mask_isol,
        'Luminous': mask_lum,
        'Dark': mask_dark
    }
    
    for vmax_down, vmax_up in vmax_ranges:
        vran_str = "{}-{}".format(vmax_down, vmax_up)
        print(vran_str)
        sim_data["Selections"][vran_str] = \
            dataset_comp.prune_vmax(snap, low_lim=vmax_down, up_lim=vmax_up)

In [None]:
# lil check:
for name, sim_data in data.items():
    sft = sim_data["StarFormationOnset"]
    mask_lum = sim_data["Selections"]["Luminous"]
    print(name, np.all(mask_lum == ~np.isnan(sft)))

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    
    H0 = snap.get_attribute("HubbleParam", "Header") * 100
    Om0 = snap.get_attribute("Omega0", "Header")
    cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
    
    # Read star particle formation times for star particles of 
    # each subhalo. Select only star particles with masses in the
    # given range:
    initial_mass = snap.get_particles("InitialMass", part_type=[4]) \
                      * units.g.to(units.Msun)
    print(np.min(initial_mass), np.max(initial_mass), np.mean(initial_mass))
    mask_mass_range = np.logical_and(initial_mass > 10**3, 
                                     initial_mass < 10**8)
    grouped_data = dataset_comp.group_selected_particles_by_subhalo(
        snap, "StellarFormationTime", "InitialMass", selection_mask=mask_mass_range, 
        part_type=[4])
    
    print(np.concatenate(grouped_data["StellarFormationTime"]).size)
    for sft in grouped_data["StellarFormationTime"][:10]:
        print("    ", sft.size)
    
    # Convert formation time scale factor to age of the universe, and for each
    # subhalo, find the formation time of its first star particle:
    star_form_time = [cosmo.age(1/sft - 1).value 
                      for sft in grouped_data["StellarFormationTime"]]
    sf_onset = np.array([np.min(subhalo_sft) if subhalo_sft.size > 0 else 
                         cosmo.age(0).value for subhalo_sft in star_form_time])
        
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_comp.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_comp.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    print(name, np.sum(np.logical_or.reduce(masks_sat)))
    mask_lum, mask_dark = dataset_comp.split_luminous(snap)
    mask_vmax = [dataset_comp.prune_vmax(snap, low_lim=down, up_lim=up) 
                 for down, up in zip(vmax_down_lim, vmax_up_lim)]
    
    # Add to the data dictionary:
    data[name]["StarFormationOnset"] =\
     {"satellites": [sf_onset[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask]
     )] for mask in mask_vmax],
      "isolated": [sf_onset[np.logical_and.reduce(
          [mask_isol, mask_lum, mask]
      )] for mask in mask_vmax]
     }

## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]

In [None]:
cols = ["grey", "pink"]

In [None]:
linestyles = ["solid", "dashed", "dotted"]

In [None]:
# Construct saving location:
filename = 'SF_onset_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
filename += ".png"
    
home = os.path.dirname(snapshot.__file__)
path = os.path.join(home,"Figures", "MediumResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:  
    ax.invert_xaxis()
    
    ax.set_xlabel('Lookback Time [Gyr]', fontsize=16)
    ax.set_ylabel('Number Density', fontsize=16)
    
axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    print("\n {} \n".format(name))
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]    
    mask = np.logical_and(sim_data["Selections"]["Satellite"],
                          sim_data["Selections"]["Luminous"])
    for (vmax_down, vmax_up), ls in zip(vmax_ranges, linestyles):
        vran_str = "{}-{}".format(vmax_down, vmax_up)
        mask_vmax = np.logical_and(mask, sim_data["Selections"][vran_str])
        
        print("{} ({}):  {}".format(vran_str, ls, np.sum(mask_vmax)))
        # print(sf_onset[mask_vmax])
        y, bin_edges = np.histogram(sf_onset[mask_vmax], bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], linestyle=ls)
    
    # Plot isolated:
    sf_onset = sim_data["StarFormationOnset"]    
    mask = np.logical_and(sim_data["Selections"]["Isolated"],
                          sim_data["Selections"]["Luminous"])
    for (vmax_down, vmax_up), ls in zip(vmax_ranges, linestyles):
        vran_str = "{}-{}".format(vmax_down, vmax_up)
        mask_vmax = np.logical_and(mask, sim_data["Selections"][vran_str])
        
        print("{} ({}):  {}".format(vran_str, ls, np.sum(mask_vmax)))
        # print(sf_onset[mask_vmax])
        y, bin_edges = np.histogram(sf_onset[mask_vmax], bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], linestyle=ls)
    
#     # Plot satellites:
#     sf_onset = sim_data["StarFormationOnset"]["isolated"]
#     labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
#               .format(down, up) if up < 10**5 else 
#               "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
#               for down, up in zip(vmax_down_lim, vmax_up_lim)]
#     line_styles = ['-', '--']
    
#     # Iterate over v_max selections:
#     for age, label, lstyle in zip(sf_onset, labels, line_styles):
#         n_subhalos = age.size
#         label = "{} ({}): ".format(name, n_subhalos) + label
#         print(name, n_subhalos)
        
#         y, bin_edges = np.histogram(age, bins, density=True)
#         y = y * bin_width
#         x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
#         axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
# axes[0].legend()
# axes[1].legend()

# plt.savefig(filename, dpi=200)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:  
    ax.set_xlabel('Age of the Universe [Gyr]', fontsize=16)
    ax.set_ylabel('Number density', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]["satellites"]
    print(sf_onset)
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]["isolated"]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
axes[0].legend()
axes[1].legend()

plt.savefig(filename, dpi=200)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:  
    ax.set_xlabel('Age of the Universe [Gyr]', fontsize=16)
    ax.set_ylabel('Number density', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]["satellites"]
    print(sf_onset)
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]["isolated"]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
axes[0].legend()
axes[1].legend()

plt.savefig(filename, dpi=200)