## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

In [None]:
snap_id_ref = 127
snap_id_z0 = 127

---

### MR Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM": {
        "Simulation": simulation.Simulation("V1_MR_fix", env_path=env_path),
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    },
    "curv-p082": {
        "Simulation": simulation.Simulation("V1_MR_curvaton_p082_fix", env_path=env_path),
        "Color": ['red', 'pink'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 1)
    }
}

---

### LR Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path),
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    },
    "curv-p082-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p082_fix", env_path=env_path),
        "Color": ['red', 'pink'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 1)
    },
    "curv-p084-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p084_fix", env_path=env_path),
        "Color": ['blue', 'lightblue'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 0)
    }
}

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path),
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    }
}

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

Link all subhalos in the simulation, create Subhalo objects for all the individual subhalos found, and write pointers to these objects for each snapshot:

In [None]:
def test_sf_onset1(sf_onset, sm):
    """ Check that, at each snapshot, subhalos with onset time in the past 
    have nonzero stellar mass. """
    
    for sid in sf_onset.keys():
        mask_seton = ~np.isnan(sf_onset[sid])
        mask_lum = (sm[sid] > 0)
        if not np.all(mask_seton == mask_lum):
            return False
    
    return True

def test_sf_onset2(sf_onset, sm, subhalo_dict, test_snap_id):
    mask = ~np.isnan(sf_onset[test_snap_id])
    sm_test = np.full((np.sum(mask), 2), 0)
    
    # Iterate over subhalos, in test_snap_id, which have stars,
    # and their SF onset times:
    for i, (onset_snap_id, sub) in enumerate(zip(sf_onset[test_snap_id][mask],
                                                 subhalo_dict[test_snap_id][mask])):
        
        # If onset_snap_id is not the formation snapshot of the subhalo, get stellar mass 
        # in onset_snap_id and the preceding snapshot:
        if onset_snap_id > sub.index_at_formation()[1]:
            sm_test[i] = [
                sm[onset_snap_id - 1][sub.get_index_at_snap(onset_snap_id - 1)],
                sm[onset_snap_id][sub.get_index_at_snap(onset_snap_id)]
            ]
        # Else, set stellar mass in previous snapshot as 0:
        else:
            sm_test[i] = [0, sm[onset_snap_id][sub.get_index_at_snap(onset_snap_id)]]
    
    # Check that, at onset time, stellar mass is non-zero:
    if not np.all(sm_test[:,1] > 0):
        return False
    
    # Check that, just before onset time, stellar mass is zero:
    if not np.all(sm_test[:,0] == 0):
        return False
    
    return True

In [None]:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)

for sim_data in data.values():
    sim = sim_data["Simulation"]

    # If the simulations are not already linked:
    mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_stop)

    # Trace subhalos and get the M31 and the MW Subhalo objects:
    sub_dict = sim.trace_subhalos(snap_start, snap_stop)
    sim_data["Subhalos"] = sub_dict

