In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import simulation
import subhalo
import dataset_compute
import simulation_tracing
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(simulation)
importlib.reload(subhalo)
importlib.reload(dataset_compute)
importlib.reload(simulation_tracing)
importlib.reload(curve_fit)

# Fraction of halos traced backward in time

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_z0_id = 127
sim_ids = ["V1_LR_fix"]
names = ["LCDM", "p082", "p084"]
paths = ["", "/media/kassiili/USBFREE/LG_simulations", 
        "/media/kassiili/USBFREE/LG_simulations"]
paths = ["", "", ""]

# Define M31 and MW in each simulation:
snap_id = 126
snap_start = 101
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

data = {}
for name, sim_id, sim_path, m31_ns, mw_ns in zip(names, sim_ids, paths, m31, mw):
    print(name)
    sim = simulation.Simulation(sim_id, sim_path=sim_path)
    mtree = simulation_tracing.MergerTree(sim, branching='BackwardBranching')
    m31_tracer = subhalo.SubhaloTracer(sim, snap_id, m31_ns[0], m31_ns[1])
    m31_tracer.trace(mtree)
    mw_tracer = subhalo.SubhaloTracer(sim, snap_id, mw_ns[0], mw_ns[1])
    mw_tracer.trace(mtree)
    
    data[name] = {"simulation": sim,
                  "merger_tree": mtree,
                  "M31": m31_tracer,
                  "MW": mw_tracer}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"

In [None]:
def repeat_columns(arr, n):
    rep = np.repeat(arr, n)
    rep = rep.reshape((np.size(arr, axis=0), n))
    
    return rep

def compute_fraction(mask_traced, mask):
    # Sum over all traced in mask at each snapshot:
    fraction = np.array([np.sum(np.logical_and(mask_traced_to_snap, mask)) 
                         for mask_traced_to_snap in mask_traced.T])
    
    # Divide by total number of items in mask:
    fraction = fraction / np.sum(mask)
    
    return fraction

In [None]:
for name, sim_data in data.items():
    sim = sim_data["simulation"]
    snap = sim.get_snapshot(snap_id)
    mtree = sim_data["merger_tree"]
    m31 = sim_data["M31"]
    mw = sim_data["MW"]
    
    m31_id = (m31.get_halo_data("GroupNumber", snap_id),
              m31.get_halo_data("SubGroupNumber", snap_id))
    mw_id = (mw.get_halo_data("GroupNumber", snap_id),
              mw.get_halo_data("SubGroupNumber", snap_id))
    
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, m31_id, mw_id)
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, m31_id, mw_id)
            
    mask_sat = np.logical_or.reduce(masks_sat)
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_nonzero_vmax = dataset_compute.prune_vmax(snap)
     
    mask_sat_lum = np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask_nonzero_vmax])
    mask_sat_dark = np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_dark, mask_nonzero_vmax])
    mask_isol_lum = np.logical_and.reduce(
        [mask_isol, mask_lum, mask_nonzero_vmax])
    mask_isol_dark = np.logical_and.reduce(
        [mask_isol, mask_dark, mask_nonzero_vmax])
    
    snap_tracer = simulation_tracing.SnapshotTracer(snap_id, mtree)

    print(snap_start, snap_id)
    mask_traced = (snap_tracer.trace(snap_start, snap_id+1) < snap_tracer.no_match)
    print(snap_id-snap_start, np.size(mask_traced, axis=1))
    print(np.sum(mask_traced[:, snap_id-snap_start]), np.size(mask_traced, axis=0))
    
    # Add separate datasets for each subhalo to the data dictionary:

    sim_data["fraction_traced"] = \
     {"satellites": 
      {"luminous": compute_fraction(mask_traced, mask_sat_lum),
       "dark": compute_fraction(mask_traced, mask_sat_dark)
      },
      "isolated": 
      {"luminous": compute_fraction(mask_traced, mask_isol_lum),
       "dark": compute_fraction(mask_traced, mask_isol_dark)
      },
      "all": 
      {"luminous": compute_fraction(
          mask_traced, np.logical_and(mask_lum, mask_nonzero_vmax)),
       "dark": compute_fraction(
           mask_traced, np.logical_and(mask_dark, mask_nonzero_vmax))
      }
     }
    
    sim_data["count_at_snap"] = \
     {"satellites": 
      {"luminous": np.sum(mask_sat_lum),
       "dark": np.sum(mask_sat_dark)
      },
      "isolated": 
      {"luminous": np.sum(mask_isol_lum),
       "dark": np.sum(mask_isol_dark)
      },
      "all": 
      {"luminous": np.sum(np.logical_and(mask_lum, mask_nonzero_vmax)),
       "dark": np.sum(np.logical_and(mask_dark, mask_nonzero_vmax))
      }
     }
    
    sim_data["redshift"] = sim.get_redshifts(snap_start=snap_start, snap_stop=snap_id+1)

