In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
from astropy import units
import os
import astropy.units as u
import matplotlib.pyplot as plt

import snapshot_obj
import dataset_compute

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p084_fix", "V1_LR_curvaton_p082_fix"]
names = ["LCDM", "p082", "p084"]
data = {name : {} for name in names}
for sim_id, name in zip(sim_ids, names):
    # Get data:
    snap = snapshot_obj.Snapshot(sim_id, snap_id, name)
    cops = snap.get_subhalos("CentreOfPotential")
    M31 = cops[snap.index_of_halo(1, 0)]
    MW = cops[snap.index_of_halo(2, 0)]
    
    # Split into satellites:
    cops_M31 = dataset_compute.get_satellites(snap, cops,
                                          galaxy=1, split_luminous=True)
    cops_MW = dataset_compute.get_satellites(snap, cops,
                                          galaxy=2, split_luminous=True)

    # Add to dictionary:
    data[name] = {"M31": {"gal_centre": M31, "radius": cops_M31, "sat_count": {}},
                  "MW": {"gal_centre": MW, "radius": cops_MW, "sat_count": {}}}
    
    for gal_key, gal_dict in data[name].items():
        gal_centre = gal_dict["gal_centre"]
        for lum_key, x in gal_dict["radius"].items():
            # Compute distance to central:
            r = dataset_compute.periodic_wrap(snap, gal_centre, x)
            r = np.linalg.norm(r - gal_centre, axis=1) * units.cm.to(units.kpc)
            
            # Sort and add ~zero point to continue line to the side of the figure:
            r = np.sort(r)
            data[name][gal_key]['radius'][lum_key] = \
                np.concatenate((r, np.array([1000])))
        
            data[name][gal_key]['sat_count'][lum_key] = \
                np.concatenate((np.arange(1, r.size+1), np.array([r.size])))

In [None]:
# Set some parameters:
x_down = 0; x_up = 1000
y_down = 0; y_up = 50

# Set colors:
color = ["black", "red", "blue", "green"]

In [None]:
# Construct saving location:
filename = 'satellite_radial_distribution'
for name in names:
    filename += "_{}".format(name)
filename += ".png"
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:
    #ax.set_yscale('log')
    ax.set_xlim(x_down, x_up)
    ax.set_ylim(y_down, y_up)
    ax.set_xlabel('$r[\mathrm{kpc}]$', fontsize=16)
    ax.set_ylabel('$N(<r)/N_\mathrm{tot}$', fontsize=16)
    
axes[0].set_title('Andromeda')
axes[1].set_title('Milky Way')

# Add scatter plots:
for i, (name, entry) in enumerate(data.items()):
    x = entry['M31']['radius']['luminous']
    y = entry['M31']['sat_count']['luminous']
    axes[0].plot(x, y, c=color[i], label='{} luminous'.format(name))
    x = entry['M31']['radius']['dark']
    y = entry['M31']['sat_count']['dark']
    axes[0].plot(x, y, c=color[i], linestyle='--', label='{} dark'.format(name))
    
    x = entry['MW']['radius']['luminous']
    y = entry['MW']['sat_count']['luminous']
    axes[1].plot(x, y, c=color[i], label='{} luminous'.format(name))
    x = entry['MW']['radius']['dark']
    y = entry['MW']['sat_count']['dark']
    axes[1].plot(x, y, c=color[i], linestyle='--', label='{} dark'.format(name))

axes[0].legend(loc='upper right')
plt.tight_layout()

#plt.savefig(filename, dpi=200)