# Ideas for this notebook

- try looking at evolution histories at, say, snapshot 126: is there a satellite that gets destroyed before snapshot 127? Could this work as an illustrative satellite?
- Select 10 luminous (solid lines) and 10 dark (dashed lines)
- Plot survival times in mass bins (v_max < 15 km/s, ...) and split each bar into luminous and dark

## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

# Evolution Histories of the Subhalos Present at $z=0$

In this notebook, I inspect the origins of the subhalos that are satellites of the central galaxies at $z=0$, and also some isolated subhalos. I will look at their trajectories, and mass evolution.

---

### plain-LCDM-LR

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim= simulation.Simulation("V1_LR_fix", env_path=env_path)

m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)
snap_z0 = 127

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
# If the simulations are not already linked:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)
mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

Get Subhalo objects:

In [None]:
# Trace subhalos and get the M31 and the MW Subhalo objects:
sub_dict = sim.trace_subhalos(snap_start, snap_stop)

m31 = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(m31_id_z0[0], m31_id_z0[1])
]
mw = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(mw_id_z0[0], mw_id_z0[1])
]

---

## Retrieve the Datasets

Read all datasets into dictionaries by snapshot:

In [None]:
redshift = {sid: z for sid, z in zip(
    snap_ids, sim.get_attribute("Redshift", "Header", snap_ids)
)}

# Get the datasets in a dictionary, with items for each snapshot data:
mass = {sid: m * units.g.to(units.Msun) for sid, m in
        sim.get_subhalos(snap_ids, "Mass").items()}
vmax = {sid: vm[:, 0] * units.cm.to(units.km) for sid, vm in
        sim.get_subhalos(snap_ids, "Max_Vcirc", h5_group="Extended").items()}
cop = {sid: c * units.cm.to(units.kpc) for sid, c in
       sim.get_subhalos(snap_ids, "CentreOfPotential").items()}

m31_dist = {sid: np.linalg.norm(d, axis=1) * units.cm.to(units.kpc)
            for sid, d in m31.distance_to_self(snap_ids).items()}
mw_dist = {sid: np.linalg.norm(d, axis=1) * units.cm.to(units.kpc)
           for sid, d in mw.distance_to_self(snap_ids).items()}

Define masking arrays to select satellites of M31 and MW and random sample of isolated galaxies:

In [None]:
# Get masking arrays for satellites (at z=0):
mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
    sim.get_snapshot(snap_z0), m31, mw, sat_r=300
)

# Randomly select ´n_isol´ isolated galaxies:
n_isol = 30
mask_rand = np.full(np.sum(mask_isol), False)
mask_rand[:n_isol] = True
np.random.shuffle(mask_rand)

mask_rand_isol = np.full(mask_isol.size, False)
mask_rand_isol[mask_isol] = mask_rand

Make the selections and write the dataset that are ready for plotting to a dictionary:

In [None]:
# From the full datasets, read M31 satellite data and add to the data dictionary:
data = {}

m31_sats = sub_dict[snap_z0][mask_m31]
data["M31_Satellites"] = {
    "Snap_id": [np.array(sat.get_indices())[1] for sat in m31_sats],
    "Redshift": [np.array([redshift[sid] for _, sid in zip(*sat.get_indices())])
                 for sat in m31_sats],
    "Mass": [dataset_comp.subhalo_dataset_from_dict(sat, mass)[0] for sat in m31_sats],
    "Vmax": [dataset_comp.subhalo_dataset_from_dict(sat, vmax)[0] for sat in m31_sats],
    "Distance": [dataset_comp.subhalo_dataset_from_dict(sat, m31_dist)[0] for sat in m31_sats]
}

# ...Same for MW:
mw_sats = sub_dict[snap_z0][mask_mw]
data["MW_Satellites"] = {
    "Snap_id": [np.array(sat.get_indices())[1] for sat in mw_sats],
    "Redshift": [np.array([redshift[sid] for _, sid in zip(*sat.get_indices())])
                 for sat in mw_sats],
    "Mass": [dataset_comp.subhalo_dataset_from_dict(sat, mass)[0] for sat in mw_sats],
    "Vmax": [dataset_comp.subhalo_dataset_from_dict(sat, vmax)[0] for sat in mw_sats],
    "Distance": [dataset_comp.subhalo_dataset_from_dict(sat, mw_dist)[0] for sat in mw_sats]
}

# ...Same for the randomly selected isolated galaxies:
isol_subs = sub_dict[snap_z0][mask_rand_isol]
data["Isolated"] = {
    "Snap_id": [np.array(sat.get_indices())[1] for sat in isol_subs],
    "Redshift": [np.array([redshift[sid] for _, sid in zip(*sat.get_indices())])
                 for sat in isol_subs],
    "Mass": [dataset_comp.subhalo_dataset_from_dict(sat, mass)[0] for sat in isol_subs],
    "Vmax": [dataset_comp.subhalo_dataset_from_dict(sat, vmax)[0] for sat in isol_subs],
    "MW_Distance": [dataset_comp.subhalo_dataset_from_dict(sat, mw_dist)[0] for sat in isol_subs]
}

## Subhalo Survival Times

In [None]:
surv_time = {sid: np.array([len(sub.indices) for sub in subs])
             for sid, subs in sub_dict.items()}

Plot the number of snapshots, through which each snapshot is traced, against its $v_\mathrm{max}$. Below plot the counts of subhalos for each survival time. 

We see that, by far, most subhalos survive through the entire time range, through which we have done the linking. Only a small fraction is traced for less than 5 snapshots.

In [None]:
fig, axes = plt.subplots(nrows=2)

st = surv_time[snap_z0]
st_unique, st_cnt = np.unique(st, return_counts=True)
v = vmax[snap_z0]

axes[0].scatter(st, v)
axes[1].plot(st_unique, st_cnt)

Let us look at the masses of the subhalos that survive longest vs. those that die shortly, more closely. I divide the subhalos into those that survive the whole linking period, those that survive through less than 3 snapshots, and all in between:

In [None]:
mask_prune = (v < 100)
mask_long_surv = np.logical_and((st == max(st)), mask_prune)
print(np.sum(mask_long_surv))
mask_inter_surv = np.logical_and(np.logical_and((st > 3), (st < max(st))), mask_prune)
print(np.sum(mask_inter_surv))
mask_short_surv = np.logical_and((st <= 3), mask_prune)
print(np.sum(mask_short_surv))

Below are histograms for each of these categories (by  $v_\mathrm{max}$).

In [None]:
bin_edges = np.linspace(10, 100, 50)

In [None]:
fig, axes = plt.subplots(ncols=3, sharey=True, sharex=True, figsize=(10,4))

_ = axes[0].hist(v[mask_short_surv], bins=bin_edges, density=True)
_ = axes[1].hist(v[mask_inter_surv], bins=bin_edges, density=True)
_ = axes[2].hist(v[mask_long_surv], bins=bin_edges, density=True)