In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os

import snapshot_obj
import simulation
import dataset_compute

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(simulation)
importlib.reload(dataset_compute)

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,0)]

data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

In [None]:
snap_id = 126
sim_ids = ["V1_MR_fix", "V1_MR_curvaton_p082_fix"]
names = ["LCDM", "p082"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix"]
names = ["LCDM", "p082"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"
low = (15,25)
semilow = (25,35)
semihigh = (35,45)
high = (45,55)

Read datasets, split into satellites and isolated galaxies, and add to the data dictionary. We also disregard dark halos and potential spurious halos with $v_\mathrm{max} = 0$.

In [None]:
for name, sim_data in data.items():
    # Get data:
    snap = sim_data["snapshot"]
    rot_curves = snap.get_subhalos(
        'Vcirc', group='Extended/RotationCurve/All')
    sub_offset = snap.get_subhalos(
        'SubOffset', group='Extended/RotationCurve/All')
    v_circ = rot_curves[:,0] * units.cm.to(units.km)
    radii = rot_curves[:,1] * units.cm.to(units.kpc)
    v_circ = np.array(np.split(v_circ, sub_offset[1:]))
    radii = np.array(np.split(radii, sub_offset[1:]))
    
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_low_vmax = dataset_compute.prune_vmax(snap, low_lim=low[0],
                                               up_lim=low[1])
    mask_semilow_vmax = dataset_compute.prune_vmax(snap, low_lim=semilow[0],
                                               up_lim=semilow[1])
    mask_semihigh_vmax = dataset_compute.prune_vmax(snap, low_lim=semihigh[0],
                                               up_lim=semihigh[1])
    mask_high_vmax = dataset_compute.prune_vmax(snap, low_lim=high[0],
                                               up_lim=high[1])

    vmax_masks = [mask_low_vmax, mask_semilow_vmax, 
                  mask_semihigh_vmax, mask_high_vmax]
    vmax_keys = ["Low", "SemiLow", "SemiHigh", "High"]
    # Add to dictionary:
    for dataset_name, dataset in zip(["VCirc", "Radius"],[v_circ, radii]):
        data[name][dataset_name] = \
        {"satellites":
         {key: dataset[np.logical_and.reduce(
             [np.logical_or.reduce(masks_sat), mask_lum, mask]
          )] for key, mask in zip(vmax_keys, vmax_masks)},
         "isolated": 
         {key: dataset[np.logical_and.reduce(
             [mask_isol, mask_lum, mask]
          )] for key, mask in zip(vmax_keys, vmax_masks)}
        }

## Plot

In [None]:
# Set some parameters:
x_down = 10; x_up = 100
y_down = 5*10**5; y_up = 2*10**10

# Set marker styles:
fcolor = ["black", "red", "blue", "green"]
mcolor = ["gray", "pink", "lightblue", "lightgreen"]
marker = ['+', "o", "^", 1]

In [None]:
# Construct saving location:
filename = 'SM_vs_Vmax_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
filename += ".png"
    
home = os.path.dirname(snapshot_obj.__file__)
path = os.path.join(home,"Figures")
filename = os.path.join(path, filename)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(14,14))
# plt.subplots_adjust(wspace=0.3)

# Set axis:
axes[0,0].set_xlim(0,10)
axes[0,1].set_xlim(0,10)
axes[1,0].set_xlim(0,20)
axes[1,1].set_xlim(0,20)

axes[0,0].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               low[0], low[1]), 
               transform=axes[0,0].transAxes)
axes[0,1].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               semilow[0], semilow[1]), 
               transform=axes[0,1].transAxes)
axes[1,0].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               semihigh[0], semihigh[1]), 
               transform=axes[1,0].transAxes)
axes[1,1].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               high[0], high[1]), 
               transform=axes[1,1].transAxes)

# fig.set_title('Satellite galaxies')

# Add scatter plots:
for i, (name, entry) in enumerate(data.items()):
    v_circ = entry['VCirc']['satellites']['Low']
    radius = entry['Radius']['satellites']['Low']
    for v, r in zip(v_circ, radius):
        axes[0,0].plot(r, v, c=mcolor[i])
    
    v_circ = entry['VCirc']['satellites']['SemiLow']
    radius = entry['Radius']['satellites']['SemiLow']
    for v, r in zip(v_circ, radius):
        axes[0,1].plot(r, v, c=mcolor[i])
        
        
    v_circ = entry['VCirc']['satellites']['SemiHigh']
    radius = entry['Radius']['satellites']['SemiHigh']
    for v, r in zip(v_circ, radius):
        axes[1,0].plot(r, v, c=mcolor[i])
        
    v_circ = entry['VCirc']['satellites']['High']
    radius = entry['Radius']['satellites']['High']
    for v, r in zip(v_circ, radius):
        axes[1,1].plot(r, v, c=mcolor[i])
        
# axes[0].legend(loc='lower right')
plt.tight_layout()

# plt.savefig(filename, dpi=200)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(14,14))
# plt.subplots_adjust(wspace=0.3)

# Set axis:
axes[0,0].set_xlim(0,10)
axes[0,1].set_xlim(0,10)
axes[1,0].set_xlim(0,20)
axes[1,1].set_xlim(0,20)


axes[0,0].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               low[0], low[1]), 
               transform=axes[0,0].transAxes)
axes[0,1].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               semilow[0], semilow[1]), 
               transform=axes[0,1].transAxes)
axes[1,0].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               semihigh[0], semihigh[1]), 
               transform=axes[1,0].transAxes)
axes[1,1].text(0.6, 0.1, 
               "${} \mathrm{{km/s}} < v_\mathrm{{max}} < {} \mathrm{{km/s}}$".format(
               high[0], high[1]), 
               transform=axes[1,1].transAxes)

# fig.set_title('Satellite galaxies')

# Add scatter plots:
for i, (name, entry) in enumerate(data.items()):
    v_circ = entry['VCirc']['isolated']['Low']
    radius = entry['Radius']['isolated']['Low']
    for v, r in zip(v_circ, radius):
        axes[0,0].plot(r, v, c=mcolor[i])
    
    v_circ = entry['VCirc']['isolated']['SemiLow']
    radius = entry['Radius']['isolated']['SemiLow']
    for v, r in zip(v_circ, radius):
        axes[0,1].plot(r, v, c=mcolor[i])
        
        
    v_circ = entry['VCirc']['isolated']['SemiHigh']
    radius = entry['Radius']['isolated']['SemiHigh']
    for v, r in zip(v_circ, radius):
        axes[1,0].plot(r, v, c=mcolor[i])
        
    v_circ = entry['VCirc']['isolated']['High']
    radius = entry['Radius']['isolated']['High']
    for v, r in zip(v_circ, radius):
        axes[1,1].plot(r, v, c=mcolor[i])
        
# axes[0].legend(loc='lower right')
plt.tight_layout()

# plt.savefig(filename, dpi=200)