In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import dataset_compute
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(curve_fit)

# Subhalo stellar ages

## Construct data dictionary

Add entries for each simulation, and specify M31 and MW galaxies:

In [None]:
snap_id = 127
sim_ids = ["V1_LR_fix", "V1_LR_curvaton_p082_fix", "V1_LR_curvaton_p084_fix"]
names = ["LCDM", "p082", "p084"]

# Define M31 and MW in each simulation:
m31 = [(1,0), (1,0), (1,0)]
mw = [(2,0), (1,1), (1,1)]

data = {}
for name, sim_id, m31_ns, mw_ns in zip(names, sim_ids, m31, mw):
    data[name] = {"snapshot": snapshot_obj.Snapshot(sim_id, snap_id, name=name),
                  "M31_identifier": m31_ns,
                  "MW_identifier": mw_ns}

Choose how to distinguish between satellite and isolated galaxies:

In [None]:
distinction = "by_gn"

In [None]:
for name, sim_data in data.items():
    snap = sim_data["snapshot"]
    
    # Read star particle formation times for star particles of 
    # each subhalo. Select only star particles with masses in the
    # given range:
    star_masses = snap.get_particles("Masses", part_type=[4]) \
                      * units.g.to(units.Msun)
    mask_mass_range = np.logical_and(star_masses > 10**3, 
                                     star_masses < 10**8)
    grouped_data = dataset_compute.group_selected_particles_by_subhalo(
        snap, "StellarFormationTime", "InitialMass", selection_mask=mask_mass_range, 
        part_type=[4])
    
    print(np.concatenate(grouped_data["StellarFormationTime"]).size)
    for sft in grouped_data["StellarFormationTime"][:10]:
        print("    ", sft.size)
    
    sf_z = [1/sft - 1 for sft in grouped_data["StellarFormationTime"]]
    mass_weighted_sf_z = np.array([np.sum(m*z)/np.sum(m) for m, z in 
                                   zip(grouped_data["InitialMass"], sf_z)])
    sf_z = np.array(sf_z)
             
    # Split into satellites:
    if distinction == "by_r":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
    elif distinction == "by_gn":
        masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
            snap, sim_data["M31_identifier"], sim_data["MW_identifier"])
        
    print(name, np.sum(np.logical_or.reduce(masks_sat)))
    mask_lum, mask_dark = dataset_compute.split_luminous(snap)
    mask_nonzero_vmax = dataset_compute.prune_vmax(snap)
    
    # Add separate datasets for each subhalo to the data dictionary:
    data[name]["separated"] = \
    {"StellarFormationTime": \
     {"satellites": sf_z[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask_nonzero_vmax]
    )],
      "isolated": sf_z[np.logical_and.reduce(
          [mask_isol, mask_lum, mask_nonzero_vmax]
      )]
     }
    }
    
    # Combine datasets of subhalos and add to the data dictionary:
    data[name]["combined"] = \
    {"StellarFormationTime": \
     {"satellites": np.concatenate(
        data[name]["separated"]["StellarFormationTime"]["satellites"]),
      "isolated": np.concatenate(
        data[name]["separated"]["StellarFormationTime"]["isolated"])
     },
     "InitialMassWeightedBirthZ": \
     {"satellites": mass_weighted_sf_z[np.logical_and.reduce(
        [np.logical_or.reduce(masks_sat), mask_lum, mask_nonzero_vmax]
    )],
      "isolated": mass_weighted_sf_z[np.logical_and.reduce(
          [mask_isol, mask_lum, mask_nonzero_vmax]
      )]
     }
    }

## Plot

In [None]:
cols = ["grey", "pink", "lightblue"]

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(14,6))
plt.subplots_adjust(wspace=0.3)

# Set axis:
for ax in axes:  
    ax.set_xlabel('$z$', fontsize=16)
    ax.set_ylabel('Number density', fontsize=16)

axes[0].invert_xaxis()
axes[1].invert_xaxis()

axes[0].set_title('Satellite galaxies')
axes[1].set_title('Isolated galaxies')

# Add scatter plots:
sf_z = []
weights = []
n_bins = 5
for i, (name, sim_data) in enumerate(data.items()):
    sf_z.append(sim_data["combined"]["InitialMassWeightedBirthZ"]["satellites"])
    n_stars = sf_z[i].size
    print(name, n_stars)
    weights.append(np.ones(n_stars)/n_stars)

