In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time
from astropy.cosmology import FlatLambdaCDM

import snapshot_obj
import dataset_compute
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(curve_fit)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"
vmax_down_lim = [20,30]
vmax_up_lim = [30, 10**5]

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    
    H0 = snap.get_attribute("HubbleParam", "Header") * 100
    Om0 = snap.get_attribute("Omega0", "Header")
    cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
    
    # Read star particle formation times for star particles of 
    # each subhalo. Select only star particles with masses in the
    # given range:
    initial_mass = snap.get_particles("InitialMass", part_type=[4]) \
                      * units.g.to(units.Msun)
    mask_mass_range = np.logical_and(initial_mass > 10**3, 
                                     initial_mass < 10**8)
    grouped_data = dataset_compute.group_selected_particles_by_subhalo(
        snap, "StellarFormationTime", "InitialMass", selection_mask=mask_mass_range, 
        part_type=[4])
    
    print(np.concatenate(grouped_data["StellarFormationTime"]).size)
    for sft in grouped_data["StellarFormationTime"][:10]:
        print("    ", sft.size)
    
    # Convert formation time scale factor to age of the universe, and for each
    # subhalo, find the formation time of its first star particle:
    star_form_time = [cosmo.age(1/sft - 1).value 
                      for sft in grouped_data["StellarFormationTime"]]
    sf_onset = np.array([np.min(subhalo_sft) if subhalo_sft.size > 0 else 
                         cosmo.age(0).value for subhalo_sft in star_form_time])
        
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    print(name, np.sum(np.logical_or.reduce(masks_sat)))
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_vmax = [dataset_compute.prune_vmax(snap, low_lim=down, up_lim=up) 
                 for down, up in zip(vmax_down_lim, vmax_up_lim)]
    
    # Add to the data dictionary:
    data[name]["StarFormationOnset"] =\
     {"satellites": [sf_onset[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask]
     )] for mask in mask_vmax],
      "isolated": [sf_onset[np.logical_and.reduce(
          [mask_isol, mask_lum, mask]
      )] for mask in mask_vmax]
     }

## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]

In [None]:
cols = ["grey", "pink"]

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:  
    ax.set_xlabel('Age of the Universe [Gyr]', fontsize=16)
    ax.set_ylabel('Number density', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    # Plot satellites:
    sf_onset = sim_data["StarFormationOnset"]["satellites"]
    print(sf_onset)
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
    
    # Plot satellites:
    sf_onset = sim_data["InitialMassWeightedAge"]["isolated"]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for age, label, lstyle in zip(sf_onset, labels, line_styles):
        n_subhalos = age.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(age, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
axes[0].legend()
axes[1].legend()