## First, imports:

In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from astropy import units
from astropy.cosmology import FlatLambdaCDM, z_at_value

Import my library:

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join("..", "apostletools"))
sys.path.append(apt_path)

import snapshot
import dataset_comp

In [None]:
import importlib
importlib.reload(snapshot)
importlib.reload(dataset_comp)

# Mass Distribution of Subhalos

To visualize the mass distribution, at $z=0$, I plot a simple subhalo count accumulation curve, which at any point on the mass axis gives the total number of subhalos (or of satellite or isolated galaxies) with masses larger than the given mass. Mass is measured by $v_\mathrm{max} = \max_{r} \sqrt{\frac{G M(<r)}{r}}$.

## Motivation

Plenty of interesting and instructive observations can be made from the mass distribution figure alone. For the curv-p08* simulation, we expect find an indication of inhibited structure formation towards the small scales, relative to the plain-lcdm. However, on the larger scales, the curves from different simulations should approach each other (although random effects from low number counts also come into play). 

But the absence of small-scale power in the initial physical power spectrum means that the component of power that is due to numerical noise also becomes more significant --- and, indeed, will dominate on small enough scales (shot noise becomes more relevant towards the resolution limits of the simulation). This is expected to be visible in the mass function as well.

Furthermore, the mass functions satellite galaxies and isolated galaxies in any simulations should be expected to differ. 

From the single plot alone, it is impossible to make the connection between any particular feature of the mass function and any of the above-mentioned potential causes. Of course, we are also only looking at a single simulated instance of the LG.

---

## Set Parameters for the Plots

Choose the snapshot and the simulations, and define M31 and MW in each simulation:

Create a dictionary of the datasets from each simulation. 

In [None]:
snap_id = 127
data = {
    "plain-LCDM-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (2, 0),
        "Color": ["gray"]
    },
    "curv-p082-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p082_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 1),
        "Color": ["pink"]
    },
    "curv-p084-LR": {
        "Snapshot": snapshot.Snapshot("V1_LR_curvaton_p084_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 0),
        "Color": ["lightblue"]
    },
    "plain-LCDM": {
        "Snapshot": snapshot.Snapshot("V1_MR_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (2, 0),
        "Color": ["black", "gray"]
    },
    "curv-p082": {
        "Snapshot": snapshot.Snapshot("V1_MR_curvaton_p082_fix", snap_id),
        "M31_ID": (1, 0),
        "MW_ID": (1, 1),
        "Color": ["red", "pink"]
    }
}

---

## Retrieve Data

### Create a Dictionary

For easy handling of the relevant data, define a data dictionary that, at the top level, has entries for all simulations. Under each simulation sim_data, add items for the needed datasets and, under the "Selections" key, a sub-dictionary of masking arrays for each needed condition (e.g. satellite, luminous, $v_\mathrm{max}$ inside range, etc.).

In [None]:
# Define the cosmology (should be the same for each simulation):
H0 = data["plain-LCDM"]["Snapshot"].get_attribute("HubbleParam", "Header")
Om0 = data["plain-LCDM"]["Snapshot"].get_attribute("Omega0", "Header")
cosmo = FlatLambdaCDM(H0=100 * H0, Om0=Om0) 

Then, loop over simulations, retrieve data, compute masking arrays, and add to the dictionary:

In [None]:
low_sm = 10**7

for key, sim_data in data.items():    
    # Get data:
    snap = sim_data["Snapshot"]
    sm = snap.get_subhalos("Stars/Mass") * units.g.to(units.Msun)
    sfz = snap.get_subhalos("InitialMassWeightedBirthZ")
    sf_time = cosmo.age(0).value - np.array([cosmo.age(z).value for z in sfz])
    max_point = snap.get_subhalos("Max_Vcirc", "Extended")
    vmax = max_point[:,0] * units.cm.to(units.km)
    
    # Read the scale factor at formation time for each star particle in each subhalo
    # and convert to lookback time:    
    sf_a = dataset_comp.group_particles_by_subhalo(
        snap, "StellarFormationTime", part_type=[4]
    )["StellarFormationTime"]
    # Get the star-formation onset times of subhalos, as the formation times of their
    # earliest star particles:
    soz = np.array([np.max(1 / a - 1) if a.size > 0 else np.nan for a in sf_a])
    sim_data["OnsetZ"] = soz    
    sim_data["StarFormationOnset"] = np.array([
        cosmo.age(0).value - cosmo.age(z).value for z in soz
    ])
    
    # Compute masking arrays:
    mask_m31, mask_mw, mask_isol = dataset_comp.split_satellites_by_distance(
        snap, sim_data["M31_ID"], sim_data["MW_ID"]
    )
    mask_sat = np.logical_or(mask_m31, mask_mw)
    mask_lum, mask_dark = dataset_comp.split_luminous(snap)
    mask_vmax = dataset_comp.prune_vmax(snap, low_lim=10)

    # Sort by vmax and add a dummy point with very small vmax 
    # (to continue the curves to the y-axis):
    sim_data["SFZ"] = sfz
    sim_data["SFTime"] = sf_time
    sim_data["SM"] = sm
    sim_data["Vmax"] = vmax
    
    data[key]["Selections"] = {
        "M31": mask_m31,
        "MW": mask_mw,
        "Satellite": mask_sat,
        "Isolated": mask_isol,
        "Luminous": mask_lum,
        "Dark": mask_dark,
        "Vmax": mask_vmax,
        "SM": (sm > low_sm)
    }

## Plot Only Total Counts

In [None]:
def count_curve(arr, norm=None):
    
    if not norm:
        norm = arr.size
    
    arr.sort()
    
    n = arr.size
    x = np.append(10**-5, np.append(arr, arr[-1]))
    y = np.append(1, np.append(np.arange(1, n + 1)[::-1] / norm, 0))
    
    return x, y

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 10,
              'axes.labelsize': 8,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'legend.fontsize': 8,
              'legend.title_fontsize': 7}

# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(6, 3), sharey="row")
plt.subplots_adjust(wspace=0)

# Set axis:
axes[0].set_ylim(0, 1.2)
axes[0].set_ylabel("$f\,(>M_*)$")
for ax in axes:
    # ax.set_xlim(5*10**4, 5*10**9)
    ax.set_xlabel("$M_*[M_\odot]$")
    ax.invert_xaxis()
    # ax.set_xscale("log")
    # ax.set_yscale("log")
    
axes[0].set_title("Satellite galaxies")
axes[1].set_title("Isolated galaxies")

data_mr = {key: data[key] for key in data.keys() 
           & {'plain-LCDM', 'curv-p082'}}
for name, sim_data in data_mr.items():
               
    sf_time = sim_data["SFZ"]
    so_time = sim_data["OnsetZ"]
    vmax = sim_data["Vmax"]
        
    # ISOLATED GALAXIES
    # -----------------
    
    mask_isol = np.logical_and(sim_data["Selections"]["Luminous"],
                               sim_data["Selections"]["Isolated"])
    
    # print(sm[mask_lum])
    
    mask_vmax = np.logical_and((vmax > 10), (vmax < 25))
    mask = np.logical_and(mask_isol, mask_vmax)
    x, y = count_curve(so_time[mask])
    axes[1].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="solid", label=sim_name)

    mask_vmax = (vmax > 25)
    mask = np.logical_and(mask_isol, mask_vmax)
    x, y = count_curve(so_time[mask])
    axes[1].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="dashed", label=sim_name)

    # SATELLITES
    # ----------
        
    mask_sat = np.logical_and(sim_data["Selections"]["Luminous"],
                              sim_data["Selections"]["Satellite"])
    
    # print(sm[mask_lum])
    
    mask_vmax = np.logical_and((vmax > 10), (vmax < 25))
    mask = np.logical_and(mask_sat, mask_vmax)
    x, y = count_curve(so_time[mask])
    axes[0].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="solid", label=sim_name)

    mask_vmax = (vmax > 25)
    mask = np.logical_and(mask_sat, mask_vmax)
    x, y = count_curve(so_time[mask])
    axes[0].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="dashed", label=sim_name)
    
    
    mask_vmax = (vmax > 25)
    mask = np.logical_and.reduce([mask_sat, mask_vmax,
                                  sim_data["Selections"]["SM"]])
    x, y = count_curve(so_time[mask])
    axes[0].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="dotted", label=sim_name)
    
    
    x, y = count_curve(sf_time[mask])
    axes[0].plot(x, y, c=sim_data["Color"][0], 
                 linestyle="dashdot", label=sim_name)

