## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(dataset_comp)

# Accumulation of Satellites Through Time

Here, I plot the distribution of satellites by their fall-in time, at each snapshot. This is done with a single stackplot, which shows visually, where the satellites seen today originated.



## Setting variables

Let us first define our simulation and the LG central galaxies at $z=0$:

In [None]:
sim_id = "V1_MR_curvaton_p082_fix"
sim_name = "curv-p082"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (1,1)

In [None]:
sim_id = "V1_LR_curvaton_p082_fix"
sim_name = "curv-p082-LR"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (1,1)

In [None]:
sim_id = "V1_MR_fix"
sim_name = "plain-LCDM"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (2,0)

In [None]:
sim_id = "V1_LR_fix"
sim_name = "plain-LCDM-LR"
snap_ref = 127
m31_ref = (1,0)
mw_ref = (2,0)

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim = simulation.Simulation(sim_id, env_path=env_path)
print(sim.get_snap_ids())

Set the range of snapshots considered, and build the merger tree:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
s = sim.get_snapshot(101)
print(s.group_data.fname)

### Get Centrals as Subhalo Objects

In [None]:
sub_dict = sim.trace_subhalos(snap_start, snap_stop)

mw = sub_dict[snap_ref][
    sim.get_snapshot(snap_ref).index_of_halo(mw_ref[0], mw_ref[1])
]
m31 = sub_dict[snap_ref][
    sim.get_snapshot(snap_ref).index_of_halo(m31_ref[0], m31_ref[1])
]

## Get Satellite Fall-in Times

For all satellites (in all snapshots), get the snapshot ID of the fall-in snapshot, and its index place in that snapshot:

In [None]:
fallin_snaps_m31, fallin_snaps_mw = simtrace.get_fallin_times_lg(
    sim, m31, mw, snap_start, snap_stop
)

fallin_inds_m31 = dataset_comp.index_at_fallin(sub_dict, fallin_snaps_m31)
fallin_inds_mw = dataset_comp.index_at_fallin(sub_dict, fallin_snaps_mw)

Get $v_\mathrm{max}$, at the fall-in time, for all satellites (in all snapshots):

In [None]:
vmax = {snap_id: vm[:,0] * units.cm.to(units.km) for snap_id, vm in 
        sim.get_subhalos(snap_ids, "Max_Vcirc", "Extended").items()}

fallin_vmax_m31 = dataset_comp.data_at_fallin(fallin_snaps_m31, fallin_inds_m31, vmax)
fallin_vmax_mw = dataset_comp.data_at_fallin(fallin_snaps_mw, fallin_inds_mw, vmax)

At each snapshot, count satellites by fall-in times. Exclude from the counting such satellites, that at fall-in were less massive than $v_\mathrm{max} = 10 ~\mathrm{km/s}$ (or whatever is chosen as the limit):

In [None]:
m31_sat_cnts = np.zeros((snap_stop, snap_stop))
mw_sat_cnts = np.zeros((snap_stop, snap_stop))

vmax_down = 40
vmax_up = 100

for snap_id in snap_ids:
    mask_vmax_m31 = ~np.isnan(fallin_vmax_m31[snap_id])
    mask_vmax_m31[mask_vmax_m31] = np.logical_and(
        fallin_vmax_m31[snap_id][mask_vmax_m31] > vmax_down,
        fallin_vmax_m31[snap_id][mask_vmax_m31] < vmax_up
    ) 
    
    m31_sat_cnts[snap_id] = np.bincount(
        fallin_snaps_m31[snap_id][mask_vmax_m31].astype(int), minlength=snap_stop
    )
    
    mask_vmax_mw = ~np.isnan(fallin_vmax_mw[snap_id])
    mask_vmax_mw[mask_vmax_mw] = np.logical_and(
        fallin_vmax_mw[snap_id][mask_vmax_mw] > vmax_down,
        fallin_vmax_mw[snap_id][mask_vmax_mw] < vmax_up
    ) 
    mw_sat_cnts[snap_id] = np.bincount(
       fallin_snaps_mw[snap_id][mask_vmax_mw].astype(int), minlength=snap_stop
    )

In [None]:
all_sat_cnts = m31_sat_cnts + mw_sat_cnts

## Plot M31

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

traced_snaps = np.arange(snap_start, snap_stop)
redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

toolight = int(traced_snaps.size / 4)
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size))
# colors = plt.cm.Reds(np.linspace(0, 1, traced_snaps.size))
# colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))

# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, m31_sat_cnts[snap_start:snap_stop, snap_start:snap_stop].T,
             colors=colors, edgecolor='black',
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::4])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::4]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 2 * np.sum(m31_sat_cnts[127]))

# text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) > {} \mathrm{{km/s}}$".format(vmax_cut)
text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) \in ]{},{}[ ~ \mathrm{{km/s}}$".format(
    vmax_down, vmax_up
)
ax.text(0.1, 0.9, text, horizontalalignment='left', verticalalignment='center',
        transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

# plt.savefig(filename, dpi=200)

In [None]:
# filename = 'M31_satellite_accumulation_vmaxcut{}_{}.png'.format(vmax_cut, sim_name)
filename = 'M31_satellite_accumulation_vmaxcut{}-{}_{}.png'.format(vmax_down, vmax_up, sim_name)
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')

## Plot MW

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

traced_snaps = np.arange(snap_start, snap_stop)
redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

# toolight = int(traced_snaps.size / 4)
# colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size))
# colors = plt.cm.Reds(np.linspace(0, 1, traced_snaps.size))
# colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))
# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, mw_sat_cnts[snap_start:snap_stop, snap_start:snap_stop].T,
             colors=colors, edgecolor='black', 
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::4])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::4]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 1.2 * np.sum(mw_sat_cnts[127]))

# text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) > {} \mathrm{{km/s}}$".format(vmax_cut)
text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) \in ]{},{}[ ~ \mathrm{{km/s}}$".format(
    vmax_down, vmax_up
)
ax.text(0.1, 0.9, text, horizontalalignment='left', verticalalignment='center',
        transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

# plt.savefig(filename, dpi=200)

In [None]:
# filename = 'MW_satellite_accumulation_vmaxcut{}_{}.png'.format(vmax_cut, sim_name)
filename = 'MW_satellite_accumulation_vmaxcut{}-{}_{}.png'.format(vmax_down, vmax_up, sim_name)
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')

## Plot LG

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

traced_snaps = np.arange(snap_start, snap_stop)
redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

toolight = int(traced_snaps.size / 4)
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
# colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size))
# colors = plt.cm.Reds(np.linspace(0, 1, traced_snaps.size))
# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, all_sat_cnts[snap_start:snap_stop, snap_start:snap_stop].T,
             colors=colors, edgecolor='black', 
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::4])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::4]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 1.2 * np.sum(all_sat_cnts[127]))

