In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import dataset_compute
import simulation_tracing
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(simulation_tracing)
importlib.reload(curve_fit)

# Fraction of halos traced

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_z0_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]
paths = ["", "/media/kassiili/USBFREE/LG_simulations", 
        "/media/kassiili/USBFREE/LG_simulations"]
paths = ["", "", ""]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

data = {}
for name, sim_id, sim_path, m31_ns, mw_ns in zip(names, sim_ids, paths, m31, mw):
    print(name)
    snap_z0 = snapshot_obj.Snapshot(sim_id, snap_z0_id, name=name, sim_path=sim_path)
    data[name] = {"snapshot_z0": snap_z0,
                  "tracer": simulation_tracing.MergerTree(snap_z0),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"

In [None]:
def repeat_columns(arr, n):
    rep = np.repeat(arr, n)
    rep = rep.reshape((np.size(arr, axis=0), n))
    
    return rep

def compute_fraction(mask_traced, mask):
    # Sum over all traced in mask at each snapshot:
    fraction = np.array([np.sum(np.logical_and(mask_traced_to_snap, mask)) 
                         for mask_traced_to_snap in mask_traced.T])
    
    # Divide by total number of items in mask:
    fraction = fraction / np.sum(mask)
    
    return fraction

In [None]:
for name, sim_data in data.items():
    snap_z0 = sim_data["snapshot_z0"]
    tracer = sim_data["tracer"]
    tracer.trace_all(100)
    print(tracer.get_tracer_array().shape)
    
    sim_data["redshift"] = tracer.get_redshifts()
    
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap_z0, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap_z0, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    mask_sat = np.logical_or.reduce(masks_sat)
    mask_lum, mask_dark = dataset_compute.split_luminous(snap_z0)
    mask_nonzero_vmax = dataset_compute.prune_vmax(snap_z0)
    print(mask_nonzero_vmax.shape)
     
    mask_sat_lum = np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask_nonzero_vmax])
    mask_sat_dark = np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_dark, mask_nonzero_vmax])
    mask_isol_lum = np.logical_and.reduce(
        [mask_isol, mask_lum, mask_nonzero_vmax])
    mask_isol_dark = np.logical_and.reduce(
        [mask_isol, mask_dark, mask_nonzero_vmax])
    
    mask_traced = (tracer.get_tracer_array() < tracer.no_match)
    
    # Add separate datasets for each subhalo to the data dictionary:

    sim_data["fraction_traced"] = \
     {"satellites": 
      {"luminous": compute_fraction(mask_traced, mask_sat_lum),
       "dark": compute_fraction(mask_traced, mask_sat_dark)
      },
      "isolated": 
      {"luminous": compute_fraction(mask_traced, mask_isol_lum),
       "dark": compute_fraction(mask_traced, mask_isol_dark)
      },
      "all": 
      {"luminous": compute_fraction(
          mask_traced, np.logical_and(mask_lum, mask_nonzero_vmax)),
       "dark": compute_fraction(
           mask_traced, np.logical_and(mask_dark, mask_nonzero_vmax))
      }
     }
    
    sim_data["count_z0"] = \
     {"satellites": 
      {"luminous": np.sum(mask_sat_lum),
       "dark": np.sum(mask_sat_dark)
      },
      "isolated": 
      {"luminous": np.sum(mask_isol_lum),
       "dark": np.sum(mask_isol_dark)
      },
      "all": 
      {"luminous": np.sum(np.logical_and(mask_lum, mask_nonzero_vmax)),
       "dark": np.sum(np.logical_and(mask_dark, mask_nonzero_vmax))
      }
     }

## Plot

In [None]:
# Set some parameters:
x_down = 0; x_up = 1
y_down = 0; y_up = 1.2

# Set marker styles:
fcolor = ["black", "red", "blue", "green"]
mcolor = ["gray", "pink", "lightblue", "lightgreen"]
marker = ['+', "o", "^", 1]

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:
    ax.set_xlim(x_down, x_up)
    ax.set_ylim(y_down, y_up)
#     ax.set_xlabel('$Mass [$M_odot$]$', fontsize=16)
#     ax.set_ylabel('$v_{\mathrm{1 kpc}} [\mathrm{kms^{-1}}]$', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Add scatter plots:
print(data["LCDM"].keys())
for i, (name, entry) in enumerate(data.items()):
    print(name)
    x = entry['redshift']
    y = entry['fraction_traced']['satellites']['luminous'] 
    print(entry['count_z0']['satellites']['luminous'])
    axes[0].plot(x, y, c=mcolor[i], label='{} luminous'.format(name))
    
    x = entry['redshift']
    y = entry['fraction_traced']['satellites']['dark']
    print(entry['count_z0']['satellites']['dark'])
    axes[0].plot(x, y, c=fcolor[i], label='{} dark'.format(name))
        
    x = entry['redshift']
    y = entry['fraction_traced']['isolated']['luminous']
    print(entry['count_z0']['isolated']['luminous'])
    axes[1].plot(x, y, c=mcolor[i], label='{} luminous'.format(name))
    
    x = entry['redshift']
    y = entry['fraction_traced']['isolated']['dark']
    print(entry['count_z0']['isolated']['dark'])
    axes[1].plot(x, y, c=fcolor[i], label='{} dark'.format(name))
    
# Add median curves:
# n_median_points = 7
# for i, (name, entry) in enumerate(data.items()):
#     x = entry['Mass']['satellites']['all']
#     y = entry['V1kpc']['satellites']['all']   
#     print("# of satellites: {}".format(x.size))
#     median = curve_fit.median_trend(x, y, n_points_per_bar=n_median_points)
#     if median is not None:
#         axes[0].plot(median[0], median[1], c=fcolor[i], linestyle='--')
#     else:
#         print("Could not fit median for:", name)
    
#     x = entry['Mass']['isolated']['all']
#     y = entry['V1kpc']['isolated']['all']
#     print("# of isolated galaxies: {}".format(x.size))
#     median = curve_fit.median_trend(x, y, n_points_per_bar=n_median_points)
#     if median is not None:
#         axes[1].plot(median[0], median[1], c=fcolor[i], linestyle='--')
#     else:
#         print("Could not fit median for:", name)
    
axes[0].legend(loc='upper right')
plt.tight_layout()

#plt.savefig(filename, dpi=200)

In [None]:
fig, axes = plt.subplots()
axes.invert_xaxis()
axes.set_ylim(0,1.2)
axes.plot(z, fraction)