## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

# Evolution Histories of the M31 and MW satellites around the time of fall-in


---

### plain-LCDM-LR

Choose the simulations, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

sim = simulation.Simulation("V1_LR_fix", env_path=env_path)

snap_z0 = 127
m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
# If the simulations are not already linked:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)
mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

In [None]:
# Trace subhalos and get the M31 and the MW Subhalo objects:
sub_dict = sim.trace_subhalos(snap_start, snap_stop)
m31 = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(m31_id_z0[0], m31_id_z0[1])
]
mw = sub_dict[snap_z0][
    sim.get_snapshot(snap_z0).index_of_halo(mw_id_z0[0], mw_id_z0[1])
]

In [None]:
fallin_snaps_m31, fallin_snaps_mw = simtrace.get_fallin_times_lg(
    sim, m31, mw, snap_start, snap_stop, first_infall=True
)

In [None]:
fallin_inds_m31 = dataset_comp.index_at_fallin(sub_dict, fallin_snaps_m31)
fallin_inds_mw = dataset_comp.index_at_fallin(sub_dict, fallin_snaps_mw)

## Check Fallin Times

First, simply check fall-in array shapes:

In [None]:
shapes_match = [(fallin_snaps_m31[snap_id].size == sim.get_snapshot(snap_id).get_subhalo_number()) and
                (fallin_snaps_mw[snap_id].size == sim.get_snapshot(snap_id).get_subhalo_number()) 
                for snap_id in snap_ids]
print(np.all(shapes_match))

The fall-in snapshot arrays should have non-NaN values for exactly those elements that represent satellites (except if we use the first infall):

In [None]:
snap_check = [101, 104, 105, 111, 120, 127]

In [None]:
for snap_id in snap_check:
    m31_id = m31.get_group_number_at_snap(snap_id)
    mw_id = mw.get_group_number_at_snap(snap_id)
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        sim.get_snapshot(snap_id), m31_id, mw_id, sat_r=300, comov=True
    )
    # mask_sats2, mask_isol_2 = dataset_comp.split_satellites_by_distance_old(
    #     sim.get_snapshot(snap_id), m31_id_z0, mw_id_z0, max_dist_sat=300
    # )

    print(np.all(mask_m31 == ~np.isnan(fallin_snaps_m31[snap_id])))
    print(np.all(mask_mw == ~np.isnan(fallin_snaps_mw[snap_id])))

### Inspect coordinates at fall-in

Coordinates at fall-in should, trivially, be at a circle of radius 300 ckpc around the central. That is, except for those satellites that form as satellite, or which already were satellites at `min(snap_ids)`.

In [None]:
mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
    sim.get_snapshot(snap_z0), m31_id_z0, mw_id_z0, sat_r=300, comov=True
)
mask_sats2, mask_isol_2 = dataset_comp.split_satellites_by_distance_old(
    sim.get_snapshot(snap_z0), m31_id_z0, mw_id_z0, max_dist_sat=300
)

In [None]:
dist_m31 = {sid: np.linalg.norm(d, axis=1) * units.cm.to(units.kpc)
            for sid, d in m31.distance_to_self(snap_ids).items()}
dist_mw = {sid: np.linalg.norm(d, axis=1) * units.cm.to(units.kpc)
            for sid, d in mw.distance_to_self(snap_ids).items()}

Try the two different methods:

In [None]:
m31_sats_z0 = sub_dict[snap_z0][mask_m31]
print(dataset_comp.get_subhalos_at_fallin(m31_sats_z0, fallin_snaps_m31, dist_m31))

In [None]:
fallin_dist_m31 = dataset_comp.data_at_fallin(fallin_snaps_m31, fallin_inds_m31, dist_m31)
print(fallin_dist_m31[snap_z0][mask_m31])

In [None]:
fallin_dist_mw = dataset_comp.data_at_fallin(fallin_snaps_mw, fallin_inds_mw, dist_mw)
print(fallin_dist_mw[snap_z0][mask_mw])

