## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import simulation
import simtrace_redo
import match_halo_redo
import dataset_comp
import subhalo

In [None]:
import importlib
importlib.reload(simulation)
importlib.reload(simtrace_redo)
importlib.reload(match_halo_redo)
importlib.reload(dataset_comp)
importlib.reload(subhalo)

For computation time efficiency analysis:

In [None]:
%load_ext line_profiler

## Test match_halo

In [None]:
matcher = match_halo_redo.SnapshotMatcher(n_link_ref=20)

First, study the low-resolution simulation. Create new, blank envelope files in the directory "test_tracing" (make sure that no envelope files exist in that directory before testing). 

In [None]:
sim_id = "V1_MR_fix"
env_path = os.path.abspath(os.path.join('..', 'test_tracing_minmatch'))
print(env_path)
sim = simulation.Simulation(sim_id, env_path=env_path)
sim.get_snap_ids()

Try matching two snapshots. Inspect the runtime.

In [None]:
# Branching backward in time:
snap_ref = sim.get_snapshot(126)
snap_srch = sim.get_snapshot(127)
%timeit matches_ref, matches_srch = matcher.match_snapshots(snap_ref, snap_srch)

In [None]:
%lprun -f matcher.match_snapshots matcher.match_snapshots(snap_ref, snap_srch)

In [None]:
%lprun -f matcher.is_a_match matcher.match_snapshots(snap_ref, snap_srch)

In [None]:
matches_ref, matches_srch = matcher.match_snapshots(snap_ref, snap_srch)
print(matches_ref.shape)
print(matches_srch.shape)

With ~5s matching time for a pair of snapshots and up to ~120 snapshots to match, the total runtime for tracing the entire LR simulation would be around 10 min. This seems acceptable.

Of course, the question is: is this fast enough to trace a MR simulation? The most computing time is spent in the `np.interset1d` function calls, i.e. computing the intersections between the most bound particles between two subhalos. This function is called 468241 times, which means that 468241 pairs of subhalos are tried as matches. 

This is probably the most time efficient function for finding the intersections, i.e. finding out whether two subhalos match. Thus, the only way to speed up the `match_snapshots` function would be to decrease the number of matching trials. But this is also quite difficult to do: probably the most promising way would be to use spatial and kinematic information. 

Note that the number of matching trials, if every subhalo in snap_ref was tried with every subhalo in snap_srch, would be 1116 * 1125 = 1255500. So, by restring trials to subhalos within mass range of factor 3 and halting at the first match, we reduce the number of trials only by about 37 % (468241 / (1116 * 1125 = 0.373). Most subhalos are in the low-mass range. Also, the subhalos with the smallest mass are also rarely matched.


Let
> $n$ be the average number of subhalos in a snapshot, and \
> $k$ the average number of bound particles in a subhalo.

The time complexity of `match_snapshots` is $\mathcal{O}(n^2)$. 

### Inspect the matches

First, simply print the matches:

In [None]:
gns_ref = snap_ref.get_subhalos('GroupNumber')
sgns_ref = snap_ref.get_subhalos('SubGroupNumber')
gns_srch = snap_srch.get_subhalos('GroupNumber')
sgns_srch = snap_srch.get_subhalos('SubGroupNumber')

for i, j in enumerate(matches_ref):
    if j != matcher.no_match:
        print("({}, {}) --> ({}, {})".format(gns_ref[i], sgns_ref[i],
                                             gns_srch[j], sgns_srch[j]))
    else:
        print("({}, {}) NO MATCH".format(gns_ref[i], sgns_ref[i]))

Most massive central halos are always matched. Most of the subhalos of M31 and MW are matched. 

Note, that subhalos of a central remain subhalos of the same central, even when the group number changes:
> (10.0, 0.0) --> (11.0, 0.0) \
> (10.0, 1.0) --> (11.0, 1.0) \
> (10.0, 2.0) --> (11.0, 2.0) \
> (11.0, 0.0) --> (9.0, 0.0) \
> (11.0, 1.0) --> (9.0, 1.0)

WRITE A DOCUMENT ON THIS TESTING PROCESS!

What do you see in the above output? What is note-worthy? What confirms you that the program is working as intended? ...that there are no bugs? What is suspicious? Which relevant questions cannot be answered based on the above? What next?

Document the step-by-step process, by which you become convinced that everything works. Then, reflect: is the process thorough and reliable? Is the program well-structured; too complicated and difficult to test, or easy to understand?

No two subhalos in snap_ref are matched with the same subhalo in snap_srch:

In [None]:
vals, cnts = np.unique(matches_ref, return_counts=True)
print("Total number of matches: {}".format(vals.size))
print("Subhalos in snap_srch that are matched more than ones:")
print("Indices: {}".format(vals[cnts > 1]))
print("Counts: {}".format(cnts[cnts > 1]))

print("Indices in ref: {}".format([list(np.arange(matches_ref.size)[matches_ref == v]) 
                                   for v in vals[cnts > 1] if v != matcher.no_match]))

... or:

In [None]:
print(np.sum(matches_srch[:,1] != matcher.no_match))

## Merger Trees

In [None]:
snap_start = 101
snap_stop = 128

In [None]:
sim.get_snapshot(snap_start).group_data.fname

In [None]:
mtree = simtrace_redo.MergerTree(sim, branching="BackwardBranching")
%lprun -f mtree.build_tree_with_back_branch mtree.build_tree(snap_start, snap_stop)

Again, no mergers found between any of the pairs of snapshots:

In [None]:
snap_stop=101
for sid in range(127, snap_stop, -1):
    snap = sim.get_snapshot(sid)
    prog = snap.get_subhalos('Progenitors', mtree.h5_group)
#     print(prog.shape)
    mask_merger = np.logical_or(prog[:,1] != mtree.no_match, 
                                prog[:,2] != mtree.no_match)
    print(sid, np.sum(mask_merger))
    print(np.size(prog, axis=0), np.sum(prog[:,0] != mtree.no_match))

In [None]:
snap_stop=101
for sid in range(snap_stop, 127):
    snap = sim.get_snapshot(sid)
    desc = snap.get_subhalos('Descendants', mtree.h5_group)
    vals, cnts = np.unique(desc, return_counts=True)
    mask_merger = np.logical_and(vals != mtree.no_match, cnts > 1)
    print(sid, np.sum(mask_merger))
    print(np.size(desc), np.sum(desc != mtree.no_match))

In [None]:
sid = 126
snap = sim.get_snapshot(sid)
desc = snap.get_subhalos('Descendants', mtree.h5_group)
prog = snap.get_subhalos('Progenitors', mtree.h5_group)
mask_shadow = np.logical_and(prog[:,0] == mtree.no_match, desc == mtree.no_match)

masks_sat, mask_isol = dataset_comp.split_satellites_by_group_number(
    snap, (1,0), (2,0))
mask_sat = np.logical_or.reduce(masks_sat)

mass = snap.get_subhalos('Mass') * units.g.to(units.Msun)
m = np.sort(mass)
m_s = np.sort(mass[mask_shadow])

fig, ax = plt.subplots(ncols=2)
ax[0].set_xscale('log')
ax[1].set_xscale('log')

m = mass[mask_sat]
ax[0].plot(np.sort(m), np.arange(m.size))
m = mass[np.logical_and(mask_sat, mask_shadow)]
ax[0].plot(np.sort(m), np.arange(m.size))

m = mass[mask_isol]
ax[1].plot(np.sort(m), np.arange(m.size))
m = mass[np.logical_and(mask_isol, mask_shadow)]
ax[1].plot(np.sort(m), np.arange(m.size))
