## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

# Linking Subhalos at Present Backward in Time

How far back in time are the subhalos that we see at present linked? How strongly does this depend on the following distinctions:
- Isolated subhalos vs. satellites
- Luminous vs. dark
- $v_\mathrm{max} < 15 \mathrm{km/s}$ vs. $v_\mathrm{max} > 15 \mathrm{km/s}$ (or $v_\mathrm{max} < 10 \mathrm{km/s}$ vs. $v_\mathrm{max} > 10 \mathrm{km/s}$ for satellites)?
- power-law inflation IC:s vs. two period inflation IC:s?

These are the questions explored in this notebook. I will make a simple figure showing the fraction of a certain subset of subhalos, chosen at $z=0$, that is present at different times in the past.

In [None]:
snap_id_ref = 100
snap_id_z0 = 127

---

### MR Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM": {
        "Simulation": simulation.Simulation("V1_MR_fix", env_path=env_path),
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    },
    "curv-p082": {
        "Simulation": simulation.Simulation("V1_MR_curvaton_p082_fix", env_path=env_path),
        "Color": ['red', 'pink'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 1)
    }
}

---

### LR Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path),
        "Color": ['black', 'gray'],
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    },
    "curv-p082-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p082_fix", env_path=env_path),
        "Color": ['red', 'pink'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 1)
    },
    "curv-p084-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p084_fix", env_path=env_path),
        "Color": ['blue', 'lightblue'],
        "M31_z0": (1, 0),
        "MW_z0": (1, 0)
    }
}

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

Link all subhalos in the simulation, create Subhalo objects for all the individual subhalos found, and write pointers to these objects for each snapshot:

In [None]:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)

for sim_data in data.values():
    sim = sim_data["Simulation"]

    # If the simulations are not already linked:
    mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_stop)

    # Trace subhalos and get the M31 and the MW Subhalo objects:
    sub_dict = sim.trace_subhalos(snap_start, snap_stop)
    sim_data["Subhalos"] = sub_dict

Get the M31 and the MW halos and compute masking arrays for their satellites (and isolated subhalos) at ´snap_id_ref´:

In [None]:
for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]
    
    # Get the M31 subhalo:
    m31_id = sim_data["M31_z0"]
    m31 = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(m31_id[0], m31_id[1])
    ]
    sim_data["M31"] = m31 
    
    # Get the MW subhalo:
    mw_id = sim_data["MW_z0"]
    mw = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(mw_id[0], mw_id[1])
    ]
    sim_data["MW"] = mw
    
    # Get masking arrays for satellites (at z=z_ref):
    m31_id_ref = m31.get_group_number_at_snap(snap_id_ref)
    mw_id_ref = mw.get_group_number_at_snap(snap_id_ref)
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        sim.get_snapshot(snap_id_ref), m31_id_ref, mw_id_ref, sat_r=300, isol_r=2000, comov=True
    )
    print(np.sum(mask_m31), np.sum(mask_mw), np.sum(mask_isol))
    sim_data["Ref_Selections"] = {"M31_Satellites": mask_m31,
                                  "MW_Satellites": mask_mw,
                                  "LG_Satellites": np.logical_or(mask_m31, mask_mw),
                                  "Isolated": mask_isol}

---

## Retrieve the Datasets

Read all datasets into dictionaries by snapshot:

In [None]:
# Define the cosmology (should be the same for each simulation):
for sim_data in data.values():
    H0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("HubbleParam", "Header")
    Om0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("Omega0", "Header")
#     print(H0, Om0)
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
sat_low_lim = 10
isol_low_lim = 10

