In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from astropy import units
from pathlib import Path
import os

import snapshot_obj
import dataset_compute

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)

In [None]:
snap_id = 127
sim_id = "V1_MR_fix"

# Define M31 and MW:
m31 = (1,0)
mw = (2,0)

In [None]:
snap = snapshot_obj.Snapshot(sim_id, snap_id)
max_point = snap.get_subhalos("Max_Vcirc", "Extended")
vmax = max_point[:,0] * units.cm.to(units.km)

gn = snap.get_subhalos("GroupNumber")
sgn = snap.get_subhalos("SubGroupNumber")
print(vmax[np.logical_and(gn==m31[0], sgn==m31[1])])
print(vmax[np.logical_and(gn==mw[0], sgn==mw[1])])

# All datasets will be sorted by vmax:
sort_idx = np.argsort(vmax)[::-1]
vmax = vmax[sort_idx]

# Get selections (masking arrays):
mask_lum, mask_dark = dataset_compute.split_luminous(snap)
mask_lum = mask_lum[sort_idx]
mask_dark = mask_dark[sort_idx]
# mask_sat, mask_isol = dataset_compute.split_satellites_by_group_number(snap, m31, mw)
mask_sat, mask_isol = dataset_compute.split_satellites_by_distance(snap, m31, mw)
mask_m31 = mask_sat[0][sort_idx]
mask_mw = mask_sat[1][sort_idx]
mask_isol = mask_isol[sort_idx]

# Prune out potential spurious:
mask_nonzero_vmax = dataset_compute.prune_vmax(snap)[sort_idx]
vmax = vmax[mask_nonzero_vmax]
mask_lum = mask_lum[mask_nonzero_vmax]
mask_dark = mask_dark[mask_nonzero_vmax]
mask_m31 = mask_m31[mask_nonzero_vmax]
mask_mw = mask_mw[mask_nonzero_vmax]
mask_isol = mask_isol[mask_nonzero_vmax]

# Add minimal mass dummy halo (to continue the curves to the y-axis):
vmax = np.concatenate([vmax, [0.01]])
mask_lum = np.concatenate([mask_lum, [True]])
mask_dark = np.concatenate([mask_dark, [True]])
mask_m31 = np.concatenate([mask_m31, [True]])
mask_mw = np.concatenate([mask_mw, [True]])
mask_isol = np.concatenate([mask_isol, [True]])


In [None]:
fig, ax = plt.subplots()

ax.set_ylim(1, 5000)
ax.set_xlim(5, 110)
ax.set_xscale('log')
ax.set_yscale('log')

# mask_plot = np.logical_and(mask_lum, np.logical_not(np.logical_or(mask_mw, mask_m31)))
# mask_plot = np.logical_and(mask_lum, mask_isol)
# mask_plot = np.logical_and(mask_lum, np.logical_or(mask_mw, mask_m31))
mask_plot = np.array([True]*vmax.size)
ax.plot(vmax[mask_plot], np.arange(1, np.sum(mask_plot) + 1))

plt.savefig("WTFFFF/wtf_all_{}.png".format(snap_id), dpi=200)