## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)

# The Environmental Influence on Satellites

Here, we demonstrate the effects of tidal stripping on satellite galaxies.

---

### plain-LCDM-LR

Set the envelope file path, and define the centrals at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim= simulation.Simulation("V1_LR_fix", env_path=env_path)

m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)
snap_id_z0 = 127
snap_id_ref = 127

---

### plain-LCDM

Set the envelope file path, and define the centrals at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))
sim= simulation.Simulation("V1_MR_fix", env_path=env_path)

m31_id_z0 = (1, 0)
mw_id_z0 = (2, 0)
snap_id_z0 = 127
snap_id_ref = 127

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
# If the simulations are not already linked:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)
mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

Get Subhalo objects:

In [None]:
# Trace subhalos and get the M31 and the MW Subhalo objects:
sub_dict = sim.trace_subhalos(snap_start, snap_stop)

m31 = sub_dict[snap_id_z0][
    sim.get_snapshot(snap_id_z0).index_of_halo(m31_id_z0[0], m31_id_z0[1])
]
mw = sub_dict[snap_id_z0][
    sim.get_snapshot(snap_id_z0).index_of_halo(mw_id_z0[0], mw_id_z0[1])
]

---

## Retrieve the Datasets

First, get redshifts and lookback times at the snapshots:

In [None]:
# Define the cosmology (should be the same for each simulation):
H0 = sim.get_snapshot(snap_id_z0).get_attribute("HubbleParam", "Header")
Om0 = sim.get_snapshot(snap_id_z0).get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
# Get snapshot redshifts and the respective lookback times:
redshift = sim.get_attribute("Redshift", "Header", snap_ids)
lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])

The following cell is likely to take some time (it needs to read the given datasets from all the snapshots, and file retrievals take time):

In [None]:
# Get the datasets in a dictionary, with items for each snapshot data:
mass_dict = {sid: m * units.g.to(units.Msun) for sid, m in
        sim.get_subhalos(snap_ids, "Mass").items()}
vmax_dict = {sid: vm[:, 0] * units.cm.to(units.km) for sid, vm in
        sim.get_subhalos(snap_ids, "Max_Vcirc", h5_group="Extended").items()}
sm_dict = {sid: m * units.g.to(units.Msun) for sid, m in
           sim.get_subhalos(snap_ids, "Stars/Mass").items()}

r_m31_dict = {sid: d * units.cm.to(units.kpc)
              for sid, d in m31.distance_to_self(snap_ids).items()}
r_mw_dict = {sid: d * units.cm.to(units.kpc)
             for sid, d in mw.distance_to_self(snap_ids).items()}

Make masking arrays for subhalos at `snap_ref`:

In [None]:
sat_low_lim = 10
isol_low_lim = 10

# Masking arrays for subhalos at snap_ref:
snap_ref = sim.get_snapshot(snap_id_ref)
mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
ref_masks = {
    "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
    "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
    "Luminous": mask_lum,
    "Dark": mask_dark
}

# Get masking arrays for satellites (at z=0):
m31_id = m31.get_group_number_at_snap(snap_id_ref)
mw_id = mw.get_group_number_at_snap(snap_id_ref)
mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
    sim.get_snapshot(snap_id_ref), m31_id, mw_id, sat_r=300, comov=True
)

ref_masks.update({
    "M31_Satellites": mask_m31,
    "MW_Satellites": mask_mw,
    "Isolated": mask_isol
})

In addition, define a function for selecting a random subset from a given masking array:

In [None]:
def random_mask(mask, n):
    """ From the selection prescribed by ´mask´, select ´n´ items at random. """
    k = np.sum(mask)
    mask_rand = np.full(k, False)
    mask_rand[:min(n, k)] = True
    np.random.shuffle(mask_rand)

    mask_new = np.full(mask.size, False)
    mask_new[mask] = mask_rand
    
    return mask_new

---

## Plot Satellite Radii and Mass Evolution

Create arrays of time arrays of each dataset for each subhalo in `snap_ref` and save in the dictionary `subh_arrs`:

