In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from astropy import units
import plotly.graph_objects as go

import importlib

import simulation
import snapshot_obj
import simulation_tracing
import dataset_compute
import subhalo

In [None]:
importlib.reload(simulation)
importlib.reload(snapshot_obj)
importlib.reload(simulation_tracing)
importlib.reload(dataset_compute)
importlib.reload(subhalo)

In [None]:
sim = simulation.Simulation("V1_LR_fix")
m31 = subhalo.SubhaloTracer(sim, 127, 1, 0)
mw = subhalo.SubhaloTracer(sim, 127, 1, 1)

In [None]:
snap_start = 100
snap_stop = 128

In [None]:
mtree = simulation_tracing.MergerTree(sim, branching="BackwardBranching")
mtree.build_tree(snap_start, snap_stop)
heritage = mtree.get_all_matches()

In [None]:
print(heritage.keys())
print(heritage[120].keys())


In [None]:
snaps = list(heritage.keys())
snaps.sort()
print(snaps)

In [None]:
formation_time = {snap_id: snap_id * np.ones(sim.get_snapshot(snap_id).get_subhalo_number()) 
                  for snap_id in snaps}
destruction_time = {snap_id: mtree.no_match * np.ones(
        sim.get_snapshot(snap_id).get_subhalo_number(), dtype=int
    ) for snap_id in snaps}

In [None]:
progs_in_snap = {snap_id: np.arange(sim.get_snapshot(snap_id).get_subhalo_number())
                 for snap_id in snaps
                }

# Iterate through snapshots backward in time:
for snap_id in snaps[::-1]:
    if 'Progenitors' not in heritage[snap_id].keys():
        print("here")
        continue

    progs = heritage[snap_id]['Progenitors']
        
    # Update formation times for the satellites in the succeeding 
    # snapshots:
    for succ_snap_id in snaps[snaps.index(snap_id):]:
#         print(succ_snap_id, snap_id)
        # Select subhalos in succ_snap that have a progenitor in snap:
        mask_prog_in_snap = progs_in_snap[succ_snap_id] != mtree.no_match
#         print(np.sum(mask_prog_in_snap))
#         print(mask_prog_in_snap.shape)
            
        # Find the (potential) progenitors of these in the snapshot 
        # preceding snap:
        progs_in_prev = mtree.no_match * np.ones(mask_prog_in_snap.size,
                                                 dtype=int)
#         print(progs[progs_in_snap[succ_snap_id][mask_prog_in_snap], 0])
        progs_in_prev[mask_prog_in_snap] = \
            progs[progs_in_snap[succ_snap_id][mask_prog_in_snap], 0]
#         print(progs_in_prev.size)
            
        # Select those that have progenitors in the previous snapshot,
        # update their formation times:
        mask_prog_in_prev = progs_in_prev != mtree.no_match
#         print(np.sum(mask_prog_in_prev))
        formation_time[succ_snap_id][mask_prog_in_prev] = snap_id - 1
            
        # Update progs_in_snap for foll_snap:
        progs_in_snap[succ_snap_id] = progs_in_prev

In [None]:
descs_in_snap = {snap_id: np.arange(sim.get_snapshot(snap_id).get_subhalo_number())
                 for snap_id in snaps
                }

# Iterate through snapshots forward in time:
for snap_id in snaps:
    if 'Descendants' not in heritage[snap_id].keys():
        print("here")
        continue

    descs = heritage[snap_id]['Descendants']
        
    # Update destruction times for the satellites in the previous 
    # snapshots:
    for prev_snap_id in snaps[:snaps.index(snap_id) + 1]:
        print(prev_snap_id, snap_id)
        # Select subhalos in prev_snap that have a descendant in snap:
        mask_desc_in_snap = descs_in_snap[prev_snap_id] != mtree.no_match
        print(np.sum(mask_desc_in_snap))
        print(mask_desc_in_snap.shape)
            
        # Find the (potential) descendants of these in the snapshot 
        # succeeding snap:
        descs_in_succ = mtree.no_match * np.ones(mask_desc_in_snap.size,
                                                 dtype=int)
#         print(descs[descs_in_snap[prev_snap_id][mask_desc_in_snap]])
        descs_in_succ[mask_desc_in_snap] = \
            descs[descs_in_snap[prev_snap_id][mask_desc_in_snap]]
        print(descs_in_succ.size)
            
        # Select those that have descendants in the previous snapshot,
        # update their formation times:
        mask_no_desc_in_succ = descs_in_succ == mtree.no_match
        print(np.sum(mask_no_desc_in_succ))
        destruction_time[prev_snap_id][
            np.logical_and(mask_desc_in_snap, mask_no_desc_in_succ)
        ] = snap_id
            
        # Update descs_in_snap for prev_snap:
        descs_in_snap[prev_snap_id] = descs_in_succ

In [None]:
snap_id = 110
print(formation_time[snap_id][:100])
print(destruction_time[snap_id][:100])