axes[1].legend()
    
plt.tight_layout()

Ideas:
- Plot separately $v_\mathrm{max} > 25$ km/s and below (where the mass functions decouple). If this is the mass scale, where the cut-off kicks in, the difference could be very notable below it.
- Or maybe plot: all, $v_\mathrm{max} > 25$ km/s, and $v_\mathrm{max} > 15$ km/s, normalized to the total number (or then, just offset?)
- Or just make a scatter plot, as the bins contain so few points?


About half of these galaxies have masses below $10^7 M_\odot$, so is there much sense in this plot (most of the low-mass halos are also the hosts of low-mass galaxies)?

Maybe the onset plot is only really informative.

### Save the Figure

In [None]:
filename = "sm_distribution_norm.png"
    
path = os.path.abspath(os.path.join("..", "Figures", "MediumResolution"))
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches="tight")

## Plot Scatter: Stellar Ages vs. $M_*$

In [None]:
# Choose font sizes:
parameters = {'axes.titlesize': 12,
              'axes.labelsize': 10,
              'xtick.labelsize': 9,
              'ytick.labelsize': 9,
              'legend.fontsize': 10}

# Marker size
ms = 15
a = 0.7

# Set fonts:
plt.rcParams.update(parameters)
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(sharey=True, figsize=(3, 3))

ax.set_xscale('log')
# ax.set_xlim(10**6, 5 * 10**9)
# ax.set_ylim(-4, 1.5)    
ax.set_xlabel('$v_\mathrm{max} [\mathrm{km/s}]$')
ax.set_ylabel('Mean Stellar Age [Gyr]')

# Add scatter plots:
data_mr = {key: data[key] for key in data.keys() 
           & {'plain-LCDM', 'curv-p082'}}
for name, sim_data in data_mr.items():
    
    mask = np.logical_and(sim_data['Selections']['Luminous'],
                          sim_data['Selections']['Vmax'])
    
    mask_sat = np.logical_and(mask, sim_data['Selections']['Satellite'])
    x = sim_data['Vmax'][mask_sat]
    y = sim_data['SFZ'][mask_sat]
    
    ax.scatter(x, y, alpha=a, marker='+',
               c=sim_data["Color"][0], s=ms)
    
    mask_isol = np.logical_and(mask, sim_data['Selections']['Isolated'])
    x = sim_data['Vmax'][mask_isol]
    y = sim_data['SFZ'][mask_isol]
    ax.scatter(x, y, alpha=a, marker='s',
               facecolor='none', s=ms, edgecolor=sim_data["Color"][0])

# axes[0].legend(loc='lower right')

In [None]:
fig, ax = plt.subplots(sharey=True, figsize=(3, 3))

ax.set_xscale('log')
# ax.set_xlim(10**6, 5 * 10**9)
# ax.set_ylim(-4, 1.5)    
ax.set_xlabel('$v_\mathrm{max} [\mathrm{km/s}]$')
ax.set_ylabel('SF Onset [Gyr]')

# Add scatter plots:
data_mr = {key: data[key] for key in data.keys() 
           & {'plain-LCDM', 'curv-p082'}}