In [None]:
# From the full datasets, read M31 satellite data and add to the data dictionary:
subs = sub_dict[snap_id_ref]

# For each satellite, get it's indices in the `snap_ids` array:
inds = [np.searchsorted(snap_ids, sub.get_snap_ids()) for sub in subs]
subh_arrs = {
    "Redshift": np.array([redshift[idx_list] for idx_list in inds], dtype=object),
    "LookbackTime": np.array([lookback_time[idx_list] for idx_list in inds], dtype=object),
    "Mass": np.array([dataset_comp.subhalo_dataset_from_dict(sub, mass_dict)[0] 
                      for sub in subs], dtype=object),
    "Vmax": np.array([dataset_comp.subhalo_dataset_from_dict(sub, vmax_dict)[0] 
                      for sub in subs], dtype=object),
    "M31_Distance": np.array([dataset_comp.subhalo_dataset_from_dict(sub, r_m31_dict)[0]
                              for sub in subs], dtype=object)
}

# Also, add selection for non-volatile (subhalos that survive over sufficiently many snap
vol_n = 3
ref_masks["NonVolatile"] = np.array([z_arr.size > vol_n for z_arr in subh_arrs["Redshift"]])

---

## Plot M31 Satellites

First, select the random sample of luminous M31 satellites that are plotted:

In [None]:
mask_m31_lum = random_mask(np.logical_and.reduce([
    ref_masks["M31_Satellites"], 
    ref_masks["Vmax_Sat"],
    ref_masks["NonVolatile"], 
    ref_masks["Luminous"]
]), 10)

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 10,
              'axes.labelsize': 12,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'legend.fontsize': 10}

# Get the default color map:
cmap = plt.get_cmap("tab10")

In [None]:
# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

fig, axes = plt.subplots(nrows=2, figsize=(4, 6), sharex=True)
plt.subplots_adjust(hspace=0.05)

axes[0].invert_xaxis()
axes[1].set_xlabel("Lookback Time [Gyr]")

axes[0].set_ylabel("Distance to M31 [kpc]")
axes[1].set_ylabel("$v_\mathrm{max}$ [km/s]")

# Plot position of snap_ref:
# idx_ref = np.searchsorted(snap_ids, snap_id_ref)
# axes[0].axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)
# axes[1].axvline(lookback_time[idx_ref], c='black', linestyle='dotted', alpha=0.5)

# Plot 300ckpc (in the background):
a = sim.get_attribute('Time', 'Header', snap_ids)
z = sim.get_attribute('Redshift', 'Header', snap_ids)
lbt = cosmo.age(0).value - np.array([cosmo.age(rs).value for rs in z])
axes[0].plot(lbt, a * 300, c='gray', linestyle='--')

Plot radii of some dark satellites in the background:

In [None]:
mask_m31_dark = random_mask(np.logical_and.reduce([
    ref_masks["M31_Satellites"], 
    ref_masks["Vmax_Sat"],
    ref_masks["NonVolatile"], 
    ref_masks["Dark"]
]), 15)

In [None]:
# Plot dark:
for i, (r, z) in enumerate(zip(subh_arrs["M31_Distance"][mask_m31_dark], 
                               subh_arrs["LookbackTime"][mask_m31_dark])):

    # Plot cubic interpolating functions of the data points:
    f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
    z_new = np.linspace(min(z), max(z), num=1000)
    axes[0].plot(z_new, f(z_new), c='gray', alpha=0.5, lw=0.5)

fig

Plot radii of some luminous satellites:

In [None]:
for i, (r, z) in enumerate(zip(subh_arrs["M31_Distance"][mask_m31_lum], 
                               subh_arrs["LookbackTime"][mask_m31_lum])):

    # Plot cubic interpolating functions of the data points:
    f = interp1d(z, np.linalg.norm(r, axis=1), kind='cubic')
    z_new = np.linspace(min(z), max(z), num=1000)
    axes[0].plot(z_new, f(z_new), c=cmap(i))

fig

Plot the max. circ. velocities of these luminous satellites below:

