In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time
from astropy import cosmology

import snapshot_obj
import dataset_compute
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(curve_fit)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]
paths = ["", "/media/kassiili/USBFREE/LG_simulations", 
        "/media/kassiili/USBFREE/LG_simulations"]
#paths = ["", "", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

In [None]:
snap_id = 126
sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
data = {}
for name, sim_id, sim_path, m31_ns, mw_ns in zip(names, sim_ids, paths, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name,
                                                    sim_path=sim_path),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"
vmax_down_lim = [20,30]
vmax_up_lim = [30, 10**5]

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    sf_times = snap.get_subhalos("InitialMassWeightedStellarAge") * units.s.to(units.Gyr)
           
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_vmax = [dataset_compute.prune_vmax(snap, low_lim=down, up_lim=up) 
                 for down, up in zip(vmax_down_lim, vmax_up_lim)]
    print("{} satellite galaxies: {}".format(
        name, np.sum(np.logical_and(np.logical_or.reduce(masks_sat), mask_lum))
    ))
    
    print("{} isolated galaxies: {}".format(
        name, np.sum(np.logical_and(mask_isol, mask_lum))
    ))
    
    # Add separate datasets for each subhalo to the data dictionary:
    data[name] = \
    {"StellarAge": \
     {"satellites": [sf_times[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask]
      )] for mask in mask_vmax],
      "isolated": [sf_times[np.logical_and.reduce(
          [mask_isol, mask_lum, mask]
      )] for mask in mask_vmax]
     }
    }


## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]
cols = ["grey", "pink"]

In [None]:
# Construct saving location:
filename = 'stellar_age_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
filename += ".png"
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures", "LowResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

fig.suptitle("Mean initial mass weighted stellar ages")

# Set axis:
for ax in axes:  
    ax.set_xlabel('$Age [\mathrm{Gyr}]$', fontsize=16)
    ax.set_ylabel('fraction', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')


# Set bins:
bin_width = 2
bins = np.arange(0, 16, bin_width)

# Iterate over simulations:
for i, (name, sim_data) in enumerate(data.items()):
    
    # Plot satellites:
    sf_times = sim_data["StellarAge"]["satellites"]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for sft, label, lstyle in zip(sf_times, labels, line_styles):
        n_subhalos = sft.size
        label = "{} ({}): ".format(name, n_subhalos) + label
    
        y, bin_edges = np.histogram(sft, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[0].plot(x, y, c=cols[i], label=label, linestyle=lstyle)
    
    # Plot isolated:
    sf_times = sim_data["StellarAge"]["isolated"]
    labels = ["${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$"\
              .format(down, up) if up < 10**5 else 
              "${} \mathrm{{km/s}} < v_\mathrm{{max}}$".format(down)
              for down, up in zip(vmax_down_lim, vmax_up_lim)]
    line_styles = ['-', '--']
    
    # Iterate over v_max selections:
    for sft, label, lstyle in zip(sf_times, labels, line_styles):
        n_subhalos = sft.size
        label = "{} ({}): ".format(name, n_subhalos) + label
    
        y, bin_edges = np.histogram(sft, bins, density=True)
        y = y * bin_width
        x = [(bin_edges[i] + bin_edges[i + 1]) / 2 for i in range(bin_edges.size - 1)]
        axes[1].plot(x, y, c=cols[i], label=label, linestyle=lstyle)

axes[0].legend(loc="upper left")
axes[1].legend(loc="upper left")


#plt.savefig(filename, dpi=200)

In [None]:
# Add twin axis with fractions (only works with one dataset):
# n_subhalos = sf_times[0].size
# print(np.sum(y / n_subhalos))
# ydown, yup = axes[0].get_ylim()
# ydown = ydown / n_subhalos
# yup = yup / n_subhalos
# yticks = axes[0].get_yticks()[1:-1] / n_subhalos
# print([str(t) for t in yticks])
# twinx = axes[0].twinx()
# twinx.set_ylim(ydown, yup)
# twinx.set_yticks(yticks)