## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
from scipy.interpolate import interp1d

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo
import dataset_comp
import curve_fit

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)
importlib.reload(dataset_comp)
importlib.reload(curve_fit)

# $v_\mathrm{max}$ at Fall-in

To look for a reasonable low-mass limit for satellites subhalo, I inspect the relation of the $v_\mathrm{max}$ values at $z=0$ and at the time, when they fell into orbit of their respective centrals.

In [None]:
snap_id_ref = 127
snap_id_z0 = 127

---

### Medium Resolution Simulations

Set the envelope file path, and define the M31 and the MW at redshift $z=0$:

In [None]:
data = {
    "plain-LCDM": {
        "Simulation": simulation.Simulation("V1_MR_fix"),
        "Color": ['blue', 'lightblue'],
        "Colormap": cm.Blues,
        "M31_z0": (1, 0),
        "MW_z0": (2, 0)
    },
    "curv-p082": {
        "Simulation": simulation.Simulation("V1_MR_curvaton_p082_fix"),
        "Color": ['red', 'pink'],
        "Colormap": cm.Reds,
        "M31_z0": (1, 0),
        "MW_z0": (1, 1)
    }
}

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

Link all subhalos in the simulation, create Subhalo objects for all the individual subhalos found, and write pointers to these objects for each snapshot:

In [None]:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)

for sim_data in data.values():
    sim = sim_data["Simulation"]

    # If the simulations are not already linked:
    mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_stop)

    # Trace subhalos and get the M31 and the MW Subhalo objects:
    sub_dict = sim.trace_subhalos(snap_start, snap_stop)
    sim_data["Subhalos"] = sub_dict

Get the M31 and the MW halos and compute masking arrays for their satellites (and isolated subhalos) at `snap_id_ref`:

In [None]:
for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]
    
    # Get the M31 subhalo:
    m31_id = sim_data["M31_z0"]
    m31 = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(m31_id[0], m31_id[1])
    ]
    sim_data["M31"] = m31 
    
    # Get the MW subhalo:
    mw_id = sim_data["MW_z0"]
    mw = sub_dict[snap_id_z0][
        sim.get_snapshot(snap_id_z0).index_of_halo(mw_id[0], mw_id[1])
    ]
    sim_data["MW"] = mw
    
    # Get masking arrays for satellites (at z=z_ref):
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        sim.get_snapshot(snap_id_ref), m31_id, mw_id, sat_r=300, isol_r=2000, comov=True
    )
    sim_data["Ref_Selections"] = {"M31_Satellites": mask_m31,
                                  "MW_Satellites": mask_mw,
                                  "LG_Satellites": np.logical_or(mask_m31, mask_mw),
                                  "Isolated": mask_isol}

---

## Retrieve the Datasets

Read all datasets into dictionaries by snapshot:

In [None]:
# Define the cosmology (should be the same for each simulation):
for sim_data in data.values():
    H0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("HubbleParam", "Header")
    Om0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("Omega0", "Header")
#     print(H0, Om0)
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
sat_low_lim = 10
isol_low_lim = 10
vmax_lim = 10

