## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

In [None]:
snap_id_ref = 127
snap_id_z0 = 127

---

### LR Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path),
        "ColorMap": plt.cm.Blues,
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    }
}

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "curv-p082-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p082_fix", env_path=env_path),
        "ColorMap": plt.cm.Reds,
        "Color": ['red', 'pink'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    }
}

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

Link all subhalos in the simulation, create Subhalo objects for all the individual subhalos found, and write pointers to these objects for each snapshot:

In [None]:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)

for sim_data in data.values():
    sim = sim_data["Simulation"]

    # If the simulations are not already linked:
    mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_stop)

    # Trace subhalos and get the M31 and the MW Subhalo objects:
    sub_dict = sim.trace_subhalos(snap_start, snap_stop)
    sim_data["Subhalos"] = sub_dict

Get the M31 and the MW halos and compute masking arrays for their satellites (and isolated subhalos) at ´snap_id_ref´:

In [None]:
for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]
    
    # Get the M31 subhalo:
    m31_id = sim_data["M31_z0"]
    m31 = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(m31_id[0], m31_id[1])
    ]
    sim_data["M31"] = m31 
    
    # Get the MW subhalo:
    mw_id = sim_data["MW_z0"]
    mw = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(mw_id[0], mw_id[1])
    ]
    sim_data["MW"] = mw
    
    # Get masking arrays for satellites (at z=z_ref):
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        sim.get_snapshot(snap_id_ref), m31_id, mw_id, sat_r=300, isol_r=2000, comov=True
    )
    sim_data["Ref_Selections"] = {"M31_Satellites": mask_m31,
                                  "MW_Satellites": mask_mw,
                                  "LG_Satellites": np.logical_or(mask_m31, mask_mw),
                                  "Isolated": mask_isol}

---

## Retrieve the Datasets

Read all datasets into dictionaries by snapshot:

In [None]:
# Define the cosmology (should be the same for each simulation):
for sim_data in data.values():
    H0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("HubbleParam", "Header")
    Om0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("Omega0", "Header")
#     print(H0, Om0)
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
sat_low_lim = 10
isol_low_lim = 15

for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]
    
    # Get snapshot redshifts and the respective lookback times:
    redshift = sim.get_attribute("Redshift", "Header", snap_ids)
    lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])
    sim_data["Redshift"] = redshift
    sim_data["LookbackTime"] = lookback_time
    
    # Get v_max at snap_ref and at the time of fallin:
    vmax_dict = {snap_id: vmax_arr[:, 0] * units.cm.to(units.km) for snap_id, vmax_arr in
                 sim.get_subhalos(snap_ids, "Max_Vcirc", "Extended").items()}
    sim_data["Vmax"] = vmax_dict[snap_id_ref]
    
    fallin_m31, fallin_mw = simtrace.get_fallin_times_lg(
        sim, sim_data["M31"], sim_data["MW"], snap_start, snap_stop
    )
    
    vmax_fallin_m31, snap_id_fallin_m31 = dataset_comp.get_subhalos_at_fallin(
        sub_dict[snap_id_ref], fallin_m31, vmax_dict
    )
    
    vmax_fallin_mw, snap_id_fallin_mw = dataset_comp.get_subhalos_at_fallin(
        sub_dict[snap_id_ref], fallin_mw, vmax_dict
    )
    
    sim_data["Vmax_Fallin_M31"] = vmax_fallin_m31
    sim_data["Vmax_Fallin_MW"] = vmax_fallin_mw
    sim_data["Vmax_Fallin"] = np.where(~np.isnan(vmax_fallin_m31), 
                                       vmax_fallin_m31,
                                       vmax_fallin_mw)
                                       
    snap_id_fallin = np.where(~np.isnan(snap_id_fallin_m31),
                              snap_id_fallin_m31,
                              snap_id_fallin_mw)
    inds = np.searchsorted(snap_ids, snap_id_fallin)
    inds[inds == snap_ids.size] = -1
    sim_data["Time_Fallin"] = np.where(inds != -1, lookback_time[inds], np.nan)
    
    # Get subhalo formation times:
    sim_data["Time_Formation"] = simtrace.creation_times(sub_dict)[snap_id_ref]
    
    # Masking arrays for subhalos at snap_ref:
    snap_ref = sim.get_snapshot(snap_id_ref)
    mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
    sim_data["Ref_Selections"].update({
        "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
        "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
        "Luminous": mask_lum,
        "Dark": mask_dark
    })

In [None]:
fig, ax = plt.subplots()

sim_name, sim_data = list(data.items())[0]

mask = sim_data["Ref_Selections"]["Isolated"]
ax.scatter(sim_data["Time_Formation"][mask], sim_data["Vmax"][mask])

Trace back further (need more data)?