## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace
import match_halo

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace)
importlib.reload(match_halo)

# Evolution Histories of the Milky Way and the M31

Here, I compare the mass evolution, and the relative motions of the MW and the M31 between simulations. This helps me understand the potential differences between the environments of the satellites between simulations. 

---

## Set Parameters for the Plots

### plain-LCDM and spec-p082

Choose the simulations, and define the M31 and the MW at redshift $z=0$:

In [None]:
data = {
    "plain-LCDM": {
        "Simulation": simulation.Simulation("V1_MR_fix")
    },
    "curv-p082": {
        "Simulation": simulation.Simulation("V1_MR_curvaton_p082_fix")
    }
}

m31_id_z0 = [(1, 0), (1, 0)]
mw_id_z0 = [(2, 0), (1, 1)]

Define plotting style:

In [None]:
linestyle = {'All' : ':', 'Gas' : '-', 'DM' : '--', 
             'Stars' : '-', 'BH' : '--'}

data["plain-LCDM"].update({
    "Color": {'All' : 'gray', 'Gas' : 'gray', 'DM' : 'black', 
             'Stars' : 'black', 'BH' : 'gray'},
    "Linestyle" : linestyle
})
data["curv-p082"].update({
    "Color": {'All' : 'pink', 'Gas' : 'pink', 'DM' : 'red', 
             'Stars' : 'red', 'BH' : 'pink'},
    "Linestyle" : linestyle
})

---

### plain-LCDM-LR

Choose the simulations, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path)
    }
}

m31_id_z0 = [(1, 0)]
mw_id_z0 = [(2, 0)]

Define plotting style:

In [None]:
linestyle = {'All' : ':', 'Gas' : '-', 'DM' : '--', 
             'Stars' : '-', 'BH' : '--'}

data["plain-LCDM-LR"].update({
    "Color": {'All' : 'gray', 'Gas' : 'gray', 'DM' : 'black', 
             'Stars' : 'black', 'BH' : 'gray'},
    "Linestyle" : linestyle
})

---

### Low Resolution Simulations

Choose the simulations, and define the M31 and the MW at redshift $z=0$:

In [None]:
env_path = os.path.abspath(os.path.join('..', 'test_tracing_inj'))

data = {
    "plain-LCDM-LR": {
        "Simulation": simulation.Simulation("V1_LR_fix", env_path=env_path)
    },
    "curv-p082-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p082_fix", env_path=env_path)
    },
    "curv-p084-LR": {
        "Simulation": simulation.Simulation("V1_LR_curvaton_p084_fix", env_path=env_path)
    }
}

m31_id_z0 = [(1, 0), (1, 0), (1, 0)]
mw_id_z0 = [(2, 0), (1, 1), (1, 0)]

Define plotting style:

In [None]:
linestyle = {'All' : ':', 'Gas' : '-', 'DM' : '--', 
             'Stars' : '-', 'BH' : '--'}

data["plain-LCDM-LR"].update({
    "Color": {'All' : 'gray', 'Gas' : 'gray', 'DM' : 'black', 
             'Stars' : 'black', 'BH' : 'gray'},
    "Linestyle" : linestyle
})
data["curv-p082-LR"].update({
    "Color": {'All' : 'pink', 'Gas' : 'pink', 'DM' : 'red', 
             'Stars' : 'red', 'BH' : 'pink'},
    "Linestyle" : linestyle
})
data["curv-p084-LR"].update({
    "Color": {'All' : 'lightblue', 'Gas' : 'lightblue', 'DM' : 'blue', 
             'Stars' : 'blue', 'BH' : 'lightblue'},
    "Linestyle" : linestyle
})

---

## Tracing

Set the range of snapshots to be traced:

In [None]:
snap_start = 100
snap_stop = 128
snap_ids = np.arange(snap_start, snap_stop)

In [None]:
# If the simulations are not already linked:
matcher = match_halo.SnapshotMatcher(n_link_ref=20, n_matches=1)
for sim in [sim_data["Simulation"] for sim_data in data.values()]:
    mtree = simtrace.MergerTree(sim, matcher=matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_stop)

