# Ideas for this notebook

- try looking at evolution histories at, say, snapshot 126: is there a satellite that gets destroyed before snapshot 127? Could this work as an illustrative satellite?
- Select 10 luminous (solid lines) and 10 dark (dashed lines)
- Plot survival times in mass bins (v_max < 15 km/s, ...) and split each bar into luminous and dark

## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

# Evolution Histories of the Subhalos Present at $z=0$

In this notebook, I inspect the origins of the subhalos that are satellites of the central galaxies at $z=0$, and also some isolated subhalos. I will look at their trajectories, and mass evolution.

---

### plain-LCDM-LR

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim= simulation.Simulation("V1_LR_fix", env_path=env_path)

m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)
snap_id_z0 = 127
snap_id_ref = 127

---

### plain-LCDM

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim= simulation.Simulation("V1_MR_fix", env_path=env_path)

m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)
snap_id_z0 = 127
snap_id_ref = 115

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
# If the simulations are not already linked:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)
mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

Get Subhalo objects:

In [None]:
# Trace subhalos and get the M31 and the MW Subhalo objects:
sub_dict = sim.trace_subhalos(snap_start, snap_stop)

m31 = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(m31_id_z0[0], m31_id_z0[1])
]
mw = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(mw_id_z0[0], mw_id_z0[1])
]

Define masking arrays to select satellites of M31 and MW and random sample of isolated galaxies. Select a reasonable random sample from the isolated galaxies:

---

## Retrieve the Datasets

Read all datasets into dictionaries by snapshot:

In [None]:
# Define the cosmology (should be the same for each simulation):
H0 = sim.get_snapshot(snap_id_z0).get_attribute("HubbleParam", "Header")
Om0 = sim.get_snapshot(snap_id_z0).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
# Get snapshot redshifts and the respective lookback times:
redshift = sim.get_attribute("Redshift", "Header", snap_ids)
lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])

The following cell is likely to take some time (it needs to read the given datasets from all the snapshots, and file retrievals take time):

In [None]:
# Get the datasets in a dictionary, with items for each snapshot data:
mass_dict = {sid: m * units.g.to(units.Msun) for sid, m in
        sim.get_subhalos(snap_ids, "Mass").items()}
vmax_dict = {sid: vm[:, 0] * units.cm.to(units.km) for sid, vm in
        sim.get_subhalos(snap_ids, "Max_Vcirc", h5_group="Extended").items()}
# cop_dict = {sid: c * units.cm.to(units.kpc) for sid, c in
#        sim.get_subhalos(snap_ids, "CentreOfPotential").items()}

r_m31_dict = {sid: d * units.cm.to(units.kpc)
            for sid, d in m31.distance_to_self(snap_ids).items()}
r_mw_dict = {sid: d * units.cm.to(units.kpc)
           for sid, d in mw.distance_to_self(snap_ids).items()}

Make the selections and write the dataset that are ready for plotting to a dictionary:

In [None]:
# From the full datasets, read M31 satellite data and add to the data dictionary:
subs = sub_dict[snap_id_ref]

# For each satellite, get it's indices in the `snap_ids` array:
inds = [np.searchsorted(snap_ids, sub.get_snap_ids()) for sub in subs]
data = {
    "Redshift": np.array([redshift[idx_list] for idx_list in inds], dtype=object),
    "LookbackTime": np.array([lookback_time[idx_list] for idx_list in inds], dtype=object),
    "Mass": np.array([dataset_comp.subhalo_dataset_from_dict(sub, mass_dict)[0] 
                      for sub in subs], dtype=object),
    "Vmax": np.array([dataset_comp.subhalo_dataset_from_dict(sub, vmax_dict)[0] 
                      for sub in subs], dtype=object),
    "M31_Distance": np.array([dataset_comp.subhalo_dataset_from_dict(sub, r_m31_dict)[0]
                              for sub in subs], dtype=object),
    "MW_Distance": np.array([dataset_comp.subhalo_dataset_from_dict(sub, r_mw_dict)[0] 
                             for sub in subs], dtype=object)
}

