In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

import importlib

import simulation
import snapshot_obj
import simulation_tracing
import dataset_compute
import subhalo
import halo_matching

In [None]:
importlib.reload(simulation)
importlib.reload(snapshot_obj)
importlib.reload(simulation_tracing)
importlib.reload(dataset_compute)
importlib.reload(subhalo)
importlib.reload(halo_matching)

In [None]:
data = {"plain-LCDM" : {},
        "spec-p082" : {}
       }

data = {"plain-LCDM" : 
        {"M31" : {"halo_id_z0" : (1,0)},
         "MW" : {"halo_id_z0" : (2,0)}
        },
        "spec-p082" : 
        {"M31" : {"halo_id_z0" : (1,0)},
         "MW" : {"halo_id_z0" : (1,1)}
        }
       }

sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]

snap_end = 127
snap_start = 101
matcher = halo_matching.SnapshotMatcher()
cosmo = FlatLambdaCDM(H0=70, Om0=0.3)

linestyle = {'All' : ':', 'Gas' : '-', 'DM' : '--', 
             'Stars' : '-', 'BH' : '--'}

for colors, sim_name in zip([("black", "gray"), ("red", "pink")],
                            data.keys()):
    
    color = {'All' : colors[1], 'Gas' : colors[1], 'DM' : colors[0], 
             'Stars' : colors[0], 'BH' : colors[1]}

    for galaxy in data[sim_name].keys():
        data[sim_name][galaxy].update(
        {"Color" : color, "Linestyle" : linestyle})

In [None]:
for sim_name, sim_id in zip(data.keys(), sim_ids):
    sim = simulation.Simulation(sim_id)
    data[sim_name]["Redshift"] = sim.get_redshifts(snap_start, snap_end+1)

    # Build merger tree and trace centrals:
    mtree = simulation_tracing.MergerTree(sim, matcher, branching="BackwardBranching")
    mtree.build_tree(snap_start, snap_end+1)
    
    m31_id = data[sim_name]["M31"]["halo_id_z0"]
    m31 = subhalo.SubhaloTracer(sim, snap_end, m31_id[0], m31_id[1])
    mw_id = data[sim_name]["MW"]["halo_id_z0"]
    mw = subhalo.SubhaloTracer(sim, snap_end, mw_id[0], mw_id[1])
    m31.trace(mtree)
    mw.trace(mtree)

    data[sim_name]["M31"].update(
        {"Mass" :
         {'all' : np.sum(m31.get_halo_data("MassType", snap_start, snap_end+1), axis=1)
          * units.g.to(units.Msun)}
        }
    )
    data[sim_name]["M31"]["Mass"].update(
        {pt : m31.get_halo_data("MassType", snap_start, snap_end+1)[:, pt_num]\
                  * units.g.to(units.Msun)
          for pt, pt_num in zip(['gas', 'dm', 'stars', 'bh'], [0,1,4,5])
        }
    )
    
    data[sim_name]["MW"].update(
        {"Mass" :
         {'all' : np.sum(mw.get_halo_data("MassType", snap_start, snap_end+1), axis=1)
          * units.g.to(units.Msun)}
        }
    )
    data[sim_name]["MW"]["Mass"].update(
        {pt : mw.get_halo_data("MassType", snap_start, snap_end+1)[:, pt_num]\
                  * units.g.to(units.Msun)
          for pt, pt_num in zip(['gas', 'dm', 'stars', 'bh'], [0,1,4,5])
        }
    )
    
    r = mw.distance_to_central(m31, snap_start, snap_end+1, centre_name="CentreOfMass")
    data[sim_name]["Separation"] = r * units.cm.to(units.kpc)
    
    v_m31 = m31.get_halo_data("Velocity", snap_start, snap_end+1) \
        * units.cm.to(units.km)
    v_mw = mw.get_halo_data("Velocity", snap_start, snap_end+1) \
        * units.cm.to(units.km)
   
    r = r * units.cm.to(units.km)
    H = sim.get_hubble(snap_start, snap_end+1)
    H_flow = np.multiply(H, r.T).T
    data[sim_name]["Expansion"] = np.linalg.norm(H_flow, axis=1)
    
    v = H_flow + v_m31 - v_mw
    
    # Compute radial and tangential components of the relative
    # peculiar velocity:
    r_unit = np.multiply(r,  1/np.linalg.norm(r, axis=1)[:, np.newaxis])
    v_rad = np.sum(v * r_unit, axis=1)
    v_rot = np.linalg.norm(v - np.multiply(r_unit, v_rad[:, np.newaxis]),
                           axis=1)
    
    data[sim_name]["V_rad"] = v_rad
    data[sim_name]["V_rot"] = v_rot

In [None]:
# Construct saving location:
filename = "time_evolution_of_centrals.png"
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures", "MediumResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(nrows=4, sharex=True, figsize=(6,9))
plt.subplots_adjust(hspace=0.03)
axes[1].get_shared_y_axes().join(axes[0], axes[1])

for ax in axes:
    ax.yaxis.set_ticks_position('both')

