In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from astropy import units
import plotly.graph_objects as go

import importlib

import simulation
import snapshot_obj
import simulation_tracing
import dataset_compute
import subhalo

In [None]:
importlib.reload(simulation)
importlib.reload(snapshot_obj)
importlib.reload(simulation_tracing)
importlib.reload(dataset_compute)
importlib.reload(subhalo)

# Plotting trajectories of satellites

## Setting variables

Let us first define our simulation and the LG central galaxies. The M31 and MW galaxies have identifiers (1,0) and (1,1) at redshift zero (snap_id=127), respectively:

In [None]:
sim = simulation.Simulation("V1_LR_fix")
m31 = subhalo.SubhaloTracer(sim, 127, 1, 0)
mw = subhalo.SubhaloTracer(sim, 127, 1, 1)

Next, since we are interested in the past trajectories of the subhalos that exist at $z=0$, we need to set the reference snapshot id to 127. We also set the snapshot up to which the trajectories are traced. Then, we get the merger tree and the snapshot tracer, and construct the tracers for the subhalos: 

In [None]:
snap_start = 100
snap_stop = 128

In [None]:
mtree = simulation_tracing.MergerTree(sim, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)

# Trace centrals:
m31.trace(mtree)
mw.trace(mtree)

Make tracers for the traced snapshots:

In [None]:
traced_snaps = np.array([100, 110, 120, 127])
traced_z = [sim.get_snapshot(snap_id).get_attribute("Redshift", "Header") 
            for snap_id in traced_snaps]
snap_tracers = {}
for snap_id in traced_snaps:
    snap_tracer = simulation_tracing.SnapshotTracer(snap_id, mtree)
    snap_tracer.trace(start=snap_id, stop=snap_stop)
    snap_tracers[snap_id] = snap_tracer.tracer_array

## Selecting the satellites

In [None]:
for snap_id, tracer in snap_tracers.items():
    snapshot = sim.get_snapshot(snap_id)
    mask_sat,_ = dataset_compute.split_satellites_by_distance(
        snapshot, m31.get_identifier(snap_id), mw.get_identifier(snap_id))
    mask_sat = np.logical_or.reduce(mask_sat)
    snap_tracers[snap_id] = tracer[mask_sat]

In [None]:
for sid, tracer in snap_tracers.items():
    print(sid, np.sum(tracer[:, sid] != mtree.no_match))

In [None]:
unique, counts = np.unique(snap_tracers[100][:,110], return_counts=True)
print(counts)
print(unique)
print(counts[np.where(unique > 1)[0]])

## Construct Sankey

In [None]:
node = dict(label = ["Accreted satellites"] + \
            ["$z={:.2f}$".format(z) for z in traced_z] +\
            ["Destroyed satellites"])

In [None]:
node = dict(label = ["Fallen in"] + \
            [str(snap_id) for snap_id in snap_tracers.keys()] +\
            ["Destroyed"])

In [None]:
print(node)

Compute link values, sources and targets:

In [None]:
node_snaps = list(snap_tracers.keys())
node_snaps.sort()

In [None]:
# For satellites of each node, find their the creation nodes:
creation_node = {snap_id: -1 * np.ones(np.size(tracer, axis=0),
                                          dtype = int) 
                 for snap_id, tracer in snap_tracers.items()}
for snap_id, tracer in snap_tracers.items():
    node_idx = node_snaps.index(snap_id)
    cr_node = creation_node[snap_id]
    sats = tracer[:, snap_id]
    for prev_node in node_snaps[:node_idx]:
        # Get indices of satellites at the current snapshot, which 
        # are present in prev_node:
        sats_from_prev = snap_tracers[prev_node][:, snap_id]
        sats_from_prev = sats_from_prev[
            sats_from_prev != mtree.no_match]
        
        # If the satellite was not created in earlier nodes, it is 
        # created in prev_node:
#         print(sats_from_prev)
#         print(sats)
        _, idx_places, _ = np.intersect1d(sats, sats_from_prev, 
                                          return_indices=True)
#         print(idx_places)
        cr_node[idx_places] = np.where(cr_node[idx_places] != -1,
                                       cr_node[idx_places],
                                       prev_node)
#         creation_node[snap_id] = np.array(
#             [prev_node for sat_idx in idx_places 
#              if cr_node[sat_idx] == -1]
#         )
        
#     print(cr_node)
#     print(np.where(cr_node != -1, cr_node, snap_id))
    creation_node[snap_id] = np.where(cr_node != -1, cr_node, snap_id)

