In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import simulation
import dataset_compute
import simulation_tracing
import subhalo
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(simulation)
importlib.reload(dataset_compute)
importlib.reload(simulation_tracing)
importlib.reload(subhalo)
importlib.reload(curve_fit)

# Fraction of halos traced

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
sim_ids = ["V1_LR_fix"]
names = ["LCDM"]
paths = ["", "/media/kassiili/USBFREE/LG_simulations", 
        "/media/kassiili/USBFREE/LG_simulations"]
paths = ["", "", ""]

# Define M31 and MW in each simulation:
snap_id = 126
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

In [None]:
sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]
names = ["LCDM", "p082"]
paths = ["", ""]

# Define M31 and MW in each simulation:
snap_id = 127
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

In [None]:
data = {}
for name, sim_id, sim_path, m31_ns, mw_ns in zip(names, sim_ids, paths, m31, mw):
    print(name)
    sim = simulation.Simulation(sim_id, sim_path=sim_path)
    mtree = simulation_tracing.MergerTree(sim, branching='BackwardBranching')
    m31_tracer = subhalo.SubhaloTracer(sim, snap_id, m31_ns[0], m31_ns[1])
    m31_tracer.trace(mtree)
    mw_tracer = subhalo.SubhaloTracer(sim, snap_id, mw_ns[0], mw_ns[1])
    mw_tracer.trace(mtree)
    
    data[name] = {"simulation": sim,
                  "merger_tree": mtree,
                  "M31": m31_tracer,
                  "MW": mw_tracer}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"

Select the snapshots for tracing:

In [None]:
snap_id_traced = list(range(101,122,5))
#snap_id_traced = [101]
snap_stop = 128

In [None]:
def repeat_columns(arr, n):
    rep = np.repeat(arr, n)
    rep = rep.reshape((np.size(arr, axis=0), n))
    
    return rep

def compute_fraction(mask_traced, mask):
    # Sum over all traced in mask at each snapshot:
    fraction = np.array([np.sum(np.logical_and(mask_traced_to_snap, mask)) 
                         for mask_traced_to_snap in mask_traced.T])
    
    # Divide by total number of items in mask:
    fraction = fraction / np.sum(mask)
    
    return fraction

In [None]:
for name, sim_data in data.items():
    sim = sim_data["simulation"]
    mtree = sim_data["merger_tree"]
    m31 = sim_data["M31"]
    mw = sim_data["MW"]
    
    for snap_id in snap_id_traced:
        snap = sim.get_snapshot(snap_id)
        m31_id = (m31.get_halo_data("GroupNumber", snap_id),
                  m31.get_halo_data("SubGroupNumber", snap_id))
        mw_id = (mw.get_halo_data("GroupNumber", snap_id),
                 mw.get_halo_data("SubGroupNumber", snap_id))
    
        # Split into satellites:
        if distinction == "by_r":
            masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
                snap, m31_id, mw_id)
        elif distinction == "by_gn":
            masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
                snap, m31_id, mw_id)
            
        mask_sat = np.logical_or.reduce(masks_sat)
        mask_lum, mask_dark = dataset_compute.split_luminous(snap)
        
        # Exclude the smallest subhalos:
        mask_vmax = dataset_compute.prune_vmax(snap, low_lim=15)
     
        mask_sat_lum = np.logical_and.reduce(
            [np.logical_or.reduce(masks_sat), mask_lum, mask_vmax])
        mask_sat_dark = np.logical_and.reduce(
            [np.logical_or.reduce(masks_sat), mask_dark, mask_vmax])
        mask_isol_lum = np.logical_and.reduce(
            [mask_isol, mask_lum, mask_vmax])
        mask_isol_dark = np.logical_and.reduce(
            [mask_isol, mask_dark, mask_vmax])
    
        snap_tracer = simulation_tracing.SnapshotTracer(snap_id, mtree)
    
        mask_traced = (snap_tracer.trace(snap_stop) < snap_tracer.no_match)
    
        # Add separate datasets for each subhalo to the data dictionary:

        fraction_traced = \
         {"satellites": 
          {"all": compute_fraction(
              mask_traced, np.logical_and(mask_sat, mask_vmax)),
           "luminous": compute_fraction(mask_traced, mask_sat_lum),
           "dark": compute_fraction(mask_traced, mask_sat_dark)
          },
          "isolated": 
          {"all": compute_fraction(
              mask_traced, np.logical_and(mask_isol, mask_vmax)),
           "luminous": compute_fraction(mask_traced, mask_isol_lum),
           "dark": compute_fraction(mask_traced, mask_isol_dark)
          },
          "all": 
          {"all": compute_fraction(
              mask_traced, mask_vmax),
           "luminous": compute_fraction(
              mask_traced, np.logical_and(mask_lum, mask_vmax)),
           "dark": compute_fraction(
               mask_traced, np.logical_and(mask_dark, mask_vmax))
          }
         }
    
        count_at_snap = \
         {"satellites": 
          {"all": np.sum(np.logical_and(mask_sat, mask_vmax)),
           "luminous": np.sum(mask_sat_lum),
           "dark": np.sum(mask_sat_dark)
          },
          "isolated": 
          {"all": np.sum(np.logical_and(mask_isol, mask_vmax)),
           "luminous": np.sum(mask_isol_lum),
           "dark": np.sum(mask_isol_dark)
          },
          "all": 
          {"all": np.sum(mask_vmax),
           "luminous": np.sum(np.logical_and(mask_lum, mask_vmax)),
           "dark": np.sum(np.logical_and(mask_dark, mask_vmax))
          }
         }
        
        z = sim.get_redshifts(snap_start=snap_id, snap_stop=snap_stop)
        
        sim_data[snap_id] = {"fraction_traced": fraction_traced, 
                             "count_at_snap": count_at_snap,
                             "redshift": z}