_ = axes[0].hist(sf_z, n_bins, weights=weights, color=cols, label=data.keys())

# Add scatter plots:
sf_z = []
weights = []
for i, (name, sim_data) in enumerate(data.items()):
    sf_z.append(sim_data["combined"]["InitialMassWeightedBirthZ"]["isolated"])
    n_stars = sf_z[i].size
    print(name, n_stars)
    weights.append(np.ones(n_stars)/n_stars)

_ = axes[1].hist(sf_z, n_bins, weights=weights, color=cols)

axes[0].legend()

In [None]:
gns = snap.get_particles("GroupNumber", part_type=[4])
sgns = snap.get_particles("SubGroupNumber", part_type=[4])
star_masses = snap.get_particles("Masses", part_type=[4]) * units.g.to(units.Msun)

In [None]:
# Select by star particle masses:
mask_mass_range = np.logical_and(star_masses > 10**6, star_masses < 10**7)
mask_mass_range = np.logical_and(star_masses > 10**5, star_masses < 10**6)
grouped_data = dataset_compute.group_selected_particles_by_subhalo(
    snap, "StellarFormationTime", "Masses", selection_mask=mask_mass_range, part_type=[4])

In [None]:
print(np.sum(mask_mass_range))

In [None]:
star_form_times_halo = np.array(grouped_data["StellarFormationTime"])
star_masses_halo = np.array(grouped_data["Masses"])

In [None]:
mask_lum, mask_dark = dataset_compute.split_luminous(snap)
mask_nonzero_vmax = dataset_compute.prune_vmax(snap)
# Split into satellites:
if distinction == "by_r":
    masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
        snap, m31, mw)
elif distinction == "by_gn":
    masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
        snap, m31, mw)
    
mask_sat = np.logical_and.reduce([np.logical_or.reduce(masks_sat),
                                 mask_lum, mask_nonzero_vmax])
mask_isol = np.logical_and.reduce([mask_isol,
                                   mask_lum, mask_nonzero_vmax])

In [None]:
sft_all = np.concatenate(star_form_times_halo[mask_sat])

In [None]:
fig, axes = plt.subplots()

n_stars_in_sats = sft_all.size
print(n_stars_in_sats)
axes.hist(sft_all, bins=10, weights=np.ones(n_stars_in_sats)/n_stars_in_sats)

In [None]:
sort = np.lexsort((sgns, gns))
idents = np.vstack((gns[sort], sgns[sort])).T
print(idents.shape)
t = time.time()
mask_halo = idents[:,1] < np.max(idents[:,1])
idents_unique = np.unique(idents[mask_halo], axis=0)
print(time.time()-t)
print(idents_unique)
print(idents_unique.shape)

In [None]:
mask_lum_subh = snap.get_subhalos("Stars/Mass") > 0
print(np.sum(mask_lum_subh))
counts_true = snap.get_subhalos("SubLengthType").astype(int)[mask_lum_subh,4]

In [None]:
t = time.time()
counts = [0] * snap.get_subhalo_number()
counts = [np.sum(np.logical_and(gns == ident[0], sgns == ident[1]))
                for ident in np.unique(idents[mask_halo], axis=0)]
print(time.time()-t)
print(counts)

In [None]:
print(np.cumsum(counts))

In [None]:
t = time.time()
subhalo_gns = snap.get_subhalos("GroupNumber")
subhalo_sgns = snap.get_subhalos("SubGroupNumber")

# Count number of entries with each group number and subgroup number
# pairs:
counts = [np.sum(np.logical_and(gns == gn, sgns == sgn)) 
          for gn, sgn in zip(subhalo_gns, subhalo_sgns)]

subhalo_offsets = np.cumsum(counts)

print(time.time() - t)

In [None]:
print(np.all(counts == counts_true))

In [None]:
print(np.histogram(star_masses, bins=[10**5, 5*10**5, 10**6, 5*10**6, 10**7, 5*10**7, 10**8]))

In [None]:
star_form_times_sat = star_form_times[mask_sat]
star_masses_sat = star_masses[mask_sat]

In [None]:
# Check that no star particles are in dark halos:
empty = np.concatenate(np.array(grouped_data["Masses"])[mask_dark])

empty.size

In [None]:
gns = snap.get_subhalos("GroupNumber")
sgns = snap.get_subhalos("SubGroupNumber")
sort, split = dataset_compute.sort_and_split_by_subhalo(gns, sgns)
print(np.split(gns, split))

In [None]:
print(split.size)