for sim_data in data.values():
    sim = sim_data["Simulation"]
    sub_dict = sim_data["Subhalos"]
    
    # Get snapshot redshifts and the respective lookback times:
    redshift = sim.get_attribute("Redshift", "Header", snap_ids)
    lookback_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in redshift])
    sim_data["Redshift"] = redshift
    sim_data["LookbackTime"] = lookback_time
    
    # Get v_max at snap_ref and at the time of fallin:
    vmax_dict = {snap_id: vmax_arr[:, 0] * units.cm.to(units.km) for snap_id, vmax_arr in
                 sim.get_subhalos(snap_ids, "Max_Vcirc", "Extended").items()}
    sim_data["Vmax"] = vmax_dict[snap_id_ref]
    
    fallin_m31, fallin_mw = simtrace.get_fallin_times_lg(
        sim, sim_data["M31"], sim_data["MW"], snap_start, snap_stop, first_infall=True
    )    
    vmax_fallin_m31, snap_id_fallin_m31 = dataset_comp.get_subhalos_at_fallin(
        sub_dict[snap_id_ref], fallin_m31, vmax_dict
    )    
    vmax_fallin_mw, snap_id_fallin_mw = dataset_comp.get_subhalos_at_fallin(
        sub_dict[snap_id_ref], fallin_mw, vmax_dict
    )
    
    sim_data["Vmax_Fallin_M31"] = vmax_fallin_m31
    sim_data["Vmax_Fallin_MW"] = vmax_fallin_mw
    vmax_infall = np.where(~np.isnan(vmax_fallin_m31), 
                           vmax_fallin_m31,
                           vmax_fallin_mw)
    sim_data["Vmax_Fallin"] = vmax_infall
        
    snap_id_fallin = np.where(~np.isnan(snap_id_fallin_m31),
                              snap_id_fallin_m31,
                              snap_id_fallin_mw)
    inds = np.searchsorted(snap_ids, snap_id_fallin)
    inds[inds == snap_ids.size] = -1
    sim_data["Time_Fallin"] = np.where(inds != -1, lookback_time[inds], np.nan)
    
    # Masking arrays for subhalos at snap_ref:
    snap_ref = sim.get_snapshot(snap_id_ref)
    mask_lum, mask_dark = dataset_comp.split_luminous(snap_ref)
    sim_data["Ref_Selections"].update({
        "Vmax_Sat": dataset_comp.prune_vmax(snap_ref, low_lim=sat_low_lim),
        "Vmax_Isol": dataset_comp.prune_vmax(snap_ref, low_lim=isol_low_lim),
        "Luminous": mask_lum,
        "Dark": mask_dark
    })
    
    sim_data["Ref_Selections"]["Vmax_Infall"] = np.where(
        ~np.isnan(vmax_infall), (vmax_infall > vmax_lim), False
    )

In [None]:
min_time = round(min(list(data.values())[0]['LookbackTime']), 5)
max_time = round(max(list(data.values())[0]['LookbackTime']), 5)
print(min_time, max_time)

norm_func = plt.Normalize(vmin=min_time, vmax=max_time)

In [None]:
data["plain-LCDM"]["Colormap"] = cm.Blues
data["curv-p082"]["Colormap"] = cm.Reds


In [None]:
# shades of blue, mapping values from minval to maxval
minval = 0.2
maxval = 1

for sim_data in data.values():
    cmap = sim_data["Colormap"]

    sim_data["Colormap"] = LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, 100))
    )

### Set Plot Parameters

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 10,
              'axes.labelsize': 9,
              'xtick.labelsize': 6,
              'ytick.labelsize': 6,
              'legend.fontsize': 8}

ms = 10 # Marker size
msl = 15
a = 0.75 # Transparency

### Check that colormaps will align

In [None]:
for sim_data in data.values():
    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Dark'])
    fi_dark = sim_data['Time_Fallin'][mask]

    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Luminous'])
    fi_lum = sim_data['Time_Fallin'][mask]

    print(min(fi_dark), max(fi_dark))
    print(min(fi_lum), max(fi_lum))

In [None]:
min_time = 10**10
max_time = 0

for sim_data in data.values():
    mask_dark = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Dark'])
    mask_lum = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Luminous'])
    
    min_time = min([min_time, sim_data['Time_Fallin'][mask_dark].min(), 
                    sim_data['Time_Fallin'][mask_lum].min()])
    max_time = max([max_time, sim_data['Time_Fallin'][mask_dark].max(), 
                    sim_data['Time_Fallin'][mask_lum].max()])
min_time = round(min_time, 5)
max_time = round(max_time, 5)
print(min_time, max_time)

norm_func = plt.Normalize(vmin=min_time, vmax=max_time)

In [None]:
min_time = min(list(data.values())[0]['LookbackTime'])
max_time = max(list(data.values())[0]['LookbackTime'])
min_time = round(min_time, 5)
max_time = round(max_time, 5)
print(min_time, max_time)

norm_func = plt.Normalize(vmin=min_time, vmax=max_time)

### Create Blank Plot

In [None]:
# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

fig, ax = plt.subplots(figsize=(5, 5))
plt.subplots_adjust(wspace=0.05)

# Set axis:
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_box_aspect(0.9) # Set subfigure box side aspect ratio

ax.set_xlim(5, 110)
ax.set_ylim(0.8, 4) 

