In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from astropy import units
from astropy.constants import G
import importlib

import dataset_compute
import snapshot_obj

In [None]:
importlib.reload(dataset_compute)
importlib.reload(snapshot_obj)

In [None]:
sim_id = "V1_MR_fix"
m31 = (1, 0)
mw = (2, 0)
snap = snapshot_obj.Snapshot(sim_id, 127)
cops = snap.get_subhalos("CentreOfPotential") * units.cm.to(units.Mpc)
mask_vmax = dataset_compute.prune_vmax(snap)

masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(snap, m31, mw)
split_cops_by_r = {"M31_satellites": cops[np.logical_and(masks_sat[0], mask_vmax)],
                   "MW_satellites": cops[np.logical_and(masks_sat[1], mask_vmax)],
                   "isolated": cops[np.logical_and(mask_isol, mask_vmax)]}

masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(snap, m31, mw)
split_cops_by_gn = {"M31_satellites": cops[np.logical_and(masks_sat[0], mask_vmax)],
                    "MW_satellites": cops[np.logical_and(masks_sat[1], mask_vmax)],
                    "isolated": cops[np.logical_and(mask_isol, mask_vmax)]}

In [None]:
def circle(centre_x, centre_y, r, n):
    t = 4*np.pi/n * np.arange(n/2)
    x = centre_x + r * np.cos(t)
    y = centre_y + r * np.sin(t)
    return x, y

In [None]:
s = 0.1 # marker size
lw = 0.3 # line width

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharey='row', sharex='col',
                        figsize=(7,7))
plt.subplots_adjust(wspace=0.04)
plt.subplots_adjust(hspace=0.04)

m31_idx = snap.index_of_halo(m31[0], m31[1])
mw_idx = snap.index_of_halo(mw[0], mw[1])
LG_centre = dataset_compute.compute_LG_centre(snap, m31, mw) \
                * units.cm.to(units.Mpc)

ax_size = 5
axes[1,0].set_xlim(LG_centre[0] - ax_size/2, 
                   LG_centre[0] + ax_size/2)
axes[1,1].set_xlim(LG_centre[0] - ax_size/2, 
                   LG_centre[0] + ax_size/2)
axes[1,0].set_xlabel("$x [\mathrm{Mpc}]$")
axes[1,1].set_xlabel("$x [\mathrm{Mpc}]$")

axes[0,0].set_ylim(LG_centre[1] - ax_size/2, 
                   LG_centre[1] + ax_size/2)
axes[1,0].set_ylim(LG_centre[2] - ax_size/2, 
                   LG_centre[2] + ax_size/2)
axes[0,0].set_ylabel("$y [\mathrm{Mpc}]$")
axes[1,0].set_ylabel("$z [\mathrm{Mpc}]$")

# for ax in axes.flatten():
#     ax.set_xlim(7,12)
#     ax.set_xlabel("$x [\mathrm{Mpc}]$")

# for ax in axes.T:
#     ax[0].set_ylim(15,19)
#     ax[0].set_ylabel("$y [\mathrm{Mpc}]$")
#     ax[1].set_ylim(82,86)
#     ax[1].set_ylabel("$z [\mathrm{Mpc}]$")
    
axes[0,0].set_title("By group number")
axes[0,1].set_title("By distance")
    
x=0; y=1
axes[0,0].scatter(split_cops_by_gn['isolated'][:,x], 
                  split_cops_by_gn['isolated'][:,y], c='gray', s=s)
axes[0,0].scatter(split_cops_by_gn['M31_satellites'][:,x], 
                  split_cops_by_gn['M31_satellites'][:,y], c='pink', s=s)
axes[0,0].scatter(split_cops_by_gn['MW_satellites'][:,x], 
                  split_cops_by_gn['MW_satellites'][:,y], c='lightblue', s=s)

axes[0,0].scatter(cops[m31_idx,x], cops[m31_idx,y], c='red', s=s)
x_circ, y_circ = circle(cops[m31_idx, x], cops[m31_idx, y], 0.3, 10000)
axes[0,0].plot(x_circ, y_circ, c='red', linestyle='dashed', linewidth=lw)
axes[0,0].scatter(cops[mw_idx,x], cops[mw_idx,y], c='blue', s=s)
x_circ, y_circ = circle(cops[mw_idx, x], cops[mw_idx, y], 0.3, 10000)
axes[0,0].plot(x_circ, y_circ, c='blue', linestyle='dashed', linewidth=lw)
axes[0,0].scatter(LG_centre[x], LG_centre[y], c='k', s=s)
x_circ, y_circ = circle(LG_centre[x], LG_centre[y], 2, 10000)
axes[0,0].plot(x_circ, y_circ, c='k', linestyle='dashed', linewidth=lw)

