In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value
import importlib

import simulation
import snapshot_obj
import simulation_tracing
import dataset_compute
import subhalo

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
snap_id = 127
sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

In [None]:
H0 = data["LCDM"]["snapshot"].get_attribute("HubbleParam", "Header") * 100
Om0 = data["LCDM"]["snapshot"].get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"
vmax_down_lim = [20,30]
vmax_up_lim = [30, 10**5]

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    
    H0 = snap.get_attribute("HubbleParam", "Header") * 100
    Om0 = snap.get_attribute("Omega0", "Header")
    cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
    
    # Read star particle formation times for star particles of 
    # each subhalo. Select only star particles with masses in the
    # given range:
    initial_mass = snap.get_particles("InitialMass", part_type=[4]) \
                      * units.g.to(units.Msun)
    mask_mass_range = np.logical_and(initial_mass > 10**3, 
                                     initial_mass < 10**8)
    grouped_data = dataset_compute.group_selected_particles_by_subhalo(
        snap, "StellarFormationTime", "InitialMass", selection_mask=mask_mass_range, 
        part_type=[4])
    
    print(np.concatenate(grouped_data["StellarFormationTime"]).size)
    for sft in grouped_data["StellarFormationTime"][:10]:
        print("    ", sft.size)
    
    # Convert formation time scale factor to stellar age, and weight by
    # star particle initial mass:
    stellar_age = [cosmo.age(0).value - cosmo.age(1/sft - 1).value 
                 for sft in grouped_data["StellarFormationTime"]]
    stellar_age = np.array(stellar_age)
    mass_weighted_stellar_age = np.array([np.sum(m*t)/np.sum(m) for m, t in 
                                   zip(grouped_data["InitialMass"], stellar_age)])
        
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    print(name, np.sum(np.logical_or.reduce(masks_sat)))
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_vmax = [dataset_compute.prune_vmax(snap, low_lim=down, up_lim=up) 
                 for down, up in zip(vmax_down_lim, vmax_up_lim)]
    
    # Add separate datasets for each subhalo to the data dictionary:
    data[name]["separated"] = \
    {"StellarFormationTime": \
     {"satellites": [stellar_age[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask]
    )] for mask in mask_vmax],
      "isolated": [stellar_age[np.logical_and.reduce(
          [mask_isol, mask_lum, mask]
      )] for mask in mask_vmax]
     }
    }
    
    # Combine datasets of subhalos and add to the data dictionary:
    data[name]["combined"] = \
    {"StellarFormationTime": \
     {"satellites": np.concatenate(
        data[name]["separated"]["StellarFormationTime"]["satellites"]),
      "isolated": np.concatenate(
        data[name]["separated"]["StellarFormationTime"]["isolated"])
     },
     "InitialMassWeightedAge": \
     {"satellites": [mass_weighted_stellar_age[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask]
     )] for mask in mask_vmax],
      "isolated": [mass_weighted_stellar_age[np.logical_and.reduce(
          [mask_isol, mask_lum, mask]
      )] for mask in mask_vmax]
     }
    }

## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]

In [None]:
cols = ["grey", "pink"]

In [None]:
# Construct saving location:
filename = 'stellar_ages'
for name in names:
    filename += "_{}".format(name)
filename += "_{}.png".format(snap_id)
    
home = os.path.dirname(simulation.__file__)
path = os.path.join(home,"Figures", "MediumResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
print(cosmo.age(0).value)
for ax in axes:  
    ax.set_xbound(0, cosmo.age(0).value)
    ax.set_xlabel('Age of the Universe', fontsize=16)
    ax.set_ylabel('Number density', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    # Plot satellites:
    stellar_age = sim_data["combined"]["InitialMassWeightedAge"]["satellites"]
    age = [cosmo.age(0).value - a for a in stellar_age]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for a, label, lstyle in zip(age, labels, line_styles):
        n_subhalos = a.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(a, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
    
    # Plot satellites:
    stellar_age = sim_data["combined"]["InitialMassWeightedAge"]["isolated"]
    age = [cosmo.age(0).value - a for a in stellar_age]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for a, label, lstyle in zip(age, labels, line_styles):
        n_subhalos = a.size
        label = "{} ({}): ".format(name, n_subhalos) + label
        print(name, n_subhalos)
        
        y, bin_edges = np.histogram(a, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
for ax in axes:
    ax2 = ax.twiny()
    ticks = ax.get_xticks()
    ticks[0] = 10**(-2)
#     ticks[-1] = ticks[-1] - 10**(-2)
    ticks[-1] = 13.7
    print(ticks)
    ax2.set_xticks(ticks)
    redshifts = [z_at_value(cosmo.age, age * units.Gyr) for age in ticks]
    ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshifts])
    ax2.set_xlabel('Redshift')
    
axes[0].legend()
axes[1].legend()

plt.savefig(filename, dpi=200)