## Plot

In [None]:
# Set some parameters:
x_down = 0; x_up = 1
y_down = 0; y_up = 1.2

# Set marker styles:
fcolor = ["black", "red", "blue", "green"]
mcolor = ["gray", "pink", "lightblue", "lightgreen"]
marker = ['+', "o", "^", 1]

In [None]:
# Construct saving location:
filename = 'fraction_traced_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:
    ax.set_xlim(x_down, x_up)
    ax.set_ylim(y_down, y_up)
    ax.invert_xaxis()
#     ax.set_xlabel('$Mass [$M_odot$]$', fontsize=16)
#     ax.set_ylabel('$v_{\mathrm{1 kpc}} [\mathrm{kms^{-1}}]$', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Add scatter plots:
print(data["LCDM"].keys())
for i, (name, entry) in enumerate(data.items()):
    print(name)
    x = entry['redshift']
    y = entry['fraction_traced']['satellites']['luminous'] 
    print(x,y)
    print(entry['count_at_snap']['satellites']['luminous'])
    axes[0].plot(x, y, c=mcolor[i], label='{} luminous'.format(name))
    
    x = entry['redshift']
    y = entry['fraction_traced']['satellites']['dark']
    print(entry['count_at_snap']['satellites']['dark'])
    axes[0].plot(x, y, c=fcolor[i], label='{} dark'.format(name))
        
    x = entry['redshift']
    y = entry['fraction_traced']['isolated']['luminous']
    print(entry['count_at_snap']['isolated']['luminous'])
    axes[1].plot(x, y, c=mcolor[i], label='{} luminous'.format(name))
    
    x = entry['redshift']
    y = entry['fraction_traced']['isolated']['dark']
    print(entry['count_at_snap']['isolated']['dark'])
    axes[1].plot(x, y, c=fcolor[i], label='{} dark'.format(name))
    
# Add median curves:
# n_median_points = 7
# for i, (name, entry) in enumerate(data.items()):
#     x = entry['Mass']['satellites']['all']
#     y = entry['V1kpc']['satellites']['all']   
#     print("# of satellites: {}".format(x.size))
#     median = curve_fit.median_trend(x, y, n_points_per_bar=n_median_points)
#     if median is not None:
#         axes[0].plot(median[0], median[1], c=fcolor[i], linestyle='--')
#     else:
#         print("Could not fit median for:", name)
    
#     x = entry['Mass']['isolated']['all']
#     y = entry['V1kpc']['isolated']['all']
#     print("# of isolated galaxies: {}".format(x.size))
#     median = curve_fit.median_trend(x, y, n_points_per_bar=n_median_points)
#     if median is not None:
#         axes[1].plot(median[0], median[1], c=fcolor[i], linestyle='--')
#     else:
#         print("Could not fit median for:", name)
    
axes[0].legend(loc='upper right')
plt.tight_layout()

plt.savefig(filename, dpi=200)