## Plot radial evolution with fall-in

Define masking arrays to select satellites of M31 and MW and random sample of isolated galaxies:

In [None]:
def random_mask(mask, n):
    """ From the selection prescribed by ´mask´, select ´n´ items at random. """
    k = np.sum(mask)
    mask_rand = np.full(k, False)
    mask_rand[:min(n, k)] = True
    np.random.shuffle(mask_rand)

    mask_new = np.full(mask.size, False)
    mask_new[mask] = mask_rand
    
    return mask_new

In [None]:
redshift = {snap.snap_id: np.full(snap.get_subhalo_number(), snap.get_attribute("Redshift", "Header"))
            for snap in sim.get_snapshots(snap_ids)}
fallin_z_m31 = dataset_comp.data_at_fallin(fallin_snaps_m31, fallin_inds_m31, redshift)
fallin_z_mw = dataset_comp.data_at_fallin(fallin_snaps_mw, fallin_inds_mw, redshift)

In [None]:
m31_sats = sub_dict[snap_z0][random_mask(mask_m31, 15)]
mw_sats = sub_dict[snap_z0][random_mask(mask_mw, 15)]

data = {
    "M31_Satellites": {
        "Redshift": [dataset_comp.subhalo_dataset_from_dict(sat, redshift)[0]
                     for sat in m31_sats],
        "Distance": [dataset_comp.subhalo_dataset_from_dict(sat, dist_m31)[0] 
                     for sat in m31_sats],
        "D_at_fallin": [dataset_comp.subhalo_dataset_from_dict(sat, fallin_dist_m31)[0] 
                        for sat in m31_sats],
        "Z_at_fallin": [dataset_comp.subhalo_dataset_from_dict(sat, fallin_z_m31)[0] 
                        for sat in m31_sats]
    },
    "MW_Satellites": {
        "Redshift": [dataset_comp.subhalo_dataset_from_dict(sat, redshift)[0]
                     for sat in mw_sats],
        "Distance": [dataset_comp.subhalo_dataset_from_dict(sat, dist_mw)[0] 
                     for sat in mw_sats],
        "D_at_fallin": [dataset_comp.subhalo_dataset_from_dict(sat, fallin_dist_mw)[0] 
                        for sat in mw_sats],
        "Z_at_fallin": [dataset_comp.subhalo_dataset_from_dict(sat, fallin_z_mw)[0] 
                        for sat in mw_sats]
    }
       }

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()

# Plot 300ckpc:
a = sim.get_attribute('Time', 'Header', snap_ids)
z = sim.get_attribute('Redshift', 'Header', snap_ids)
ax.plot(z, a * 300, c='gray', linestyle='--')

for r, z in zip(data["M31_Satellites"]["Distance"],
                data["M31_Satellites"]["Redshift"]):
    ax.plot(z, r)
#     ax.scatter(z, r, s=10)
    
for rf, zf in zip(data["M31_Satellites"]["D_at_fallin"],
                  data["M31_Satellites"]["Z_at_fallin"]):
    ax.scatter(zf, rf, c='black', s=15)

# for subhalo in sub_dict[snap_z0][mask_m31]:
#     d, sub_snaps = dataset_comp.get_subhalo_from_data_dict(subhalo, dist_m31)
#     z = sim.get_redshifts(sub_snaps)
#     ax.plot(z,d)

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()

# Plot 300ckpc:
a = sim.get_attribute('Time', 'Header', snap_ids)
z = sim.get_attribute('Redshift', 'Header', snap_ids)
ax.plot(z, a * 300, c='gray', linestyle='--')

for r, z in zip(data["MW_Satellites"]["Distance"],
                data["MW_Satellites"]["Redshift"]):
    ax.plot(z, r)
#     ax.scatter(z, r, s=10)
    
for rf, zf in zip(data["MW_Satellites"]["D_at_fallin"],
                  data["MW_Satellites"]["Z_at_fallin"]):
    ax.scatter(zf, rf, c='black', s=15)