In [None]:
for i, (vmax, time) in enumerate(zip(subh_arrs["Vmax"][mask_m31_lum], 
                                     subh_arrs["LookbackTime"][mask_m31_lum])):

    axes[1].plot(time, vmax, c=cmap(i))
    
fig

In [None]:
# axes[1].set_yscale('log')
# axes[1].set_ylim(5*10**7, 3*10**10)

# for i, (vmax, time) in enumerate(zip(subh_arrs["Mass"][mask_m31_lum], 
#                                      subh_arrs["LookbackTime"][mask_m31_lum])):

#     axes[1].plot(time, vmax, c=cmap(i))
    
# fig

Set low y-axis limits to zero:

In [None]:
axes[0].set_ylim(0, 700)
axes[1].set_ylim(0, axes[1].get_ylim()[1])

fig

---

## Save the Figure

In [None]:
filename = "m31_tidal_stripping.png"    
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')

---

## Plot $v_\mathrm{max}$ at Infall

In [None]:
fallin_m31, fallin_mw = simtrace.get_fallin_times_lg(
    sim, m31, mw, snap_start, snap_stop, first_infall=True
)
vmax_fallin_m31,_ = dataset_comp.get_subhalos_at_fallin(
    sub_dict[snap_id_ref], fallin_m31, vmax_dict
)
vmax_fallin_mw,_ = dataset_comp.get_subhalos_at_fallin(
    sub_dict[snap_id_ref], fallin_mw, vmax_dict
)
vmax_fallin = np.where(~np.isnan(vmax_fallin_m31), 
                                       vmax_fallin_m31,
                                       vmax_fallin_mw)

sm_fallin_m31,_ = dataset_comp.get_subhalos_at_fallin(
    sub_dict[snap_id_ref], fallin_m31, sm_dict
)
sm_fallin_mw,_ = dataset_comp.get_subhalos_at_fallin(
    sub_dict[snap_id_ref], fallin_mw, sm_dict
)
sm_fallin = np.where(~np.isnan(sm_fallin_m31), 
                           sm_fallin_m31,
                           sm_fallin_mw)

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 10,
              'axes.labelsize': 12,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'legend.fontsize': 10}

# Get the default color map:
cmap = plt.get_cmap("tab10")

s = 15
s_back = 3
c_back = "black"
a_infall = 0.3

In [None]:
# Set fonts:
plt.rcParams.update(parameters)
# plt.tight_layout()

fig, ax = plt.subplots(figsize=(4, 4))

ax.set_yscale('log')
ax.set_xlabel("$v_\mathrm{max}$ [km/s]")
ax.set_ylabel("$M_* [\mathrm{M}_\odot]$")

Plot all the other galaxies in the background:

In [None]:
mask = np.logical_and.reduce([
    np.logical_not(mask_m31_lum),
    np.logical_or(ref_masks["M31_Satellites"],
                  ref_masks["MW_Satellites"]),
    ref_masks["Vmax_Sat"],
    ref_masks["Luminous"]])
    
x = vmax_fallin[mask]
y = sm_fallin[mask]
ax.scatter(x, y, c=c_back, alpha=a_infall, s=s_back)

x = vmax_dict[snap_id_ref][mask]
y = sm_dict[snap_id_ref][mask]
ax.scatter(x, y, c=c_back, s=s_back)

fig

Plot the selected galaxies:

In [None]:
ax.scatter(vmax_dict[snap_id_ref][mask_m31_lum], sm_dict[snap_id_ref][mask_m31_lum],
           c=np.arange(10), cmap=cmap, s=s)

ax.scatter(vmax_fallin_m31[mask_m31_lum], sm_fallin_m31[mask_m31_lum],
           c=np.arange(10), cmap=cmap, s=s, alpha=a_infall)

fig

Add legend:

In [None]:
ax.scatter([], [], c=c_back, s=s_back, label="$z=0$")
ax.scatter([], [], c=c_back, alpha=a_infall, s=s_back, label="$z_\mathrm{infall}$")
ax.legend(loc="lower right")
fig

---

## Save the Figure

In [None]:
filename = "sm_vs_vmax_with_stripping.png"    
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')