Now, define masking arrays for these:

In [None]:
sat_low_lim = 10
isol_low_lim = 10
vol_n = 3

# Masking arrays for subhalos at snap_ref:
snap_ref = sim.get_snapshot(snap_id_ref)
mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
data["Ref_Selections"] = {
    "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
    "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
    "Luminous": mask_lum,
    "Dark": mask_dark,
    "NonVolatile": np.array([z_arr.size > vol_n for z_arr in data["Redshift"]])
}

# Get masking arrays for satellites (at z=0):
m31_id = m31.get_group_number_at_snap(snap_id_ref)
mw_id = mw.get_group_number_at_snap(snap_id_ref)
mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
    sim.get_snapshot(snap_id_ref), m31_id, mw_id, sat_r=300
)

data["Ref_Selections"].update({
    "M31_Satellites": mask_m31,
    "MW_Satellites": mask_mw,
    "Isolated": mask_isol
})

In addition, define a function for selecting a random subset from a given masking array:

In [None]:
def random_mask(mask, n):
    """ From the selection prescribed by ´mask´, select ´n´ items at random. """
    k = np.sum(mask)
    mask_rand = np.full(k, False)
    mask_rand[:min(n, k)] = True
    np.random.shuffle(mask_rand)

    mask_new = np.full(mask.size, False)
    mask_new[mask] = mask_rand
    
    return mask_new

## Plot M31 Satellites

In [None]:
# Get the default color map:
cmap = plt.get_cmap("tab10")

In [None]:
mask_m31_lum = random_mask(np.logical_and.reduce([
    data["Ref_Selections"]["M31_Satellites"], 
    data["Ref_Selections"]["Vmax_Sat"],
    data["Ref_Selections"]["NonVolatile"], 
    data["Ref_Selections"]["Luminous"]
]), 10)

mask_m31_dark = random_mask(np.logical_and.reduce([
    data["Ref_Selections"]["M31_Satellites"], 
    data["Ref_Selections"]["Vmax_Sat"],
    data["Ref_Selections"]["NonVolatile"], 
    data["Ref_Selections"]["Dark"]
]), 10)

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()
ax.set_xlabel("Lookback Time [Gyr]")
ax.set_ylabel("Distance to M31 [kpc]")

for i, (r, z) in enumerate(zip(data["M31_Distance"][mask_m31_lum], 
                             data["LookbackTime"][mask_m31_lum])):

    # Plot cubic interpolating functions of the data points:
    f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
    z_new = np.linspace(min(z), max(z), num=1000)
    ax.plot(z_new, f(z_new), c=cmap(i))
   
# Plot position of snap_ref:
idx_ref = np.searchsorted(snap_ids, snap_id_ref)
ax.axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)

# for r, z in zip(data["M31_Distance"][mask_m31_dark], 
#                 data["LookbackTime"][mask_m31_dark]):

#     # Plot cubic interpolating functions of the data points:
#     f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
#     z_new = np.linspace(min(z), max(z), num=1000)
#     ax.plot(z_new, f(z_new), linestyle='dashed')

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()
ax.set_xlabel("Lookback Time [Gyr]")
ax.set_ylabel("$v_\mathrm{max}$ [km/s]")

for i, (vmax, time) in enumerate(zip(data["Vmax"][mask_m31_lum], 
                                     data["LookbackTime"][mask_m31_lum])):

    ax.plot(time, vmax, c=cmap(i))
    

# Plot position of snap_ref:
idx_ref = np.searchsorted(snap_ids, snap_id_ref)
ax.axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)

# for r, z in zip(data["Vmax"][mask_m31_dark], 
#                 data["LookbackTime"][mask_m31_dark]):