In [None]:
# For satellites of each node, find their the destruction nodes:
destruction_node = {snap_id: -1 * np.ones(np.size(tracer, axis=0),
                                          dtype = int) 
                    for snap_id, tracer in snap_tracers.items()}
for snap_id, tracer in snap_tracers.items():
    node_idx = node_snaps.index(snap_id)
    for idx in range(node_idx+1, len(node_snaps)):
        print(node_snaps[idx])
        # Get indices of satellites of the current snapshot
        # in node_foll:
        sats_in_foll = snap_tracers[snap_id][:, node_snaps[idx]]
        
        dr_node= destruction_node[snap_id]
        destruction_node[snap_id] = np.where(
            np.logical_and(sats_in_foll == mtree.no_match,
                           dr_node == -1),
            node_snaps[idx-1], dr_node)
        
#         print("dr fuckin node",dr_node)
#         print(np.logical_and(sats_in_foll == mtree.no_match,
#                            dr_node == -1))
#         print(destruction_node[snap_id])

In [None]:
value = []
source = []
target = []

# Add creation links:
for snap_id, cr_arr in creation_node.items():
    node_idx = np.nonzero(traced_snaps == snap_id)[0][0]
    if node_idx == 0:
        continue
    cnt = np.sum(cr_arr == snap_id)
    value.append(cnt)
    source.append("Fallen in")
    target.append(str(snap_id))
        
# Add destruction links:
for snap_id, dr_arr in destruction_node.items():
    nodes, counts = np.unique(dr_arr[dr_arr != -1], return_counts=True)
    for n, count in zip(nodes, counts):
        value.append(count)
        source.append(str(n))
        target.append("Destroyed")
        
# Add links between snapshots:
for snap_id, cr_arr in creation_node.items():
    node_idx = np.nonzero(traced_snaps == snap_id)[0][0]
    if node_idx == 0:
        continue
    print(snap_id, traced_snaps[node_idx], traced_snaps[node_idx-1])
    nodes, counts = np.unique(cr_arr[cr_arr != snap_id], return_counts=True)
    for cnt in counts:
        value.append(cnt)
        source.append(str(traced_snaps[node_idx-1]))
        target.append(str(snap_id))
        

In [None]:
for v, s, t in zip(value, source, target):
    print(v, s, t)

In [None]:
source = [node["label"].index(s) for s in source]
target = [node["label"].index(t) for t in target]

In [None]:
print(source)
print(target)

In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = node["label"],
      color = "blue"
    ),
    link = dict(
      source = source, # indices correspond to labels, eg A1, A2, A2, B1, ...
      target = target,
      value = value
  ))])


In [None]:
fig.show()

In [None]:
#THIS IS NOT RIGHT:

In [None]:
values = []
sources = []
targets = []
snaps = list(snap_tracers.keys()) + [snap_stop-1]
snaps.sort()
for snap_id, tracer in snap_tracers.items():
    # Iterate through the following nodes, saving the number of satellites 
    # left over from snap_id:
    sats = tracer[:, snap_id]
    sats = sats[sats != mtree.no_match]
    idx = snaps.index(snap_id)
    foll_snaps = snaps[idx+1:]
    for idx_foll in range(idx+1, len(snaps)-1):
        foll_snap = snaps[idx_foll]
        
        # Get indices of satellites in foll_snap:
        sats_in_foll = snap_tracers[snap_id][:, foll_snap]
        print(sats_in_foll.size)
        sats_in_foll = sats_in_foll[sats_in_foll != mtree.no_match]
        print(sats_in_foll.size)
        
        # Add link for satellites originating from snap_id:
        prev_snap = snaps[idx_foll - 1] 
        values.append(np.size(sats_in_foll))
        sources.append(str(prev_snap))
        targets.append(str(foll_snap))
        
        # Get indices of satellites in the node before foll_snap:
        sats_from_prev = snap_tracers[snap_id][:, prev_snap]
        sats_from_prev = sats_from_prev[sats_from_prev != mtree.no_match]
        
        # Add link for satellites that get destroyed between prev_snap
        # and foll_snap:
        prev_snap = idx_foll - 1 
        values.append(sats_from_prev.size - sats_in_foll.size)
        sources.append(str(snaps[prev_snap]))
        targets.append("Destroyed")
        
    # Add newly accreted satellites:
    if idx != 0:
        sats_from_prev = snap_tracers[snaps[idx-1]][:, snap_id]
        sats_from_prev = sats_from_prev[sats_from_prev != mtree.no_match]        
        values.append(sats.size - sats_from_prev.size)
        sources.append("Accreted")
        targets.append(str(snap_id))

## Retrieve data