Get Subhalo objects corresponding to the M31 and the MW:

In [None]:
snap_z0 = 127
for m31_id, mw_id, sim_data in zip(m31_id_z0, mw_id_z0, data.values()):
    sim = sim_data["Simulation"]
    
    # Trace subhalos and add the M31 and the MW Subhalo objects to the
    # ´sim_data´ dictionary:
    sub_dict = sim.trace_subhalos(snap_start, snap_stop)
    sim_data.update({
        "M31": {"Subhalo": sub_dict[snap_z0][
                sim.get_snapshot(snap_z0).index_of_halo(m31_id[0], m31_id[1])
        ]},
        "MW": {"Subhalo": sub_dict[snap_z0][
            sim.get_snapshot(snap_z0).index_of_halo(mw_id[0], mw_id[1])
        ]}
    })

---

## Retrieve and Compute Datasets for Plotting

Below we add to the data dictionaries of each simulation, at each traced snapshot,
- the masses (of different types) of the centrals
- their relative distance
- the Hubble expansion of that distance
- the radial and tangential components of their relative peculiar velocity
- and the corresponding redshifts and lookback times

In [None]:
# Define the cosmology (should be the same for each simulation):
for sim_data in data.values():
    H0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("HubbleParam", "Header")
    Om0 = sim_data["Simulation"].get_snapshot(snap_stop-1)\
        .get_attribute("Omega0", "Header")
    print(H0, Om0)
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

In [None]:
print(cosmo.age(0).value)
print(cosmo.age(1).value)

In [None]:
for sim_data in data.values():
    sim = sim_data["Simulation"]
    m31 = sim_data["M31"]["Subhalo"]
    mw = sim_data["MW"]["Subhalo"]
    
    # Get redshifts and the corresponding lookback times:
    z = sim.get_attribute("Redshift", "Header", snap_ids)
    lookback_time = cosmo.age(0).value - np.array([cosmo.age(zi).value for zi in z])
    sim_data["Redshift"] = z
    sim_data["LookbackTime"] = lookback_time
    
    # Read M31 mass evolution (of different particle types) and add to dict:
    m31_masses = m31.get_halo_data("MassType", snap_ids)\
        * units.g.to(units.Msun)

    sim_data["M31"].update({"Mass": {
        pt : m31_masses[:, pt_num] for pt, pt_num \
        in zip(['Gas', 'DM', 'Stars', 'BH'], [0,1,4,5])
    }})
    sim_data["M31"]["Mass"]["All"] = np.sum(m31_masses, axis=1)

    # Read MW mass evolution (of different particle types) and add to dict:    
    mw_masses = mw.get_halo_data("MassType", snap_ids) * units.g.to(units.Msun)

    sim_data["MW"].update({"Mass": {
        pt : mw_masses[:, pt_num] for pt, pt_num \
        in zip(['Gas', 'DM', 'Stars', 'BH'], [0,1,4,5])
    }})
    sim_data["MW"]["Mass"]["All"] = np.sum(mw_masses, axis=1)

    # Get coordinates of the MW, relative to the M31: 
    r = mw.distance_to_central(m31, snap_ids, centre_name="CentreOfMass")\
        * units.cm.to(units.kpc)
    sim_data["Separation"] = r
    
    # Compute the Hubble expansion between the centrals:
    H = sim.get_attribute("H(z)", "Header", snap_ids)
    r = r * units.kpc.to(units.km)
    H_flow = np.multiply(H, r.T).T
    sim_data["Expansion"] = np.linalg.norm(H_flow, axis=1)
    
    # Compute the relative peculiar velocity of the centrals:
    v_m31 = m31.get_halo_data("Velocity", snap_ids) * units.cm.to(units.km)
    v_mw = mw.get_halo_data("Velocity", snap_ids) * units.cm.to(units.km)
    v = H_flow + v_m31 - v_mw
    
    # Compute radial and tangential components of the relative
    # peculiar velocity:
    r_unit = np.multiply(r,  1/np.linalg.norm(r, axis=1)[:, np.newaxis])
    v_rad = np.sum(v * r_unit, axis=1)
    v_rot = np.linalg.norm(v - np.multiply(r_unit, v_rad[:, np.newaxis]),
                           axis=1)
    
    sim_data["V_rad"] = v_rad
    sim_data["V_rot"] = v_rot
    
    # Get r_200 of the centrals:
    sim_data["M31"]["r_200"] = m31.get_fof_data(
        "Group_R_TopHat200", snap_ids
    ) * units.cm.to(units.kpc)
    sim_data["MW"]["r_200"] = mw.get_fof_data(
        "Group_R_TopHat200", snap_ids
    ) * units.cm.to(units.kpc)

