In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from astropy import units

import importlib

import simulation
import snapshot_obj
import simulation_tracing
import dataset_compute
import subhalo

In [None]:
importlib.reload(simulation)
importlib.reload(snapshot_obj)
importlib.reload(simulation_tracing)
importlib.reload(dataset_compute)
importlib.reload(subhalo)

In [None]:
%load_ext line_profiler

In [None]:
def distance_to_central(snap_ids, sim, central):
    central_cop = central.get_halo_data(
        "CentreOfPotential", snap_ids[0], snap_ids[-1] + 1)
    dist = [dataset_compute.distance_to_point(
        sim.get_snapshot(snap_id), cop) for snap_id, cop in zip(snap_ids, central_cop)]
    
    return dist

# Plotting trajectories of satellites

## Setting variables

Let us first define our simulation and the LG central galaxies. The M31 and MW galaxies have identifiers (1,0) and (1,1) at redshift zero (snap_id=127), respectively:

In [None]:
sim = simulation.Simulation("V1_LR_fix")
m31 = subhalo.SubhaloTracer(sim, 127, 1, 0)
mw = subhalo.SubhaloTracer(sim, 127, 1, 1)

Next, since we are interested in the past trajectories of the subhalos that exist at $z=0$, we need to set the reference snapshot id to 127. We also set the snapshot up to which the trajectories are traced. Then, we get the merger tree and the snapshot tracer, and construct the tracers for the subhalos: 

In [None]:
snap_ref = 100
snap_start = 100
snap_stop = 128

In [None]:
mtree = simulation_tracing.MergerTree(sim, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

# Trace centrals:
m31.trace(mtree)
mw.trace(mtree)

In [None]:
# Generate snapshot tracer:
snap_tracer = simulation_tracing.SnapshotTracer(snap_ref, mtree)
tracer_arr = snap_tracer.trace(start=snap_start, stop=snap_stop)

## Selecting the satellites

Construct the mask array for selecting satellites or isolated galaxies at snap_ref:

In [None]:
snapshot = sim.get_snapshot(snap_ref)
mask_sat, mask_isol = dataset_compute.split_satellites_by_distance(
    snapshot, m31.get_identifier(snap_ref), mw.get_identifier(snap_ref))
mask_m31 = mask_sat[0]
mask_mw = mask_sat[1]

mask_lum, mask_dark = dataset_compute.split_luminous(snapshot)

## Retrieve data

Get, in a list, the datasets containing centres of potential of all subhalos in a snapshot, for all snapshots of interest:

In [None]:
cops = sim.get_subhalos_in_snapshots(
    np.arange(snap_start, snap_stop), "CentreOfPotential")
mass = sim.get_subhalos_in_snapshots(
    np.arange(snap_start, snap_stop), "MassType")

We want to plot the trajectories in a reference frame, where the central galaxy is stationary:

In [None]:
m31_cop = m31.get_halo_data("CentreOfPotential", snap_start, snap_stop)
mw_cop = mw.get_halo_data("CentreOfPotential", snap_start, snap_stop)

In [None]:
snaps = sim.get_snapshots(snap_start, snap_stop)
cops_centered = [dataset_compute.periodic_wrap(
    snaps[i], m31_cop[i], cops[i]) - m31_cop[i] for i in range(snaps.size)]

Convert coordinates to kpc:

In [None]:
for i in range(len(cops)):
    cops_centered[i] = cops_centered[i] * units.cm.to(units.kpc)
    mass[i] = mass[i] * units.g.to(units.Msun)

Get redshifts:

In [None]:
redshift = sim.get_redshifts(snap_start, snap_stop)
z_ref = sim.get_snapshot(snap_ref).get_attribute("Redshift", "Header")

## Construct trajectory arrays

Finally, we are ready to make the trajectories. First, get the coordinate positions of the satellites in all snapshots in an array:

In [None]:
sat_tracer = tracer_arr[mask_m31]
sat_cops = np.zeros((np.size(sat_tracer, axis=0), 3, np.size(sat_tracer, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer, axis=1)):
    mask_traced = sat_tracer[:,i] < snap_tracer.no_match
    sat_cops[mask_traced,:,i] = cops_centered[i][sat_tracer[mask_traced, i]]

Then, for each satellite that does not exist from snap_start to snap_stop, remove meaningless entries, and save the trajectory arrays of individual satellites in a list:

In [None]:
sat_trajectories = [traj[:, traj[1, :] != 0] for traj in sat_cops] 
sat_trajectories = [traj for traj in sat_trajectories if np.size(traj, axis=1) == 28]

In [None]:
fig, axes = plt.subplots()
for trajectory in sat_trajectories:
    axes.plot(trajectory[1], trajectory[2])

## Distance to central

We are interested in seeing if there is a difference in the evolution of dark and luminous galaxies, so we need to introduce a further division.

First, get datasets for satellites.

Luminous:

In [None]:
sat_tracer_lum = tracer_arr[np.logical_and(
    np.logical_or(mask_m31, mask_mw), mask_lum)]
sat_r_lum = np.zeros((np.size(sat_tracer_lum, axis=0), 
                      np.size(sat_tracer_lum, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer_lum, axis=1)):
    mask_traced = sat_tracer_lum[:,i] < snap_tracer.no_match
    sat_r_lum[mask_traced,i] = np.linalg.norm(
        cops_centered[i][sat_tracer_lum[mask_traced, i]], axis=1)

In [None]:
sat_lum = [np.vstack([redshift[r > 10**-2], r[r > 10**-2]]) for r in sat_r_lum]
sat_lum = [halo_data for halo_data in sat_lum if np.size(halo_data[0]) > 3]

Dark:

In [None]:
sat_tracer_dark = tracer_arr[np.logical_and(
    np.logical_or(mask_m31, mask_mw), mask_dark)]
sat_r_dark = np.zeros((np.size(sat_tracer_dark, axis=0), 
                      np.size(sat_tracer_dark, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer_dark, axis=1)):
    mask_traced = sat_tracer_dark[:,i] < snap_tracer.no_match
    sat_r_dark[mask_traced,i] = np.linalg.norm(
        cops_centered[i][sat_tracer_dark[mask_traced, i]], axis=1)

In [None]:
sat_dark = [np.vstack([redshift[r > 10**-2], r[r > 10**-2]]) for r in sat_r_dark]
sat_dark = [halo_data for halo_data in sat_dark if np.size(halo_data[0]) > 3]

Plot:

In [None]:
fig, axes = plt.subplots()

axes.text(0.1, 0.9, "$z_\mathrm{{ref}} = {:.3f}$".format(z_ref), 
          transform=axes.transAxes)

for traj in sat_lum:
    z = traj[0]
    r = traj[1]
    plt.plot(z, r, c='black')
    
for traj in sat_dark:
    z = traj[0]
    r = traj[1]
    plt.plot(z, r, c='gray')

# Mass evolution