In [None]:
%load_ext autoreload
%autoreload 2

%config IPCompleter.greedy=True

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from astropy import units
from pathlib import Path
import os
import time

import snapshot_obj
import dataset_compute
import curve_fit

import importlib

In [None]:
importlib.reload(snapshot_obj)
importlib.reload(dataset_compute)
importlib.reload(curve_fit)

In [None]:
distinction = "by_gn"

In [None]:
snap_id = 127
sim_id = "V1_LR_fix"
m31 = (1,0)
mw = (2,0)
snap = snapshot_obj.Snapshot(sim_id, snap_id)

In [None]:
gns = snap.get_particles("GroupNumber", part_type=[4])
sgns = snap.get_particles("SubGroupNumber", part_type=[4])
star_masses = snap.get_particles("Masses", part_type=[4]) * units.g.to(units.Msun)

In [None]:
# Select by star particle masses:
mask_mass_range = np.logical_and(star_masses > 10**6, star_masses < 10**7)
mask_mass_range = np.logical_and(star_masses > 10**5, star_masses < 10**6)
grouped_data = dataset_compute.group_selected_particles_by_subhalo(
    snap, mask_mass_range, "StellarFormationTime", "Masses", part_type=[4])

In [None]:
print(np.sum(mask_mass_range))

In [None]:
star_form_times_halo = np.array(grouped_data["StellarFormationTime"])
star_masses_halo = np.array(grouped_data["Masses"])

In [None]:
mask_lum, mask_dark = dataset_compute.split_luminous(snap)
mask_nonzero_vmax = dataset_compute.prune_vmax(snap)
# Split into satellites:
if distinction == "by_r":
    masks_sat, mask_isol = dataset_compute.split_satellites_by_distance(
        snap, m31, mw)
elif distinction == "by_gn":
    masks_sat, mask_isol = dataset_compute.split_satellites_by_group_number(
        snap, m31, mw)
    
mask_sat = np.logical_and.reduce([np.logical_or.reduce(masks_sat),
                                 mask_lum, mask_nonzero_vmax])
mask_isol = np.logical_and.reduce([mask_isol,
                                   mask_lum, mask_nonzero_vmax])

In [None]:
sft_all = np.concatenate(star_form_times_halo[mask_sat])

In [None]:
fig, axes = plt.subplots()

n_stars_in_sats = sft_all.size
print(n_stars_in_sats)
axes.hist(sft_all, bins=10, weights=np.ones(n_stars_in_sats)/n_stars_in_sats)

In [None]:
sort = np.lexsort((sgns, gns))
idents = np.vstack((gns[sort], sgns[sort])).T
print(idents.shape)
t = time.time()
mask_halo = idents[:,1] < np.max(idents[:,1])
idents_unique = np.unique(idents[mask_halo], axis=0)
print(time.time()-t)
print(idents_unique)
print(idents_unique.shape)

In [None]:
mask_lum_subh = snap.get_subhalos("Stars/Mass") > 0
print(np.sum(mask_lum_subh))
counts_true = snap.get_subhalos("SubLengthType").astype(int)[mask_lum_subh,4]

In [None]:
t = time.time()
counts = [0] * snap.get_halo_number()
counts = [np.sum(np.logical_and(gns == ident[0], sgns == ident[1]))
                for ident in np.unique(idents[mask_halo], axis=0)]
print(time.time()-t)
print(counts)

In [None]:
print(np.cumsum(counts))

In [None]:
t = time.time()
subhalo_gns = snap.get_subhalos("GroupNumber")
subhalo_sgns = snap.get_subhalos("SubGroupNumber")

# Count number of entries with each group number and subgroup number
# pairs:
counts = [np.sum(np.logical_and(gns == gn, sgns == sgn)) 
          for gn, sgn in zip(subhalo_gns, subhalo_sgns)]

subhalo_offsets = np.cumsum(counts)

print(time.time() - t)

In [None]:
print(np.all(counts == counts_true))

In [None]:
print(np.histogram(star_masses, bins=[10**5, 5*10**5, 10**6, 5*10**6, 10**7, 5*10**7, 10**8]))

In [None]:
star_form_times_sat = star_form_times[mask_sat]
star_masses_sat = star_masses[mask_sat]

In [None]:
# Check that no star particles are in dark halos:
empty = np.concatenate(np.array(grouped_data["Masses"])[mask_dark])

empty.size

In [None]:
gns = snap.get_subhalos("GroupNumber")
sgns = snap.get_subhalos("SubGroupNumber")
sort, split = dataset_compute.sort_and_split_by_subhalo(gns, sgns)
print(np.split(gns, split))

In [None]:
print(split.size)