In [None]:
# Define the cosmology (should be the same for each simulation):
sim = list(data.values())[0]["Simulation"]
H0 = sim.get_snapshot(snap_id_z0).get_attribute("HubbleParam", "Header")
Om0 = sim.get_snapshot(snap_id_z0).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]

    # Get snapshot redshifts and the respective lookback times:
    redshift = sim.get_attribute("Redshift", "Header", snap_ids)
    lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])
    sim_data.update({
        "Redshift": redshift,
        "LookbackTime": lookback_time
    })
    
    sm = {sid: m * units.g.to(units.Msun) for sid, m in
          sim.get_subhalos(snap_ids, "Stars/Mass").items()}
    sim_data["StellarMass"] = sm
    
    # Get SF onset times:
    onset_snaps = simtrace.sf_onset_times(sim, snap_start, snap_stop)
    inds = {sid: np.where(np.isnan(ot), 0, np.searchsorted(snap_ids, ot))
            for sid, ot, in onset_snaps.items()}
    sim_data["SF_Onset_Time"] = {
        sid: np.where(np.isnan(onset_snaps[sid]), np.nan, lookback_time[idxl])
        for sid, idxl in inds.items()
    }

    # Test onset times:
    # print(test_sf_onset1(onset_snaps, sm))
    # print(test_sf_onset2(onset_snaps, sm, sub_dict, 110))
    
    # Get formation times:
    form_snaps = simtrace.creation_times(sub_dict)
    inds = {sid: np.where(np.isnan(ft), 0, np.searchsorted(snap_ids, ft))
            for sid, ft, in form_snaps.items()}
    sim_data["Formation_Time"] = {
        sid: np.where(np.isnan(form_snaps[sid]), np.nan, lookback_time[idxl])
        for sid, idxl in inds.items()
    }
    
    # Get fall-in times, for satellites (assumes no intersection between m31 and MW satellites):
    m31 = sub_dict[snap_id_z0][sim.get_snapshot(snap_id_z0).index_of_halo(
        sim_data["M31_z0"][0], sim_data["M31_z0"][1]
    )]
    mw = sub_dict[snap_id_z0][sim.get_snapshot(snap_id_z0).index_of_halo(
        sim_data["MW_z0"][0], sim_data["MW_z0"][1]
    )]
    
    fallin_snaps_m31, fallin_snaps_mw = simtrace.get_fallin_times_lg(
        sim, m31, mw, snap_start, snap_stop
    )
    fallin_snaps = {sid: np.where(~np.isnan(fsm31), fsm31, fsmw) for (sid, fsm31), fsmw 
                    in zip(fallin_snaps_m31.items(), fallin_snaps_mw.values())}
    inds = {sid: np.where(np.isnan(fs), 0, np.searchsorted(snap_ids, fs))
            for sid, fs in fallin_snaps.items()}
    sim_data["Fallin_Time"] = {
        sid: np.where(np.isnan(fallin_snaps[sid]), np.nan, lookback_time[idxl])
        for sid, idxl in inds.items()
    }

In [None]:
sat_low_lim = 10
isol_low_lim = 10
vol_n = 3

for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]

    # Masking arrays for subhalos at snap_ref:
    snap_ref = sim.get_snapshot(snap_id_ref)
    mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
    sim_data["Ref_Selections"] = {
        "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
        "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
        "Luminous": mask_lum,
        "Dark": mask_dark,
        "NonVolatile": np.array([z_arr.size > vol_n for z_arr in sim_data["Redshift"]])
    }

    # Get masking arrays for satellites (at z=0):
    m31_id_z0 = sim_data["M31_z0"]
    mw_id_z0 = sim_data["MW_z0"]
    m31 = sub_dict[snap_id_z0][
    sim.get_snapshot(snap_id_z0).index_of_halo(m31_id_z0[0], m31_id_z0[1])
    ]
    mw = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(mw_id_z0[0], mw_id_z0[1])
    ]
    m31_id = m31.get_group_number_at_snap(snap_id_ref)
    mw_id = mw.get_group_number_at_snap(snap_id_ref)
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        sim.get_snapshot(snap_id_ref), m31_id, mw_id, sat_r=300
    )

    sim_data["Ref_Selections"].update({
        "M31_Satellites": mask_m31,
        "MW_Satellites": mask_mw,
        "LG_Satellites": np.logical_or(mask_m31, mask_mw),
        "Isolated": mask_isol
    })

In [None]:
sim_data = list(data.values())[0]

fig, ax = plt.subplots()

ax.set_xscale('log')

mask = np.logical_and(sim_data["Ref_Selections"]["LG_Satellites"],
                      sim_data["Ref_Selections"]["Luminous"])
print(np.sum(mask))
ms = 8
offset = 0.05
ax.scatter(sim_data["StellarMass"][snap_id_ref][mask], 
           sim_data["SF_Onset_Time"][snap_id_ref][mask] + offset, 
           s=ms, label="SF Onset")
ax.scatter(sim_data["StellarMass"][snap_id_ref][mask], 
           sim_data["Formation_Time"][snap_id_ref][mask] - offset,
           s=ms, label="Formation")
ax.scatter(sim_data["StellarMass"][snap_id_ref][mask], 
           sim_data["Fallin_Time"][snap_id_ref][mask],
           s=ms, label="Fall-in")

ax.legend()