for sim_data in data.values():
    sim = sim_data["Simulation"]
    
    # Get snapshot redshifts and the respective lookback times:
    redshift = sim.get_attribute("Redshift", "Header", snap_ids)
    lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])
    sim_data["Redshift"] = redshift
    sim_data["LookbackTime"] =  lookback_time

    # Find latest snapshot, for each subhalo present at snap_ref:
    sim_data["DestructionSnapshot"] = np.array([sub.index_at_destruction()[1] for sub in 
                                              sim_data["Subhalos"][snap_id_ref]])
    # Find earliest snapshot, for each subhalo present at snap_ref:
    sim_data["FormationSnapshot"] = np.array([sub.index_at_formation()[1] for sub in 
                                              sim_data["Subhalos"][snap_id_ref]])
    
    # Masking arrays for subhalos at snap_ref:
    snap_ref = sim.get_snapshot(snap_id_ref)
    mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
    sim_data["Ref_Selections"].update({
        "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
        "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
        "Luminous": mask_lum,
        "Dark": mask_dark
    })

---

## Plot Counts of Linked Subhalos

The linking counts are plotted separately for satellites and isolated subhalos. Below, in the case of the different plot lines, $N_{[\cdot]}(z)$ stands for 
- the total number of (satellite of isolated) subhalos N_tot
- \# subhalos with $v_\mathrm{max} > v^*$
- \# (luminous) galaxies with $v_\mathrm{max} > v^*$, 

where $v^*$ = 10 km/s for satellites and $v^*$ = 15 km/s for isolated.

### Set Plot Parameters

Define a function for counting:

In [None]:
def count_subhalos(formation_snap, destruction_snap, absolute=True):
    counts = np.array([
        np.sum(np.logical_and(formation_snap <= snap_id, destruction_snap >= snap_id))
        for snap_id in snap_ids
    ])
    if absolute:
        return counts
    
    return counts / counts[-1]

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 12,
              'axes.labelsize': 10,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'legend.fontsize': 10}

### Create Blank Plot

In [None]:
# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

fig, axes = plt.subplots(ncols=2, figsize=(7, 4), sharey=True, sharex=True)
plt.subplots_adjust(wspace=0.05)

axes[0].invert_xaxis()
axes[0].set_yscale('log')
axes[0].set_ylim(0.005, 1.25)

for ax in axes:
    ax.yaxis.set_ticks_position('both')
    ax.set_xlabel("Lookback Time [Gyr]")

axes[0].set_ylabel("$N_{[\cdot]}(z) ~/~ N_\mathrm{tot}(z_\mathrm{ref})$")
axes[0].set_title("Satellite Subhalos")
axes[1].set_title("Isolated Subhalos")

### ... And Plot