# text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) > {} \mathrm{{km/s}}$".format(vmax_cut)
text = "$v_\mathrm{{max}}(z_\mathrm{{fall-in}}) \in ]{},{}[ ~ \mathrm{{km/s}}$".format(
    vmax_down, vmax_up
)
ax.text(0.1, 0.9, text, horizontalalignment='left', verticalalignment='center',
        transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

# plt.savefig(filename, dpi=200)

In [None]:
# filename = 'LG_satellite_accumulation_vmaxcut{}_{}.png'.format(vmax_cut, sim_name)
filename = 'LG_satellite_accumulation_vmaxcut{}-{}_{}.png'.format(vmax_down, vmax_up, sim_name)
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')

Check that nowhere satellites fallen in at some given snapshot grow with time. NOTE: if you do this check with a limit on the minimum vmax, it will now work: satellites can lose and gain mass, so that they momentarily fall below the minimum mass limit.

In [None]:
for snap in range(snap_start+1, snap_stop):
    mask = np.array([
        m31_sat_cnts[snap,i] > m31_sat_cnts[snap-1, i]
        for i in range(snap_start, snap_stop)
    ])
    print(np.sum(mask))
    print(mask)

### Plot the Fall-in Time Distribution at Present

In [None]:
# Define the cosmology (should be the same for each simulation):
H0 = sim.get_snapshot(snap_stop-1).get_attribute("HubbleParam", "Header")
Om0 = sim.get_snapshot(snap_stop-1).get_attribute("Omega0", "Header")

cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
fig, ax = plt.subplots()

snap_id = 127
lb_times = np.array([cosmo.age(0).value - cosmo.age(z).value for z in 
                     sim.get_attribute("Redshift", "Header", snap_ids)])

mask_vmax_m31 = ~np.isnan(fallin_vmax_m31[snap_id])
mask_vmax_m31[mask_vmax_m31] = (fallin_vmax_m31[snap_id][mask_vmax_m31] > vmax_cut) 
fallin_cnts_m31 = np.array([np.sum(fallin_snaps_m31[snap_id][mask_vmax_m31] == sid) 
                        for sid in snap_ids])

mask_vmax_mw = ~np.isnan(fallin_vmax_mw[snap_id])
mask_vmax_mw[mask_vmax_mw] = (fallin_vmax_mw[snap_id][mask_vmax_mw] > vmax_cut) 
fallin_cnts_mw = np.array([np.sum(fallin_snaps_mw[snap_id][mask_vmax_mw] == sid) 
                        for sid in snap_ids])

fallin_cnts = fallin_cnts_m31 + fallin_cnts_mw

ax.plot(lb_times, fallin_cnts / np.sum(fallin_cnts))
ax.plot(lb_times, fallin_cnts_m31 / np.sum(fallin_cnts))
ax.plot(lb_times, fallin_cnts_mw / np.sum(fallin_cnts))

print([lb_times[i+1] - lb_times[i] for i in range(len(lb_times) - 1)])
print(np.sum(fallin_cnts))

This is how to get the fall-in times for each snapshot, generally:

In [None]:
redshift = {snap.snap_id: np.full(snap.get_subhalo_number(), 
                                  snap.get_attribute("Redshift", "Header"))
            for snap in sim.get_snapshots(snap_ids)}
lookback_time = {
    snap_id: np.array([cosmo.age(0).value - cosmo.age(z).value for z in z_arr])
    for snap_id, z_arr in redshift.items()
}

In [None]:
fallin_t_m31 = dataset_comp.data_at_fallin(fallin_snaps_m31, fallin_inds_m31, lookback_time)
fallin_t_mw = dataset_comp.data_at_fallin(fallin_snaps_mw, fallin_inds_mw, lookback_time)

In [None]:
# DUM SHIT BELOW

In [None]:
mask = np.logical_not(np.isnan(m31_fallin[120]))
for i, cnt in enumerate(np.bincount(m31_fallin[120][mask].astype(int), minlength=snap_stop)):
    print(i, cnt)

In [None]:
m31_fallin[120].astype(int) 

In [None]:
print(np.bincount(
    np.where(m31_fallin[120].astype(bool), m31_fallin[120], 128)
))

In [None]:
np.where(m31_fallin[120].astype(bool), m31_fallin[120], 0)

In [None]:
print(m31_fallin[120].dtype)

In [None]:
print(np.unique(
    np.where(m31_fallin[120], m31_fallin[120], np.nan),
    return_counts=True))

In [None]:
print(np.unique(m31_fallin[120][m31_fallin[120].astype(bool)],
    return_counts=True))

In [None]:
print(m31_fallin[120].astype(bool))

In [None]:
print(m31_fallin[120])

In [None]:
np.where(m31_fallin[120], m31_fallin[120], np.nan)

## Merger Events

Find all merger events by iterating backwards in time:

In [None]:
print(sim.sim_id)

In [None]:
snap_stop=100
for sid in range(127, snap_stop, -1):
    snap = sim.get_snapshot(sid)
    prog = snap.get_subhalos('Progenitors', mtree.h5_group)
#     print(prog.shape)
    mask_merger = np.logical_or(prog[:,1] != mtree.no_match, 
                                prog[:,2] != mtree.no_match)
    print(sid, np.sum(mask_merger))
    print(np.size(prog, axis=0), np.sum(prog[:,0] != mtree.no_match))

## ... Moving on

In [None]:
print(m31.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_Mean200") * units.cm.to(units.kpc))

In [None]:
print(m31.get_fof_data("Group_R_Crit200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_Crit200") * units.cm.to(units.kpc))

In [None]:
print(m31.get_fof_data("Group_R_TopHat200") * units.cm.to(units.kpc))
print(mw.get_fof_data("Group_R_TopHat200") * units.cm.to(units.kpc))

In [None]:
print(mtree.simulation.get_snapshot(127).get_subhalo_number())

In [None]:
print("/".join([mtree.h5_group, mtree.branching]))
print(type(mtree.no_match))

In [None]:
import datafile_oper

print(datafile_oper.path_to_extended())

Make tracers for the traced snapshots:

In [None]:
traced_snaps = np.arange(snap_start, snap_stop)
snap_tracers = {}
for snap_id in traced_snaps:
    snap_tracer = simtrace.SnapshotTracer(snap_id, mtree)
    snap_tracer.trace(start=snap_start, stop=snap_stop)
    snap_tracers[snap_id] = snap_tracer.tracer_array

## Selecting the satellites

In [None]:
print(m31.get_identifier(101))
print(mw.get_identifier(101))

In [None]:
# CHECK PERIODIC WRAP: is the h scale exponent right in boxsize?
for snap_id, tracer in snap_tracers.items():
    print(snap_id)
    snapshot = sim.get_snapshot(snap_id)
    mask_sat,_ = dataset_comp.split_satellites_by_distance(
        snapshot, m31.get_identifier(snap_id), mw.get_identifier(snap_id))
    mask_sat = np.logical_or.reduce(mask_sat)
    
    # Exclude the smallest subhalos:
    mask_vmax = dataset_comp.prune_vmax(snapshot, low_lim=15)
    
    snap_tracers[snap_id] = tracer[np.logical_and(mask_sat, mask_vmax)]

## Count satellites

At each snapshot, count the number of satellites originating from the previous snapshots:

In [None]:
# Initialize satellite (contribution) counters:
snap_sat_cnt = np.zeros((len(traced_snaps), len(traced_snaps)))

for i, snap in enumerate(traced_snaps):
    tracer = snap_tracers[snap]
    prev_tracer = snap_tracers[traced_snaps[i-1]]
    
    # Count new, accumulated satellites at snap:
    if i == 0:
        mask_new_sat = np.array([True] * np.size(tracer, axis=0))
    else:
        mask_new_sat = np.logical_not(np.isin(
            tracer[:, snap], prev_tracer[:, snap]
        ))
    snap_sat_cnt[i, i] = np.sum(mask_new_sat)
     
    # Iterate through the followings snapshots, keeping track of
    # the surviving satellites that originate from snap:
    mask_surviving = mask_new_sat
    for j, snap_next in enumerate(traced_snaps[i+1:], i+1):
        mask_surviving = np.logical_and(
            mask_surviving, 
            np.isin(tracer[:, snap_next], 
                    snap_tracers[snap_next][:, snap_next])
        )
    
        snap_sat_cnt[i, j] = np.sum(mask_surviving)

## Plot

In [None]:
# Construct saving location:
filename = 'satellite_fates_stack_from_{}_{}'.format(snap_start, sim_name)
    
home = os.path.dirname(simulation.__file__)
path = os.path.join(home,"Figures", "LowResolution")
filename = os.path.join(path, filename)

In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=200)

# ax.invert_xaxis()

redshift = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
     for snap_id in traced_snaps]
    