--- 

## Plot the Evolution of Centrals

In four subfigures, plot
- the mass of the M31
- the mass of the MW
- their distance
- their relative peculiar velocity components

First, set figure parameters:

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 10,
              'axes.labelsize': 10,
              'xtick.labelsize': 9,
              'ytick.labelsize': 9,
              'legend.fontsize': 10}

Create the blanck frame figure for the plots:

In [None]:
# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

fig, axes = plt.subplots(nrows=4, sharex=True, figsize=(6,9))
plt.subplots_adjust(hspace=0.03)

# Y-axes
# ------

# Share y-axis in the mass plots:
axes[1].get_shared_y_axes().join(axes[0], axes[1])

# axes[0].set_yscale('log')
# axes[1].set_yscale('log')

for ax in axes:
    ax.yaxis.set_ticks_position('both')

# Set y-axis limits:
axes[0].set_ylim(np.log10(5*10**8), np.log10(5*10**12))
axes[2].set_ylim(0, 990)
axes[3].set_ylim(-120, 120)


# X-axes
# ------

x_points = list(data.values())[0]["LookbackTime"]
x_ws = (max(x_points) - min(x_points)) * 0.03
axes[3].set_xlim(min(x_points) - x_ws, max(x_points) + x_ws)

axes[3].invert_xaxis()    
    
# Add x-axis above the figure for redshift:
z_ax = axes[0].twiny()
z_ax.invert_xaxis()
time_start,_ = axes[0].get_xlim()

# Set z-ticks at 0, 0.1, 0.2, 0.3, ...
z_start = 0.1 * int(z_at_value(cosmo.age, cosmo.age(0) - time_start * units.Gyr)*10)
z_tick_locations = [cosmo.age(0).value - cosmo.age(z).value 
                    for z in np.linspace(0.000001, z_start, int(z_start/0.1)+1)]

def z_tick_func(time):
    z = [z_at_value(cosmo.age, cosmo.age(0) - t * units.Gyr) 
         for t in time]
    return ["%.1f" % zi for zi in z]

z_ax.set_xlim(axes[0].get_xlim())
z_ax.set_xticks(z_tick_locations)
z_ax.set_xticklabels(z_tick_func(z_tick_locations))


# Set axis labels:
axes[3].set_xlabel("Lookback Time")
z_ax.set_xlabel("Redshift")
axes[0].set_ylabel('$\log(M_\mathrm{M31} / \mathrm{M_\odot})$')
axes[1].set_ylabel('$\log(M_\mathrm{MW} / \mathrm{M_\odot})$')
axes[2].set_ylabel('Distance [kpc]')
axes[3].set_ylabel('Relative velocity [km/s]')

In [None]:
# Initialize a list, where to save Line2D objects (returned by plot functions)
# for the contruction of legends:
mass_plot_lines = []
mass_plot_labels_pt = []
mass_plot_labels_sim = []

