In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import dataset_compute
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(curve_fit)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]
paths = ["", "/media/kassiili/USBFREE/LG_simulations", 
        "/media/kassiili/USBFREE/LG_simulations"]
#paths = ["", "", ""]

In [None]:
# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

In [None]:
snap_id = 126
sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
data = {}
for name, sim_id, sim_path, m31_ns, mw_ns in zip(names, sim_ids, paths, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name,
                                                    sim_path=sim_path),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    sf_times = snap.get_subhalos("InitialMassWeightedBirthZ")
           
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    print(name, np.sum(np.logical_or.reduce(masks_sat)))
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_nonzero_vmax = dataset_compute.prune_vmax(snap)
    
    # Add separate datasets for each subhalo to the data dictionary:
    data[name] = \
    {"StellarFormationTime": \
     {"satellites": sf_times[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask_nonzero_vmax]
    )],
      "isolated": sf_times[np.logical_and.reduce(
          [mask_isol, mask_lum, mask_nonzero_vmax]
      )]
     }
    }


## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]
cols = ["grey", "pink"]

In [None]:
# Construct saving location:
filename = 'star_formation_z_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
filename += ".png"
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures", "MediumResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

fig.suptitle("Mean initial mass weighted star birth redshifts")

# Set axis:
for ax in axes:  
    ax.set_xlabel('$z$', fontsize=16)
    ax.set_ylabel('fraction', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

axes[0].invert_xaxis()
axes[1].invert_xaxis()

# Make satellite histograms and plot:
sf_times = []
weights = []
n_bins = 7
for i, (name, sim_data) in enumerate(data.items()):
    sf_times.append(sim_data["StellarFormationTime"]["satellites"])
    n_subhalos = sf_times[i].size
    print(name, n_subhalos)
    weights.append(np.ones(n_subhalos)/n_subhalos)

bins = np.linspace(0, 1, 20)
labels = ["{} ({})".format(name, n) for name, n in 
          zip(data.keys(), [arr.size for arr in sf_times])]
_ = axes[0].hist(sf_times, n_bins, weights=weights, color=cols, 
                 label=labels)

# Make isolated galaxy histograms and plot:
sf_times = []
weights = []
for i, (name, sim_data) in enumerate(data.items()):
    sf_times.append(sim_data["StellarFormationTime"]["isolated"])
    n_subhalos = sf_times[i].size
    print(name, n_subhalos)
    weights.append(np.ones(n_subhalos)/n_subhalos)

    
labels = ["({})".format(n) for n in 
          [arr.size for arr in sf_times]]
bins = np.linspace(0, 1, 20)
_ = axes[1].hist(sf_times, n_bins, weights=weights, color=cols, 
                label=labels)

axes[0].legend(loc="upper left")
axes[1].legend(loc="upper left")

plt.savefig(filename, dpi=200)