ax.yaxis.set_ticks_position('both')
ax.set_xlabel("$v_\mathrm{max}(z=0) ~ [\mathrm{km/s}]$")
ax.set_ylabel("$v_\mathrm{max}(z=z_\mathrm{infall}) ~/~ v_\mathrm{max}(z=0)$")

# ax.set_ylabel("$N_{[\cdot]}(z) ~/~ N_\mathrm{tot}(z=0)$")
# ax.set_title("Satellite Subhalos")

### ... And Plot

In [None]:
sc = []
for sim_name, sim_data in data.items():

    cmap = sim_data['Colormap'] 

    # Plot dark
    # ---------

    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Dark'])
    time = sim_data['Time_Fallin'][mask]
    x = sim_data['Vmax'][mask]
    y = sim_data['Vmax_Fallin'][mask] / x
    ax.scatter(x, y, s=ms, edgecolor='none', alpha=a, c=cmap(norm_func(time)),
               label="{} non-SF".format(sim_name))

    # Plot luminous
    # -------------

    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Luminous'])
    time = sim_data['Time_Fallin'][mask]

    x = sim_data['Vmax'][mask]
    y = sim_data['Vmax_Fallin'][mask] / x
    # Save output for colorbar:
    sc.append(ax.scatter(x, y, s=msl, facecolor='none', alpha=a, 
                         edgecolor=cmap(norm_func(time)),          
                         label="{} SF".format(sim_name)))

    # ax.axhline(15, c="gray", linestyle="dotted")    

sc[0].set_clim(vmin=min_time, vmax=max_time)
sc[0].set_cmap(data["plain-LCDM"]["Colormap"])
cax = fig.add_axes([0.92, 0.25, 0.03, 0.5])
cbar = fig.colorbar(sc[0], cax=cax, orientation='vertical')
cbar.set_ticks([])

sc[1].set_clim(vmin=min_time, vmax=max_time)
sc[1].set_cmap(data["curv-p082"]["Colormap"])
cax = fig.add_axes([0.95, 0.25, 0.03, 0.5])
cbar = fig.colorbar(sc[1], cax=cax, orientation='vertical')
cbar.ax.set_ylabel("Lookback Time at Infall [Gyr]")

fig

In [None]:
x = 10**np.linspace(0, 2, 10000)
y = vmax_lim / x

ax.plot(x, y, c='black', linestyle='dotted')
fig

### Add Median Curves

In [None]:
n_median_bins = 5
for i, (sim_name, sim_data) in enumerate(data.items()):
    
    # mask = sim_data['Ref_Selections']['LG_Satellites']  
    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Vmax_Infall'])
    
    time = sim_data['Time_Fallin'][mask]
    x = sim_data['Vmax'][mask]
    y = sim_data['Vmax_Fallin'][mask] / x
    
    x = np.log10(x)
    y = np.log10(y)
    median = curve_fit.median_trend_fixed_bin_width(
        x, y, n_bins=n_median_bins
    )
    if median is not None:
        ax.plot(10**median[0], 10**median[1], 
                sim_data['Color'][0], linestyle='--')
    else:
        print("Could not fit median for:", name)
        
fig

In [None]:
n_median_bins = 5
for i, (sim_name, sim_data) in enumerate(data.items()):
    
    # mask = sim_data['Ref_Selections']['LG_Satellites']  
    mask = np.logical_and(sim_data['Ref_Selections']['LG_Satellites'],
                          sim_data['Ref_Selections']['Vmax_Infall'])
    
    time = sim_data['Time_Fallin'][mask]
    x = sim_data['Vmax'][mask]
    y = sim_data['Vmax_Fallin'][mask] / x
    
    x = np.log10(x)
    y = np.log10(y)
    median = curve_fit.median_trend_fixed_bin_width(
        x, y, n_bins=n_median_bins
    )
    if median is not None:
        ax.plot(10**median[0], 10**median[1], 
                sim_data['Color'][0], linestyle='--')
    else:
        print("Could not fit median for:", name)
        
fig

### Save the Figures

In [None]:
filename = 'vmax_at_infall'
for name in data.keys():
    filename += '_{}'.format(name)
filename += '.png'
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')