In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value
import importlib

import simulation
import simulation_tracing
import dataset_compute
import subhalo

In [None]:
importlib.reload(simulation)
importlib.reload(simulation_tracing)
importlib.reload(dataset_compute)
importlib.reload(subhalo)

# Stackplot of satellite fates

## Setting variables

Let us first define our simulation and the LG central galaxies. The M31 and MW galaxies have identifiers (1,0) and (1,1) at redshift zero (snap_id=127), respectively:

In [None]:
sim = simulation.Simulation("V1_LR_fix")
m31 = subhalo.SubhaloTracer(sim, 127, 1, 0)
mw = subhalo.SubhaloTracer(sim, 127, 2, 0)

Set the range of snapshots considered, and build the merger tree:

In [None]:
snap_start = 100
snap_stop = 128

In [None]:
mtree = simulation_tracing.MergerTree(sim, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

In [None]:
# Trace centrals:
m31.trace(mtree)
mw.trace(mtree)

In [None]:
print(type(m31.get_identifier(126)[0]))

In [None]:
print(m31.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))

Make tracers for the traced snapshots:

In [None]:
traced_snaps = np.arange(100, 128)
snap_tracers = {}
for snap_id in traced_snaps:
    snap_tracer = simulation_tracing.SnapshotTracer(snap_id, mtree)
    snap_tracer.trace(start=snap_start, stop=snap_stop)
    snap_tracers[snap_id] = snap_tracer.tracer_array

## Selecting the satellites

In [None]:
for snap_id, tracer in snap_tracers.items():
    snapshot = sim.get_snapshot(snap_id)
    mask_sat,_ = dataset_compute.split_satellites_by_distance(
        snapshot, m31.get_identifier(snap_id), mw.get_identifier(snap_id))
    mask_sat = np.logical_or.reduce(mask_sat)
    
    # Exclude the smallest subhalos:
    mask_vmax = dataset_compute.prune_vmax(snapshot, low_lim=15)
    
    snap_tracers[snap_id] = tracer[np.logical_and(mask_sat, mask_vmax)]

## Count satellites

At each snapshot, count the number of satellites originating from the previous snapshots:

In [None]:
# Initialize satellite (contribution) counters:
snap_sat_cnt = np.zeros((len(traced_snaps), len(traced_snaps)))

for i, snap in enumerate(traced_snaps):
    tracer = snap_tracers[snap]
    prev_tracer = snap_tracers[traced_snaps[i-1]]
    
    # Count new, accumulated satellites at snap:
    if i == 0:
        mask_new_sat = np.array([True] * np.size(tracer, axis=0))
    else:
        mask_new_sat = np.logical_not(np.isin(
            tracer[:, snap], prev_tracer[:, snap]
        ))
    snap_sat_cnt[i, i] = np.sum(mask_new_sat)
     
    # Iterate through the followings snapshots, keeping track of
    # the surviving satellites that originate from snap:
    mask_surviving = mask_new_sat
    for j, snap_next in enumerate(traced_snaps[i+1:], i+1):
        mask_surviving = np.logical_and(
            mask_surviving, 
            np.isin(tracer[:, snap_next], 
                    snap_tracers[snap_next][:, snap_next])
        )
    
        snap_sat_cnt[i, j] = np.sum(mask_surviving)

## Plot

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

toolight = int(traced_snaps.size / 4)
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))
# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, snap_sat_cnt, colors=colors, edgecolor='white', 
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::3])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::3]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 160)

ax.text(0.2, 0.9, "$v_\mathrm{{max}} > 15 \mathrm{{km/s}}$", horizontalalignment='center',
        verticalalignment='center', transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

In [None]:
fig, ax = plt.subplots()


y = np.array([cnt[-1]  for cnt in snap_sat_cnt])
y = y / np.sum(y)
x = age
ax.plot(x,y)