axes[0].set_yscale('log')
axes[1].set_yscale('log')

axes[3].invert_xaxis()
axes[3].set_xlabel("Time")
axes[0].set_ylabel('$M_\mathrm{M31} [\mathrm{M_\odot}]$')
axes[1].set_ylabel('$M_\mathrm{MW} [\mathrm{M_\odot}]$')
axes[2].set_ylabel('Distance [kpc]')
axes[3].set_ylabel('Relative velocity [km/s]')

# Initialize a list, where to save Line2D objects (returned by plot functions)
# for the contruction of legends:
mass_plot_lines = []
mass_plot_labels_pt = []
mass_plot_labels_sim = []

for i, sim_name in enumerate(data.keys()): 
    
    z = data[sim_name]["Redshift"]
    time = cosmo.age(0).value - np.array([cosmo.age(zi).value for zi in z])
    
    # Only add first simulation data to legend of particle types:
    sim_plot_lines = []
    pt_labels = []
    sim_labels = []
    
    for pt, col, ls, mass in zip(data[sim_name]["M31"]["Color"].keys(),
                                data[sim_name]["M31"]["Color"].values(),
                                data[sim_name]["M31"]["Linestyle"].values(),
                                data[sim_name]["M31"]["Mass"].values()):
        
        # Do not plot black holes:
        if pt == 'BH':
            continue
            
        # Plot and save line object:
        line, = axes[0].plot(time, mass, c=col, linestyle=ls)
        sim_plot_lines.append(line)
        pt_labels.append(pt)
        sim_labels.append(sim_name)
    
    mass_plot_lines.append(sim_plot_lines)
    mass_plot_labels_pt.append(pt_labels)
    mass_plot_labels_sim.append(sim_labels)
    
    for pt, col, ls, mass in zip(data[sim_name]["MW"]["Color"].keys(),
                                data[sim_name]["MW"]["Color"].values(),
                                data[sim_name]["MW"]["Linestyle"].values(),
                                data[sim_name]["MW"]["Mass"].values()):
        
        # Do not plot black holes:
        if pt == 'BH':
            continue
          
        axes[1].plot(time, mass, c=col, label=pt, linestyle=ls)
        
    r = data[sim_name]["Separation"]
    axes[2].plot(time, np.linalg.norm(r, axis=1), c=data[sim_name]["M31"]["Color"]['DM'])

    v_r = data[sim_name]["V_rad"]
    v_t = data[sim_name]["V_rot"]
    v_H = data[sim_name]["Expansion"]
    if sim_name == list(data.keys())[0]:        
        axes[3].plot(time, v_H, label="Expansion",
                     linestyle='dashdot', c="gray")
        axes[3].plot(time, v_r, label="Radial",
                     linestyle='solid', c=data[sim_name]["M31"]["Color"]['DM'])
        axes[3].plot(time, v_t, label="Tangential",
                     linestyle='dotted', c=data[sim_name]["M31"]["Color"]['DM'])
    else:        
        axes[3].plot(time, v_H,
                     linestyle='dashdot', c="pink")
        axes[3].plot(time, v_r, linestyle='solid',
                     c=data[sim_name]["M31"]["Color"]['DM'])
        axes[3].plot(time, v_t, linestyle='dotted', 
                     c=data[sim_name]["M31"]["Color"]['DM'])
        
# Add legends:
sim_legend = axes[0].legend([l[3] for l in mass_plot_lines], 
                        [lab[3] for lab in mass_plot_labels_sim], 
                        loc="lower right")
axes[0].add_artist(sim_legend)
axes[0].legend(mass_plot_lines[0], mass_plot_labels_pt[0], loc="upper left")
axes[3].legend(loc="lower left")
    
# Set relative velocity plot lower y-axis limit to zero:
axes[3].set_ylim(-axes[3].get_ylim()[1], axes[3].get_ylim()[1])
# axes[3].set_ylim(0, axes[3].get_ylim()[1])
# axes[3].set_ylim(0, 155)
    
# Add x-axis above the figure for redshift:
z_ax = axes[0].twiny()
z_ax.invert_xaxis()
time_start,_ = axes[0].get_xlim()

# Set z-ticks at 0, 0.1, 0.2, 0.3, ...
z_start = 0.1 * int(z_at_value(cosmo.age, cosmo.age(0) - time_start * units.Gyr)*10)
z_tick_locations = [cosmo.age(0).value - cosmo.age(z).value 
                    for z in np.linspace(0.000001, z_start, int(z_start/0.1)+1)]

def z_tick_func(time):
    z = [z_at_value(cosmo.age, cosmo.age(0) - t * units.Gyr) 
         for t in time]
    return ["%.1f" % zi for zi in z]

z_ax.set_xlim(axes[0].get_xlim())
z_ax.set_xticks(z_tick_locations)
z_ax.set_xticklabels(z_tick_func(z_tick_locations))
z_ax.set_xlabel("Redshift")

plt.tight_layout()
plt.savefig(os.path.join(path,filename), dpi=300)