axes[0,1].scatter(split_cops_by_r['isolated'][:,x], 
                  split_cops_by_r['isolated'][:,y], c='gray', s=s)
axes[0,1].scatter(split_cops_by_r['M31_satellites'][:,x], 
                  split_cops_by_r['M31_satellites'][:,y], c='pink', s=s)
axes[0,1].scatter(split_cops_by_r['MW_satellites'][:,x], 
                  split_cops_by_r['MW_satellites'][:,y], c='lightblue', s=s)

axes[0,1].scatter(cops[mw_idx,x], cops[mw_idx,y], c='blue', s=s)
x_circ, y_circ = circle(cops[m31_idx, x], cops[m31_idx, y], 0.3, 10000)
axes[0,1].plot(x_circ, y_circ, c='red', linestyle='dashed', linewidth=lw)
x_circ, y_circ = circle(cops[mw_idx, x], cops[mw_idx, y], 0.3, 10000)
axes[0,1].scatter(cops[m31_idx,x], cops[m31_idx,y], c='red', s=s)
axes[0,1].plot(x_circ, y_circ, c='blue', linestyle='dashed', linewidth=lw)
axes[0,1].scatter(LG_centre[x], LG_centre[y], c='k', s=s)
x_circ, y_circ = circle(LG_centre[x], LG_centre[y], 2, 10000)
axes[0,1].plot(x_circ, y_circ, c='k', linestyle='dashed', linewidth=lw)

x=0; y=2
axes[1,0].scatter(split_cops_by_gn['isolated'][:,x], 
                  split_cops_by_gn['isolated'][:,y], c='gray', s=s)
axes[1,0].scatter(split_cops_by_gn['M31_satellites'][:,x], 
                  split_cops_by_gn['M31_satellites'][:,y], c='pink', s=s)
axes[1,0].scatter(split_cops_by_gn['MW_satellites'][:,x], 
                  split_cops_by_gn['MW_satellites'][:,y], c='lightblue', s=s)


axes[1,0].scatter(cops[m31_idx,x], cops[m31_idx,y], c='red', s=s)
x_circ, y_circ = circle(cops[m31_idx, x], cops[m31_idx, y], 0.3, 10000)
axes[1,0].plot(x_circ, y_circ, c='red', linestyle='dashed', linewidth=lw)
axes[1,0].scatter(cops[mw_idx,x], cops[mw_idx,y], c='blue', s=s)
x_circ, y_circ = circle(cops[mw_idx, x], cops[mw_idx, y], 0.3, 10000)
axes[1,0].plot(x_circ, y_circ, c='blue', linestyle='dashed', linewidth=lw)
axes[1,0].scatter(LG_centre[x], LG_centre[y], c='k', s=s)
x_circ, y_circ = circle(LG_centre[x], LG_centre[y], 2, 10000)
axes[1,0].plot(x_circ, y_circ, c='k', linestyle='dashed', linewidth=lw)

axes[1,1].scatter(split_cops_by_r['isolated'][:,x], 
                  split_cops_by_r['isolated'][:,y], c='gray', s=s)
axes[1,1].scatter(split_cops_by_r['M31_satellites'][:,x], 
                  split_cops_by_r['M31_satellites'][:,y], c='pink', s=s)
axes[1,1].scatter(split_cops_by_r['MW_satellites'][:,x], 
                  split_cops_by_r['MW_satellites'][:,y], c='lightblue', s=s)

axes[1,1].scatter(cops[m31_idx,x], cops[m31_idx,y], c='red', s=s)
x_circ, y_circ = circle(cops[m31_idx, x], cops[m31_idx, y], 0.3, 10000)
axes[1,1].plot(x_circ, y_circ, c='red', linestyle='dashed', linewidth=lw)
axes[1,1].scatter(cops[mw_idx,x], cops[mw_idx,y], c='blue', s=s)
x_circ, y_circ = circle(cops[mw_idx, x], cops[mw_idx, y], 0.3, 10000)
axes[1,1].plot(x_circ, y_circ, c='blue', linestyle='dashed', linewidth=lw)
axes[1,1].scatter(LG_centre[x], LG_centre[y], c='k', s=s)
x_circ, y_circ = circle(LG_centre[x], LG_centre[y], 2, 10000)
axes[1,1].plot(x_circ, y_circ, c='k', linestyle='dashed', linewidth=lw)

filename = 'distinction_comparison_{}.png'.format(sim_id)
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures")
filename = os.path.join(path, filename)
plt.savefig(filename, dpi=300)