In [None]:
for sim_name, sim_data in data.items():
    time = sim_data["LookbackTime"]
    form_snap = sim_data["FormationSnapshot"]  
    dest_snap = sim_data["DestructionSnapshot"] 
    
    # Plot Satellites
    # ---------------
    
    # Plot all:
    mask = sim_data["Ref_Selections"]["LG_Satellites"]
    sat_cnt = count_subhalos(form_snap[mask], dest_snap[mask])
    idx_ref = np.nonzero(snap_ids == snap_id_ref)[0][0]
    sat_num = sat_cnt[idx_ref]
    
    axes[0].plot(time, sat_cnt/sat_num, c=sim_data["Color"][1], 
                 linestyle="dotted")
    
    # Plot those above the satellite mass limit:
    mask = np.logical_and(sim_data["Ref_Selections"]["LG_Satellites"],
                          sim_data["Ref_Selections"]["Vmax_Sat"])
    mass_sat_cnt = count_subhalos(form_snap[mask], dest_snap[mask])
    axes[0].plot(time, mass_sat_cnt / sat_num, 
                 c=sim_data["Color"][1])
    
    # ...out of those, plot luminous:
    mask = np.logical_and.reduce([sim_data["Ref_Selections"]["LG_Satellites"],
                                  sim_data["Ref_Selections"]["Vmax_Sat"],
                                  sim_data["Ref_Selections"]["Luminous"]])
    mass_lum_sat_cnt = count_subhalos(form_snap[mask], dest_snap[mask])
    axes[0].plot(time, mass_lum_sat_cnt / sat_num,
                 c=sim_data["Color"][0])

    
    # print("Total number of satellites in {}: {}".format(sim_name, sat_num))
    print("{} Satellites: \n".format(sim_name) + \
          "\t Total number: {} \n".format(sat_num) + \
          "\t Fraction traced: {} \n".format(sat_cnt[-1] / sat_cnt[0]) + \
          "\t ...of massive: {} \n".format(mass_sat_cnt[-1] / mass_sat_cnt[0]) + \
          "\t ...of massive and luminous: {} \n".format(mass_lum_sat_cnt[-1] / mass_lum_sat_cnt[0]))
    
    
    # Plot Isolated
    # -------------
    
    # Plot all:
    mask = sim_data["Ref_Selections"]["Isolated"]
    isol_cnt = count_subhalos(form_snap[mask], dest_snap[mask], absolute=True)
    isol_num = isol_cnt[idx_ref]
    
    print("Total number of isolated subhalos in {}: {}".format(sim_name, isol_num))
    
    axes[1].plot(time, isol_cnt / isol_num, c=sim_data["Color"][1], linestyle='dotted')
    
    # Plot those above the satellite mass limit:
    mask = np.logical_and(sim_data["Ref_Selections"]["Isolated"],
                          sim_data["Ref_Selections"]["Vmax_Isol"])
    mass_isol_cnt = count_subhalos(form_snap[mask], dest_snap[mask])
    axes[1].plot(time, mass_isol_cnt / isol_num,
                 c=sim_data["Color"][1])
    
    # ...out of those, plot luminous:
    mask = np.logical_and.reduce([sim_data["Ref_Selections"]["Isolated"],
                                  sim_data["Ref_Selections"]["Vmax_Isol"],
                                  sim_data["Ref_Selections"]["Luminous"]])
    mass_lum_isol_cnt = count_subhalos(form_snap[mask], dest_snap[mask])
    axes[1].plot(time, mass_lum_isol_cnt / isol_num,
                 c=sim_data["Color"][0])
    
    # print("Total number of satellites in {}: {}".format(sim_name, sat_num))
    print("{} Isolated galaxies: \n".format(sim_name) + \
          "\t Total number: {} \n".format(isol_num) + \
          "\t Fraction traced: {} \n".format(isol_cnt[-1] / isol_cnt[0]) + \
          "\t ...of massive: {} \n".format(mass_isol_cnt[-1] / mass_isol_cnt[0]) + \
          "\t ...of massive and luminous: {} \n".format(mass_lum_isol_cnt[-1] / mass_lum_isol_cnt[0]))

fig

### Add Legends

In [None]:
dummy_plots = []
leg_col = list(data.values())[0]["Color"]
dummy_lines = [
    axes[0].plot([], [], c=leg_col[1], linestyle='dotted')[0],
    axes[0].plot([], [], c=leg_col[1])[0],
    axes[0].plot([], [], c=leg_col[0])[0]
]

# leg_labels = ["Total",
#               "$v_\mathrm{max} > v^*$",
#               "$v_\mathrm{max} > v^*$ and luminous"]

leg_labels = ["Total",
              "$v_\mathrm{{max}} > {} ~\mathrm{{km/s}}$".format(sat_low_lim),
              "$v_\mathrm{{max}} > {} ~\mathrm{{km/s}}$ and luminous".format(sat_low_lim)]

axes[0].legend(dummy_lines, leg_labels, loc="lower left")

dummy_lines = []
leg_labels = []
for sim_name, sim_data in data.items():
    dummy_lines.append(
        axes[1].plot([], [], c=sim_data["Color"][0])[0]
    )
    leg_labels.append(sim_name)
    
axes[1].legend(dummy_lines, leg_labels, loc="best")

fig

---

## Save the Figure

In [None]:
filename = "linking_counts_refsnap{}".format(snap_id_ref)
# for sim_name in data.keys():
#     filename += '_{}'.format(sim_name)
filename += '.png'
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')