Get, in a list, the datasets containing centres of potential of all subhalos in a snapshot, for all snapshots of interest:

In [None]:
cops = sim.get_subhalos_in_snapshots(
    np.arange(snap_start, snap_stop), "CentreOfPotential")
mass = sim.get_subhalos_in_snapshots(
    np.arange(snap_start, snap_stop), "MassType")

We want to plot the trajectories in a reference frame, where the central galaxy is stationary:

In [None]:
m31_cop = m31.get_halo_data("CentreOfPotential", snap_start, snap_stop)
mw_cop = mw.get_halo_data("CentreOfPotential", snap_start, snap_stop)

In [None]:
snaps = sim.get_snapshots(snap_start, snap_stop)
cops_centered = [dataset_compute.periodic_wrap(
    snaps[i], m31_cop[i], cops[i]) - m31_cop[i] for i in range(snaps.size)]

Convert coordinates to kpc:

In [None]:
for i in range(len(cops)):
    cops_centered[i] = cops_centered[i] * units.cm.to(units.kpc)
    mass[i] = mass[i] * units.g.to(units.Msun)

Get redshifts:

In [None]:
redshift = sim.get_redshifts(snap_start, snap_stop)
z_ref = sim.get_snapshot(snap_ref).get_attribute("Redshift", "Header")

## Construct trajectory arrays

Finally, we are ready to make the trajectories. First, get the coordinate positions of the satellites in all snapshots in an array:

In [None]:
sat_tracer = tracer_arr[mask_m31]
sat_cops = np.zeros((np.size(sat_tracer, axis=0), 3, np.size(sat_tracer, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer, axis=1)):
    mask_traced = sat_tracer[:,i] < snap_tracer.no_match
    sat_cops[mask_traced,:,i] = cops_centered[i][sat_tracer[mask_traced, i]]

Then, for each satellite that does not exist from snap_start to snap_stop, remove meaningless entries, and save the trajectory arrays of individual satellites in a list:

In [None]:
sat_trajectories = [traj[:, traj[1, :] != 0] for traj in sat_cops] 
sat_trajectories = [traj for traj in sat_trajectories if np.size(traj, axis=1) > 3]

In [None]:
fig, axes = plt.subplots()
for trajectory in sat_trajectories:
    axes.plot(trajectory[1], trajectory[2])

## Distance to central

We are interested in seeing if there is a difference in the evolution of dark and luminous galaxies, so we need to introduce a further division.

First, get datasets for satellites.

Luminous:

In [None]:
sat_tracer_lum = tracer_arr[np.logical_and(
    np.logical_or(mask_m31, mask_mw), mask_lum)]
sat_r_lum = np.zeros((np.size(sat_tracer_lum, axis=0), 
                      np.size(sat_tracer_lum, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer_lum, axis=1)):
    mask_traced = sat_tracer_lum[:,i] < snap_tracer.no_match
    sat_r_lum[mask_traced,i] = np.linalg.norm(
        cops_centered[i][sat_tracer_lum[mask_traced, i]], axis=1)

In [None]:
sat_lum = [np.vstack([redshift[r > 10**-2], r[r > 10**-2]]) for r in sat_r_lum]
sat_lum = [halo_data for halo_data in sat_lum if np.size(halo_data[0]) > 3]

Dark:

In [None]:
sat_tracer_dark = tracer_arr[np.logical_and(
    np.logical_or(mask_m31, mask_mw), mask_dark)]
sat_r_dark = np.zeros((np.size(sat_tracer_dark, axis=0), 
                      np.size(sat_tracer_dark, axis=1)))

# Iterate over snapshots:
for i in range(np.size(sat_tracer_dark, axis=1)):
    mask_traced = sat_tracer_dark[:,i] < snap_tracer.no_match
    sat_r_dark[mask_traced,i] = np.linalg.norm(
        cops_centered[i][sat_tracer_dark[mask_traced, i]], axis=1)

In [None]:
sat_dark = [np.vstack([redshift[r > 10**-2], r[r > 10**-2]]) for r in sat_r_dark]
sat_dark = [halo_data for halo_data in sat_dark if np.size(halo_data[0]) > 3]

Plot:

In [None]:
fig, axes = plt.subplots()

axes.text(0.1, 0.9, "$z_\mathrm{{ref}} = {:.3f}$".format(z_ref), 
          transform=axes.transAxes)

for traj in sat_lum:
    z = traj[0]
    r = traj[1]
    plt.plot(z, r, c='black')
    
for traj in sat_dark:
    z = traj[0]
    r = traj[1]
    plt.plot(z, r, c='gray')

# Mass evolution