for name, sim_data in data_mr.items():
    
    mask = np.logical_and(sim_data['Selections']['Luminous'],
                          sim_data['Selections']['Vmax'])
    
    mask_sat = np.logical_and(mask, sim_data['Selections']['Satellite'])
    x = sim_data['Vmax'][mask_sat]
    y = sim_data['OnsetZ'][mask_sat]
    
    ax.scatter(x, y, alpha=a, marker='+',
               c=sim_data["Color"][0], s=ms)
    
    mask_isol = np.logical_and(mask, sim_data['Selections']['Isolated'])
    x = sim_data['Vmax'][mask_isol]
    y = sim_data['OnsetZ'][mask_isol]
    ax.scatter(x, y, alpha=a, marker='s',
               facecolor='none', s=ms, edgecolor=sim_data["Color"][0])

# axes[0].legend(loc='lower right')

In [None]:
fig, ax = plt.subplots(sharey=True, figsize=(3, 3))

ax.set_xscale('log')
# ax.set_xlim(10**6, 5 * 10**9)
# ax.set_ylim(-4, 1.5)    
ax.set_xlabel('$M_*[\mathrm{M_\odot}]$')
ax.set_ylabel('Mean Stellar Age [Gyr]')

# Add scatter plots:
data_mr = {key: data[key] for key in data.keys() 
           & {'plain-LCDM', 'curv-p082'}}
for name, sim_data in data_mr.items():
    
    mask = np.logical_and(sim_data['Selections']['Satellite'],
                          sim_data['Selections']['Vmax'])
    x = sim_data['SM'][mask]
    y = sim_data['SFTime'][mask]
    
    ax.scatter(x, y, alpha=a, marker='+',
               c=sim_data["Color"][0], s=ms)
    
    mask = np.logical_and(sim_data['Selections']['Isolated'],
                          sim_data['Selections']['Vmax'])
    x = sim_data['SM'][mask]
    y = sim_data['SFTime'][mask]
    ax.scatter(x, y, alpha=a, marker='s',
               facecolor='none', s=ms, edgecolor=sim_data["Color"][0])

# axes[0].legend(loc='lower right')

In [None]:
fig, ax = plt.subplots(sharey=True, figsize=(3, 3))

ax.set_xscale('log')
# ax.set_xlim(10**6, 5 * 10**9)
# ax.set_ylim(-4, 1.5)    
ax.set_xlabel('$M_*[\mathrm{M_\odot}]$')
ax.set_ylabel('SF Onset [Gyr]')

# Add scatter plots:
data_mr = {key: data[key] for key in data.keys() 
           & {'plain-LCDM', 'curv-p082'}}
for name, sim_data in data_mr.items():
    
    mask = np.logical_and(sim_data['Selections']['Satellite'],
                          sim_data['Selections']['Vmax'])
    x = sim_data['SM'][mask]
    y = sim_data['StarFormationOnset'][mask]
    
    ax.scatter(x, y, alpha=a, marker='+',
               c=sim_data["Color"][0], s=ms)
    
    mask = np.logical_and(sim_data['Selections']['Isolated'],
                          sim_data['Selections']['Vmax'])
    x = sim_data['SM'][mask]
    y = sim_data['StarFormationOnset'][mask]
    ax.scatter(x, y, alpha=a, marker='s',
               facecolor='none', s=ms, edgecolor=sim_data["Color"][0])

# axes[0].legend(loc='lower right')

In [None]:
ax.axvline(10**7, c='gray', linestyle="dotted")

ax.scatter([], [], c=data["plain-LCDM"]["Color"][0], alpha=a, marker='+',
           s=ms, label="Satellite")
ax.scatter([], [], edgecolor=data["plain-LCDM"]["Color"][0], alpha=a, marker='s',
           facecolor='none', s=ms, label="Isolated")
ax.legend(loc='lower right')

fig

### Save the Figure

In [None]:
# Construct saving location:
filename = 'metallicity'
for name in data.keys():
    filename += "_{}".format(name)
filename += ".png"
        
home = os.path.abspath(os.path.join('..'))
path = os.path.join(home,'Figures', 'MediumResolution')
filename = os.path.join(path, filename)

fig.savefig(filename, dpi=300, bbox_inches='tight')