#     # Plot cubic interpolating functions of the data points:
#     f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
#     z_new = np.linspace(min(z), max(z), num=1000)
#     ax.plot(z_new, f(z_new), linestyle='dashed')

In [None]:
fig, ax = plt.subplots()

for i, r in enumerate(data["M31_Distance"][mask_m31_lum]):

    ax.plot(r[:,0], r[:,1], c=cmap(i))

# for r, z in zip(data["M31_Distance"][mask_m31_dark], 
#                 data["LookbackTime"][mask_m31_dark]):

#     # Plot cubic interpolating functions of the data points:
#     f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
#     z_new = np.linspace(min(z), max(z), num=1000)
#     ax.plot(z_new, f(z_new), linestyle='dashed')

## Plot MW Satellites

In [None]:
mask_mw_lum = random_mask(np.logical_and.reduce([
    data["Ref_Selections"]["MW_Satellites"], 
    data["Ref_Selections"]["Vmax_Sat"],
    data["Ref_Selections"]["NonVolatile"], 
    data["Ref_Selections"]["Luminous"]
]), 10)

mask_mw_dark = random_mask(np.logical_and.reduce([
    data["Ref_Selections"]["MW_Satellites"], 
    data["Ref_Selections"]["Vmax_Sat"],
    data["Ref_Selections"]["NonVolatile"], 
    data["Ref_Selections"]["Dark"]
]), 10)

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()
ax.set_xlabel("Lookback Time [Gyr]")
ax.set_ylabel("Distance to MW [kpc]")

# Plot dark:
for i, (r, z) in enumerate(zip(data["MW_Distance"][mask_mw_dark], 
                             data["LookbackTime"][mask_mw_dark])):

    # Plot cubic interpolating functions of the data points:
    f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
    z_new = np.linspace(min(z), max(z), num=1000)
    ax.plot(z_new, f(z_new), c='gray', alpha=0.5, lw=0.5)

# Plot luminous:
for i, (r, z) in enumerate(zip(data["MW_Distance"][mask_mw_lum], 
                             data["LookbackTime"][mask_mw_lum])):

    # Plot cubic interpolating functions of the data points:
    f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
    z_new = np.linspace(min(z), max(z), num=1000)
    ax.plot(z_new, f(z_new), c=cmap(i), lw=2)
    
# Plot position of snap_ref:
idx_ref = np.searchsorted(snap_ids, snap_id_ref)
ax.axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()
ax.set_xlabel("Lookback Time [Gyr]")
ax.set_ylabel("$v_\mathrm{max}$ [km/s]")

# Plot dark:
for i, (vmax, time) in enumerate(zip(data["Vmax"][mask_mw_dark], 
                                     data["LookbackTime"][mask_mw_dark])):
    ax.plot(time, vmax, c='gray', alpha=0.5, lw=0.5)

# Plot luminous:
for i, (vmax, time) in enumerate(zip(data["Vmax"][mask_mw_lum], 
                                     data["LookbackTime"][mask_mw_lum])):
    ax.plot(time, vmax, c=cmap(i), lw=2)
    
# Plot position of snap_ref:
idx_ref = np.searchsorted(snap_ids, snap_id_ref)
ax.axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)

In [None]:
fig, ax = plt.subplots()

# Plot dark:
for i, r in enumerate(data["MW_Distance"][mask_mw_dark]):
    ax.plot(r[:,0], r[:,1], c='gray', lw=0.5)

# Plot luminous:
for i, r in enumerate(data["MW_Distance"][mask_mw_lum]):
    ax.plot(r[:,0], r[:,1], c=cmap(i))
    
# for r, z in zip(data["MW_Distance"][mask_mw_dark], 
#                 data["LookbackTime"][mask_mw_dark]):

#     # Plot cubic interpolating functions of the data points:
#     f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
#     z_new = np.linspace(min(z), max(z), num=1000)
#     ax.plot(z_new, f(z_new), linestyle='dashed')