## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import units
from astropy.constants import G

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)

import snapshot
import dataset_comp

In [None]:
import importlib
importlib.reload(dataset_comp)
importlib.reload(snapshot)

# Locations of Most Massive Halos 

In [None]:
snap_id = 127
data = {
    "plain-LCDM": {
        "Snapshot": snapshot.Snapshot("V1_MR_fix", snap_id),
        "Groups": np.array([[1,0], [1,1], [1,2], [2,0], [2,1], [3,0], [4,0], [5,0]])
    },
    "curv-p082": {
        "Snapshot": snapshot.Snapshot("V1_MR_curvaton_p082_fix", snap_id),
        "Groups": np.array([[1,0], [1,1], [1,2], [2,0], [2,1], [3,0], [4,0], [5,0]])
    }
}

In [None]:
snap_id = 127
data = {
    "plain-LCDM-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_fix", snap_id),
        "Groups": np.array([[1,0], [1,1], [1,2], [2,0], [2,1], [3,0], [4,0], [5,0]])
    },
    "curv-p082-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p082_fix", snap_id),
        "Groups": np.array([[1,0], [1,1], [1,2], [2,0], [2,1], [3,0], [4,0], [5,0]])
    },
    "curv-p084-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p084_fix", snap_id),        
        "Groups": np.array([[1,0], [1,1], [1,2], [2,0], [2,1], [3,0], [4,0], [5,0]])
    }
}

In [None]:
for key, sim_data in data.items():
    snap = sim_data["Snapshot"]
    cops = snap.get_subhalos("CentreOfPotential")
    coords = dataset_comp.group_particles_by_subhalo(snap, "Coordinates")["Coordinates"]
    
    cops_sel = []
    coords_sel = []
    for group in sim_data["Groups"]:
        idx = snap.index_of_halo(group[0], group[1])
        cops_sel.append(cops[idx] * units.cm.to(units.Mpc))
        coords_sel.append(coords[idx] * units.cm.to(units.Mpc))
    
    sim_data["CentreOfPotential"] = cops_sel
    sim_data["Coordinates"] = coords_sel

## Plot

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 8,
              'axes.labelsize': 7,
              'xtick.labelsize': 6,
              'ytick.labelsize': 6,
              'legend.fontsize': 7,
              'legend.title_fontsize': 7}

In [None]:
filename = 'main_halo_coords.png'
    
path = os.path.abspath(os.path.join('..', 'Figures', 'LowResolution'))
filename = os.path.join(path, filename)

filename = os.path.join(path, filename)

In [None]:
projs = [["x", "y", 0, 1],
         ["x", "z", 0, 2]]

# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

fig, axes = plt.subplots(sharey='row', sharex='col', figsize=(6, 3.5),# figsize=(7.5,6.5),
                         ncols=len(data), nrows=len(projs))
plt.subplots_adjust(wspace=0.08)
plt.subplots_adjust(hspace=0.08)


for ax_col in axes.T:        
    ax_col[0].set_xlim(6,11)
    ax_col[0].set_ylim(16.5,21.5)
    ax_col[1].set_xlim(6,11)
    ax_col[1].set_ylim(82,87)
    ax_col[-1].set_xlabel("x [Mpc]")
    
axes[0,0].set_ylabel("y [Mpc]")
axes[1,0].set_ylabel("z [Mpc]")

col = [['lightblue', 'blue'],
['pink', 'crimson'],
['gray', 'black'],
['violet', 'darkviolet'],
['lightgreen', 'green'],
['yellow', 'gold'],
['orange', 'darkorange'],
['sandybrown', 'saddlebrown'],
['lightsteelblue', 'steelblue'],
['red', 'darkred']]


# Iterate through simulations (columns):
for i, (key, sim_data) in enumerate(data.items()):

    axes[0, i].set_title(key)
    
    # Iterate through projections (rows):
    for j, ax_set in enumerate(projs):
        
        n_groups = len(sim_data['Groups'])
        legends = [None] * n_groups
        # Iterate through halos and plot halo particles:
        n_skip = 1 # only plot every n_skip:th particle
        for idx, k in enumerate(range(n_groups)):
            part_x = sim_data["Coordinates"][k][::n_skip, ax_set[2]]
            part_y = sim_data["Coordinates"][k][::n_skip, ax_set[3]]
            axes[j,i].scatter(part_x, part_y, c=col[k][0], s=0.1)
            
        # Plot halo COPs:
        for idx, k in enumerate(range(n_groups)):
            cop_x = sim_data["CentreOfPotential"][k][ax_set[2]]
            cop_y = sim_data["CentreOfPotential"][k][ax_set[3]]
            axes[j,i].scatter(cop_x, cop_y, c=col[k][1], s=0.3)
            
            # Save items for the legend:
            legends[idx] = axes[j,i].scatter([], [], c=col[k][0], s=10)
   
labels = ["{}, {}".format(grp[0],grp[1]) for grp in data["plain-LCDM-LR"]["Groups"]]
plt.legend(legends, labels, 
           loc='lower left',
           bbox_to_anchor=(1, 0),
           title="GN, SGN")
# Adjust the scaling factor to fit your legend text completely outside the plot
# (smaller value results in more space being made for the legend)
plt.subplots_adjust(right=0.78)

plt.savefig(filename, dpi=300, bbox_inches='tight')