## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import dataset_comp
import subhalo

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(dataset_comp)
importlib.reload(subhalo)

# Accumulation of Satellite Subhalos



## Setting variables

Let us first define our simulation and the LG central galaxies. The M31 and MW galaxies have identifiers (1,0) and (1,1) at redshift zero (snap_id=127), respectively:

In [None]:
sim_id = "V1_MR_curvaton_p082_fix"
sim_name = "p082"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (1,1)

In [None]:
sim_id = "V1_LR_curvaton_p082_fix"
sim_name = "p082"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (1,1)

In [None]:
sim_id = "V1_MR_fix"
sim_name = "LCDM"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (2,0)

In [None]:
sim_id = "V1_LR_fix"
sim_name = "LCDM"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (2,0)

In [None]:
sim = simulation.Simulation(sim_id)
print(sim.get_snap_ids())

In [None]:
m31 = subhalo.SubhaloTracer(sim, snap_ref, m31_ref[0], m31_ref[1])
mw = subhalo.SubhaloTracer(sim, snap_ref, mw_ref[0], mw_ref[1])

Set the range of snapshots considered, and build the merger tree:

In [None]:
snap_start = 101
snap_stop = 128

In [None]:
s = sim.get_snapshot(101)
print(s.grp_file)

In [None]:
mtree = simtrace.MergerTree(sim, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

In [None]:
# Trace centrals:
m31.trace(mtree)
mw.trace(mtree)

## Merger Events

Find all merger events by iterating backwards in time:

In [None]:
print(sim.sim_id)

In [None]:
snap_stop=100
for sid in range(127, snap_stop, -1):
    snap = sim.get_snapshot(sid)
    prog = snap.get_subhalos('Progenitors', mtree.h5_group)
#     print(prog.shape)
    mask_merger = np.logical_or(prog[:,1] != mtree.no_match, 
                                prog[:,2] != mtree.no_match)
    print(sid, np.sum(mask_merger))
    print(np.size(prog, axis=0), np.sum(prog[:,0] != mtree.no_match))

## ... Moving on

In [None]:
print(m31.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))

In [None]:
print(m31.get_fof_data("Group_R_Crit200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_Crit200") * units.cm.to(units.kpc))

In [None]:
print(m31.get_fof_data("Group_R_TopHat200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_TopHat200") * units.cm.to(units.kpc))

In [None]:
print(mtree.simulation.get_snapshot(127).get_subhalo_number())

In [None]:
print("/".join([mtree.h5_group, mtree.branching]))
print(type(mtree.no_match))

In [None]:
import datafile_oper

print(datafile_oper.path_to_extended())

Make tracers for the traced snapshots:

In [None]:
traced_snaps = np.arange(snap_start, snap_stop)
snap_tracers = {}
for snap_id in traced_snaps:
    snap_tracer = simtrace.SnapshotTracer(snap_id, mtree)
    snap_tracer.trace(start=snap_start, stop=snap_stop)
    snap_tracers[snap_id] = snap_tracer.tracer_array

## Selecting the satellites

In [None]:
print(m31.get_identifier(101))
print(mw.get_identifier(101))

In [None]:
# CHECK PERIODIC WRAP: is the h scale exponent right in boxsize?
for snap_id, tracer in snap_tracers.items():
    print(snap_id)
    snapshot = sim.get_snapshot(snap_id)
    mask_sat,_ = dataset_comp.split_satellites_by_distance(
        snapshot, m31.get_identifier(snap_id), mw.get_identifier(snap_id))
    mask_sat = np.logical_or.reduce(mask_sat)
    
    # Exclude the smallest subhalos:
    mask_vmax = dataset_comp.prune_vmax(snapshot, low_lim=15)
    
    snap_tracers[snap_id] = tracer[np.logical_and(mask_sat, mask_vmax)]

## Count satellites

At each snapshot, count the number of satellites originating from the previous snapshots:

In [None]:
# Initialize satellite (contribution) counters:
snap_sat_cnt = np.zeros((len(traced_snaps), len(traced_snaps)))

for i, snap in enumerate(traced_snaps):
    tracer = snap_tracers[snap]
    prev_tracer = snap_tracers[traced_snaps[i-1]]
    
    # Count new, accumulated satellites at snap:
    if i == 0:
        mask_new_sat = np.array([True] * np.size(tracer, axis=0))
    else:
        mask_new_sat = np.logical_not(np.isin(
            tracer[:, snap], prev_tracer[:, snap]
        ))
    snap_sat_cnt[i, i] = np.sum(mask_new_sat)
     
    # Iterate through the followings snapshots, keeping track of
    # the surviving satellites that originate from snap:
    mask_surviving = mask_new_sat
    for j, snap_next in enumerate(traced_snaps[i+1:], i+1):
        mask_surviving = np.logical_and(
            mask_surviving, 
            np.isin(tracer[:, snap_next], 
                    snap_tracers[snap_next][:, snap_next])
        )
    
        snap_sat_cnt[i, j] = np.sum(mask_surviving)

## Plot

In [None]:
# Construct saving location:
filename = 'satellite_fates_stack_from_{}_{}'.format(snap_start, sim_name)
    
home = os.path.dirname(simulation.__file__)
path = os.path.join(home,"Figures", "LowResolution")
filename = os.path.join(path, filename)

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

toolight = int(traced_snaps.size / 4)
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))
# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, snap_sat_cnt, colors=colors, edgecolor='white', 
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::4])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::4]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 100)

ax.text(0.2, 0.9, "$v_\mathrm{{max}} > 15 \mathrm{{km/s}}$", horizontalalignment='center',
        verticalalignment='center', transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

# plt.savefig(filename, dpi=200)

In [None]:
fig, ax = plt.subplots()


y = np.array([cnt[-1]  for cnt in snap_sat_cnt])
y = y / np.sum(y)
x = age
ax.plot(x,y)