In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import units
from astropy.constants import G
import importlib

In [None]:
import os
import sys

apt_path = os.path.abspath(os.path.join('..', 'apostletools'))
sys.path.append(apt_path)
import dataset_comp
import snapshot
import curve_fit

In [None]:
importlib.reload(dataset_comp)
importlib.reload(snapshot)
importlib.reload(curve_fit)

# Metallicity

Metallicity of an individual star is defined here as the mass fraction of the metals: $Z^*$ $ = \frac{ \sum_{j>\text{He}} m_j}{\sum_k m_k}$. Metallicity of a galaxy is defined as the mass weighted average of the star particles of that galaxy: $Z = \frac{ \sum_j m_j Z^*_i}{\sum_k m_k}$ (in the EAGLE simulations). 

I further normalize the metallicity of a galaxy by the solar metallicity $Z_\odot = 0.0134$ (Wikipedia), and use the log-metallicity of the normalized quantity: $Z = \log_{10} \frac{ \sum_j m_j Z^*_i}{\sum_k m_k} - \log_{10} Z_\odot$.

## Motivation

What makes metallicity in low-mass halos lower:
- Gas and dust less bound; some high-metallicity material gets ejected into the inter-galactic medium
- Formed earlier, from less recycled gas
- Lower recycling rate (less active?)

A low-mass galaxy creates a relatively shallow potential well for the inter-stellar gas and dust. Thus, high-metallicity material is more easily ejected out of a low-mass galaxy, in a supernova explosion. Therefore, metallicity will tend to be higher in more massive galaxies. 

---

## Set Parameters for the Plots

Choose the snapshot and the simulations, and define M31 and MW in each simulation. Also, set the colors used for each simulation:

In [None]:
snap_id = 127
sim_ids = ['V1_MR_fix', 'V1_MR_curvaton_p082_fix']
names = ['plain-LCDM', 'spec-p082']
colors = ['black', 'red']

m31 = [(1,0), (1,0)]
mw = [(2,0), (1,1)]

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = 'by_r'
maxdi = 2000 # Maximum distance from LG centre for isolated

Set the low-mass threshold – subhalos, whose $v_\mathrm{max}$ falls below this (in km/s) are excluded as potentially non-physical:

In [None]:
lowm = 10

---

## Retrieve Data

### Create a Dictionary

For easy handling of the relevant data, define a data dictionary that, at the top level, has entries for all simulations. Under each simulation entry, add items for the needed datasets and, under the 'Selections' key, a sub-dictionary of masking arrays for each needed condition (e.g. satellite, luminous, $v_\mathrm{max}$ inside range, etc.).

First, add the above definitions into the data dict:

In [None]:
data = {}
for name, sim_id, m31_ns, mw_ns, col in zip(names, sim_ids, m31, mw, colors):
    data[name] = {'snapshot': snapshot.Snapshot(sim_id, snap_id, name=name),
                  'M31_identifier': m31_ns,
                  'MW_identifier': mw_ns,
                  'Color': col}

Then, loop over simulations, retrieve data, compute masking arrays, and add to the dictionary:

In [None]:
metal_sun = np.log(0.0134)

for name, sim_data in data.items():
    # Get data:\frac{ \sum_j m_j Z^*_i}{\sum_k m_k}
    snap = sim_data["snapshot"]
    sm = snap.get_subhalos("Stars/Mass") * units.g.to(units.Msun)
    metal = np.log(snap.get_subhalos("Stars/Metallicity")) - metal_sun
    
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_comp.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_comp.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    # Compute masking arrays:
    mask_m31 = masks_sat[0]
    mask_mw = masks_sat[1]
    mask_lum, mask_dark = dataset_comp.split_luminous(snap)
    
    # Prune potential spurious:
    mask_phys = dataset_comp.prune_vmax(snap, low_lim=10)
    mask_m31 = np.logical_and(mask_phys, mask_m31)
    mask_mw = np.logical_and(mask_phys, mask_mw)
    mask_isol = np.logical_and(mask_phys, mask_isol)
    mask_lum = np.logical_and(mask_phys, mask_lum)
    mask_dark = np.logical_and(mask_phys, mask_dark)
    
    # Add datasets to dictionary:
    data[name]['SM'] = sm
    data[name]['Metallicity'] = metal

    # Add selections (masking arrays):
    data[name]['Selections'] = {
        'M31': mask_m31,
        'MW': mask_mw,
        'Satellite': np.logical_or(mask_m31, mask_mw),
        'Isolated': mask_isol,
        'Luminous': mask_lum,
        'Dark': mask_dark
    }

## Plot

In [None]:
# Set some parameters:
x_down = 10**6; x_up = 10**11
y_down = -6; y_up = 1

# Set marker styles:
fcolor = ["black", "red", "blue", "green"]
mcolor = ["gray", "pink", "lightblue", "lightgreen"]
marker = ['+', "o", "^", 1]

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:
    ax.set_xscale('log')
    ax.set_xlim(x_down, x_up)
    ax.set_ylim(y_down, y_up)    
    ax.set_xlabel('$M_*[\mathrm{M_\odot}]$', fontsize=16)
    ax.set_ylabel('$<[Z]>$', fontsize=16)

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Add scatter plots:
for i, (name, entry) in enumerate(data.items()):
    mask = np.logical_and(entry['Selections']['Satellite'],
                          entry['Selections']['Luminous'])
    x = entry['SM'][mask]
    y = entry['Metallicity'][mask]
    axes[0].scatter(x, y, s=20, marker=marker[i], c=mcolor[i], \
        edgecolor='none', label=name)
    
    mask = np.logical_and(entry['Selections']['Isolated'],
                          entry['Selections']['Luminous'])
    x = entry['SM'][mask]
    y = entry['Metallicity'][mask]
    axes[1].scatter(x, y, s=20, marker=marker[i], c=mcolor[i], \
        edgecolor='none', label=name)

axes[0].legend(loc='lower right')
plt.tight_layout()

In [None]:
    
# Add median curves:
n_median_points = 7
for i, (name, entry) in enumerate(data.items()):
    mask = np.logical_and(entry['Selections']['Satellite'],
                          entry['Selections']['Luminous'])
    x = entry['SM'][mask]
    y = entry['Metallicity'][mask]
    print("# of satellites: {}".format(x.size))
    median = curve_fit.median_once_more(x, y, n_points_per_bar=n_median_points)
    if median is not None:
        axes[0].plot(median[0], median[1], c=fcolor[i], linestyle='--')
    else:
        print("Could not fit median for:", name)
    
    mask = np.logical_and(entry['Selections']['Isolated'],
                          entry['Selections']['Luminous'])
    x = entry['SM'][mask]
    y = entry['Metallicity'][mask]
    print("# of isolated galaxies: {}".format(x.size))
    median = curve_fit.median_once_more(x, y, n_points_per_bar=n_median_points)
    if median is not None:
        axes[1].plot(median[0], median[1], c=fcolor[i], linestyle='--')
    else:
        print("Could not fit median for:", name)
        
fig

### Save the Figure

In [None]:
# Construct saving location:
filename = 'metallicity_{}'.format(distinction)
for name in names:
    filename += "_{}".format(name)
filename += ".png"
        
home = os.path.abspath(os.path.join('..'))
path = os.path.join(home,'Figures', 'MediumResolution')
filename = os.path.join(path, filename)

plt.savefig(filename, dpi=200)