In [None]:
print(data.keys())
print(data['LCDM'].keys())
print(data['LCDM'][101].keys())
print(data['LCDM'][101]["count_at_snap"])

## Plot

In [None]:
# Set some parameters:
x_down = 0; x_up = 1
y_down = 0; y_up = 1.2

# Set marker styles:
fcolor = ["black", "red", "blue", "green"]
mcolor = ["gray", "pink", "lightblue", "lightgreen"]
marker = ['+', "o", "^", 1]

In [None]:
# Construct saving location:
filename = 'fraction_traced_forward'
for name in names:
    filename += "_{}".format(name)
filename += "_{}.png".format(snap_id)
    
home = os.path.dirname(simulation.__file__)
path = os.path.join(home,"Figures", "MediumResolution")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:
    ax.set_xlim(x_down, x_up)
#     ax.set_ylim(y_down, y_up)
    ax.invert_xaxis()
#     ax.set_xlabel('$Mass [$M_odot$]$', fontsize=16)
#     ax.set_ylabel('$v_{\mathrm{1 kpc}} [\mathrm{kms^{-1}}]$', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Add scatter plots:
print(data["LCDM"].keys())
for i, (name, entry) in enumerate(data.items()):
    for snap_id in snap_id_traced:
        snap_data = entry[snap_id]
#         x = snap_data['redshift']
#         y = snap_data['fraction_traced']['satellites']['all']
#         axes[0].plot(x, y, c=fcolor[i], linestyle=':')
        
        x = snap_data['redshift']
        y = snap_data['fraction_traced']['satellites']['luminous']
        axes[0].plot(x, y, c=fcolor[i])
        
        x = snap_data['redshift']
        y = snap_data['fraction_traced']['satellites']['dark']
        axes[0].plot(x, y, c=mcolor[i])
        
#         x = snap_data['redshift']
#         y = snap_data['fraction_traced']['isolated']['all']
#         axes[1].plot(x, y, c=fcolor[i], linestyle=':' )
    
        x = snap_data['redshift']
        y = snap_data['fraction_traced']['isolated']['luminous']
        axes[1].plot(x, y, c=fcolor[i])
        
        x = snap_data['redshift']
        y = snap_data['fraction_traced']['isolated']['dark']
        axes[1].plot(x, y, c=mcolor[i])
        
    axes[0].plot([],[], c=fcolor[i], label="{} luminous".format(name))
    axes[0].plot([],[], c=mcolor[i], label="{} dark".format(name))
axes[0].legend(loc='lower left')
plt.tight_layout()

plt.savefig(filename, dpi=200)