H0 = sim.get_snapshot(127).get_attribute("HubbleParam", "Header") * 100
Om0 = sim.get_snapshot(127).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=H0, Om0=Om0)
age = [cosmo.age(z).value for z in redshift]

toolight = int(traced_snaps.size / 4)
colors = plt.cm.Blues(np.linspace(0, 1, traced_snaps.size + toolight))[toolight:]
colors = plt.cm.viridis(np.linspace(0, 1, traced_snaps.size))
# ax.stackplot(traced_snaps, snap_sat_cnt, colors=colors, edgecolor='white', 
#              linestyle=':', linewidth=0.3)
# for snap in traced_snaps:
#     ax.axvline(snap, c='black', linestyle=':', linewidth=0.3)

ax.stackplot(age, snap_sat_cnt, colors=colors, edgecolor='white', 
             linestyle=':', linewidth=0.3)
for a in age:
    ax.axvline(a, c='black', linestyle=':', linewidth=0.3)
    
ax2 = ax.twiny()
ax2.set_xticks(age[::4])
ax2.set_xticklabels(['{:.2f}'.format(z) for z in redshift[::4]])

ax.set_xlim(min(age), max(age))
ax2.set_xlim(min(age), max(age))
ax.set_ylim(0, 100)

ax.text(0.2, 0.9, "$v_\mathrm{{max}} > 15 \mathrm{{km/s}}$", horizontalalignment='center',
        verticalalignment='center', transform=ax.transAxes)
ax.set_xlabel('Age of the Universe [Gyr]')
ax2.set_xlabel('Redshift')
ax.set_ylabel('Number of LG satellites')

# plt.savefig(filename, dpi=200)

In [None]:
fig, ax = plt.subplots()


y = np.array([cnt[-1]  for cnt in snap_sat_cnt])
y = y / np.sum(y)
x = age
ax.plot(x,y)