for sim_name, sim_data in data.items(): 
    
    z = sim_data["Redshift"]
    time = cosmo.age(0).value - np.array([cosmo.age(zi).value for zi in z])
    
    # Only add first simulation data to legend of particle types:
    sim_plot_lines = []
    pt_labels = []
    sim_labels = []
    
    # Plot M31 mass evolution:
    for part_type, mass in sim_data["M31"]["Mass"].items():
        
        # Do not plot black holes:
        if part_type == 'BH':
            continue
            
        col = sim_data["Color"][part_type]
        ls = sim_data["Linestyle"][part_type]
            
        # Plot and save line object:
        line, = axes[0].plot(time, np.log10(mass), c=col, linestyle=ls)
        sim_plot_lines.append(line)
        pt_labels.append(part_type)
        sim_labels.append(sim_name)
    
    mass_plot_lines.append(sim_plot_lines)
    mass_plot_labels_pt.append(pt_labels)
    mass_plot_labels_sim.append(sim_labels)
    
    # Plot MW mass evolution:
    for part_type, mass in sim_data["MW"]["Mass"].items():
        
        # Do not plot black holes:
        if part_type == 'BH':
            continue
            
        col = sim_data["Color"][part_type]
        ls = sim_data["Linestyle"][part_type]
            
        # Plot and save line object:
        axes[1].plot(time, np.log10(mass), c=col, linestyle=ls)
        
    # Plot separation:
    r = sim_data["Separation"]
    axes[2].plot(time, np.linalg.norm(r, axis=1), c=sim_data["Color"]['DM'])

    # Plot the peculiar velocity components:
    v_r = sim_data["V_rad"]
    v_t = sim_data["V_rot"]
    v_H = sim_data["Expansion"]
    
#     axes[3].axhline(0, c="lightblue", alpha=0.5)
    
    # Only label the first:
    if sim_name == list(data.keys())[0]:        
        axes[3].plot(time, v_H, label="Expansion",
                     linestyle='dashdot', c="gray")
        axes[3].plot(time, v_r, label="Radial",
                     linestyle='solid', c=sim_data["Color"]['DM'])
        axes[3].plot(time, v_t, label="Tangential",
                     linestyle='dotted', c=sim_data["Color"]['DM'])
    else:        
        axes[3].plot(time, v_H,
                     linestyle='dashdot', c="pink")
        axes[3].plot(time, v_r, linestyle='solid',
                     c=sim_data["Color"]['DM'])
        axes[3].plot(time, v_t, linestyle='dotted', 
                     c=sim_data["Color"]['DM'])

fig

In [None]:
# Add legends:
sim_legend = axes[0].legend([l[2] for l in mass_plot_lines], 
                        [lab[2] for lab in mass_plot_labels_sim], 
                        loc="lower right")
axes[0].add_artist(sim_legend)
axes[0].legend(mass_plot_lines[0], mass_plot_labels_pt[0], loc="upper left")
axes[3].legend(loc="lower left")

plt.tight_layout()

fig

---

## Save the Figure

In [None]:
filename = "time_evolution_of_centrals.png"    
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')

---

## Further ideas

- Create a break to the y-axis in the distance plot to highlight the amplification of the evolution (https://pypi.org/project/brokenaxes/)

---

## Plot $r_{200}$

As a further check, plot the top-hat $r_{200}$ values of the centrals, and compare it with the physical distance corresponding to 300 ckpc.

In [None]:
fig, ax = plt.subplots()

ax.invert_xaxis()
ax.set_xlabel("Redshift")
ax.set_ylabel("Radius [kpc]")

for sim_name, sim_data in data.items(): 
    
    z = sim_data["Redshift"]
    m31_r200 = sim_data["M31"]["r_200"]
    ax.plot(z, m31_r200, c=sim_data["Color"]["Stars"], 
            linestyle=sim_data["Linestyle"]["Stars"], label="{}: $r_{{200, \mathrm{{M31}}}}$".format(sim_name))
    
    mw_r200 = sim_data["MW"]["r_200"]
    ax.plot(z, mw_r200, c=sim_data["Color"]["Gas"], 
            linestyle=sim_data["Linestyle"]["DM"], label="{}: $r_{{200, \mathrm{{MW}}}}$".format(sim_name))
    
a = sim.get_attribute("Time", "Header", snap_ids)
ax.plot(z, a * 300, c='blue', linestyle='dotted', label="$a * 300 \mathrm{{kpc}}$")
    
ax.legend(loc="upper left")

In [None]:
filename = "centrals_r200.png"    
    
path = os.path.abspath(os.path.join('..', 'Figures', 'MediumResolution'))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')