In [None]:
import json

import gc_utils
import h5py
import numpy as np
import pandas as pd
import utilities as ut

In [369]:
sim = "m12i"
sim_dir = "/Users/z5114326/Documents/simulations/"
fir_dir = sim_dir + sim + "/" + sim + "_res7100"

all_data = sim_dir + "m12i" + "/" + "m12i" + "_res7100/snapshot_times.txt"
all_snaps = pd.read_table(all_data, comment="#", header=None, sep=r"\s+")
all_snaps.columns = [
    "index",
    "scale_factor",
    "redshift",
    "time_Gyr",
    "lookback_time_Gyr",
    "time_width_Myr",
]

pub_data = sim_dir + "/snapshot_times_public.txt"
pub_snaps = pd.read_table(pub_data, comment="#", header=None, sep=r"\s+")
pub_snaps.columns = [
    "index",
    "scale_factor",
    "redshift",
    "time_Gyr",
    "lookback_time_Gyr",
    "time_width_Myr",
]
snp_lst = pub_snaps["index"].values
tim_lst = pub_snaps["time_Gyr"].values

sim_code_file = sim_dir + "simulation_codes.json"
with open(sim_code_file, "r") as file:
    sim_codes = json.load(file)

proc_file = sim_dir + sim + "/" + sim + "_processed.hdf5"
proc_data = h5py.File(proc_file, "r")  # open processed data file

In [321]:
halt = gc_utils.get_halo_tree(sim, sim_dir)

Retrieving Halo Tree.....................: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.28s/it]


In [322]:
src_vars = {
    "gc_id": "gcid",
    "group_id": "grpid",
    "feh": "feh",
    "form_time": "tfor",
    "logm_tform": "logm_tfor",
    "logm_z0": "logm_tz0",
    "ptype": "ptype",
    "pubsnap_zform": "snap_tforp",
    "t_acc": "tacc",
    "t_dis": "tdis",
    "survive_flag": "s_flag",
    "survived_accretion": "sa_flag",
    "snap_acc": "snap_tacc",
    "halo_zform": "halo_tfor",
    "snap_zform": "snap_tfor",
}

snp_vars = {
    "gc_id": "gcid",
    "group_id": "grpid",
    "ptype": "ptype",
    "now_accreted": "nacc_flag",
    "ecc": "ecc",
    "ek": "ek",
    "ep_agama": "ep",
    "et": "et",
    "inc": "inc",
    "lz_norm": "circ",
    "mass": "logm",
    "vel.sph": "host.vel.sph",
    "pos.sph": "host.pos.sph",
    "survive_flag": "s_flag",
    "survived_accretion": "sa_flag",
    "snap_part_idx": "pidx",
    "bound_flag": "bnd_flag",
}

In [323]:
def get_halo(halt, halo_tid, snap, tid_to_index, cache):
    # Return cached result if seen before
    if (halo_tid, snap) in cache:
        return cache[(halo_tid, snap)]

    # Find index directly via prebuilt dictionary
    tidx = tid_to_index[halo_tid]
    snap_hold = halt["snapshot"][tidx]

    # Follow descendant chain until reaching or exceeding the target snapshot
    while snap_hold < snap:
        tidx = halt["descendant.index"][tidx]
        snap_hold = halt["snapshot"][tidx]

    tid = halt["tid"][tidx]
    cache[(halo_tid, snap)] = (tid, tidx)  # store result for reuse
    return tid, tidx

In [None]:
data_dict = {}

# for itid in proc_data.keys():
for it in range(0, 5):
    it_id = gc_utils.iteration_name(it)
    it_dict = {"source": {}, "snapshots": {}}

    # add already known source information
    src_dat = proc_data[it_id]["source"]
    ana_msk = src_dat["analyse_flag"][()] == 1
    for var in src_vars.keys():
        it_dict["source"][src_vars[var]] = src_dat[var][ana_msk]

    # --- Precompute useful mappings ---
    src_gcid_tfor = {gcid: tfor for gcid, tfor in zip(it_dict["source"]["gcid"], it_dict["source"]["tfor"])}
    src_gcid_halo = {
        gcid: halo_tfor for gcid, halo_tfor in zip(it_dict["source"]["gcid"], it_dict["source"]["halo_tfor"])
    }

    # Prebuild a fast lookup for halo_tid → index
    tid_to_index = {tid: idx for idx, tid in enumerate(halt["tid"])}

    # Cache repeated (halo_tid, snap) results
    halo_cache = {}

    # --- Main loop ---
    for snap_id in proc_data[it_id]["snapshots"].keys():
        snap = int(snap_id[4:])
        tim = all_snaps["time_Gyr"][snap]

        snp_dat = proc_data[it_id]["snapshots"][snap_id]
        it_dict["snapshots"][snap_id] = {}

        # Copy over snapshot variables
        for var in snp_vars.keys():
            if var in snp_dat.keys():
                it_dict["snapshots"][snap_id][snp_vars[var]] = snp_dat[var][()]

        # Compute difference of tidal eigenvalues
        it_dict["snapshots"][snap_id]["tideig"] = snp_dat["tideig_1"][()] - snp_dat["tideig_3"][()]

        # Compute GC ages and host halo properties
        gcids = it_dict["snapshots"][snap_id]["gcid"]
        ages = []
        halo_tids = []
        halo_tidxs = []

        for gcid in gcids:
            ages.append(tim - src_gcid_tfor[gcid])

            halo_tid = src_gcid_halo[gcid]
            tid, tidx = get_halo(halt, halo_tid, snap, tid_to_index, halo_cache)
            halo_tids.append(tid)
            halo_tidxs.append(tidx)

        it_dict["snapshots"][snap_id]["age"] = np.array(ages)
        it_dict["snapshots"][snap_id]["halo_tid"] = np.array(halo_tids)
        it_dict["snapshots"][snap_id]["halo_tidx"] = np.array(halo_tidxs)

    data_dict[it_id] = it_dict

In [None]:
for snap in snp_lst:
    snap_id = gc_utils.snapshot_name(snap)

    # set host rotation to false so that we are comparing halt and part in the same frame
    part = gc_utils.open_snapshot(snap, fir_dir, species=["dark", "star"], assign_hosts_rotation=False)

    # for itid in proc_data.keys():
    for it in range(0, 5):
        it_id = gc_utils.iteration_name(it)
        it_dict = data_dict[it_id]

        snp_gcid_map = {gcid: idx for idx, gcid in enumerate(it_dict["snapshots"][snap_id]["gcid"][()])}
        halo_pos = []
        halo_vel = []
        for gcid in it_dict["snapshots"][snap_id]["gcid"][()]:
            idx = snp_gcid_map[gcid]
            pidx = it_dict["snapshots"][snap_id]["pidx"][idx]
            ptype = it_dict["snapshots"][snap_id]["ptype"][idx].decode("utf-8")
            halo_tidx = it_dict["snapshots"][snap_id]["halo_tidx"][idx]
            nacc = it_dict["snapshots"][snap_id]["nacc_flag"][idx]

            if nacc == 1:
                halo_pos.append(it_dict["snapshots"][snap_id]["host.pos.sph"][idx])
                halo_vel.append(it_dict["snapshots"][snap_id]["host.vel.sph"][idx])
            else:
                pos_xyz = ut.coordinate.get_distances(
                    part[ptype]["position"][pidx],
                    halt["position"][halo_tidx],
                    part.info["box.length"],
                    part.snapshot["scalefactor"],
                    False,
                )  # [kpc physical]
                pos_sph = ut.coordinate.get_positions_in_coordinate_system(pos_xyz, "cartesian", "spherical")
                halo_pos.append(pos_sph)

                vel_xyz = ut.coordinate.get_velocity_differences(
                    part[ptype]["velocity"][pidx],
                    halt["velocity"][halo_tidx],
                    part[ptype]["position"][pidx],
                    halt["position"][halo_tidx],
                    part.info["box.length"],
                    part.snapshot["scalefactor"],
                    part.snapshot["time.hubble"],
                    False,
                )

                vel_sph = ut.coordinate.get_velocities_in_coordinate_system(
                    vel_xyz, pos_xyz, "cartesian", "spherical"
                )
                halo_vel.append(vel_sph)

        it_dict["snapshots"][snap_id]["halo.pos.sph"] = np.array(halo_pos)
        it_dict["snapshots"][snap_id]["halo.vel.sph"] = np.array(halo_vel)

    del part

Retrieving Snapshot 20..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:13<00:00, 13.15s/it]
Retrieving Snapshot 23..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.42s/it]
Retrieving Snapshot 26..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.98s/it]
Retrieving Snapshot 29..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.66s/it]
Retrieving Snapshot 33..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:13<00:00, 13.41s/it]
Retrieving Snapshot 37..................: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.51s/it]
Retrieving Snapshot 41..................: 100%|███████████████████████████████████████████████

In [None]:
# for itid in proc_data.keys():
for it in range(0, 5):
    it_id = gc_utils.iteration_name(it)
    it_dict = data_dict[it_id]

    # --- GC IDs and accretion data ---
    gcid_arr = np.array(it_dict["source"]["gcid"][()])
    tacc_arr = np.array(it_dict["source"]["snap_tacc"][()])
    n_gcs = len(gcid_arr)

    # --- Snapshot ordering ---
    # Sort snapshot IDs numerically by the number after 'snap'
    snap_ids = sorted(it_dict["snapshots"].keys(), key=lambda x: int(x.replace("snap", "")))
    n_snaps = len(snap_ids)

    # --- Preallocate arrays ---
    halo_r = np.full((n_gcs, n_snaps), np.nan)
    host_r = np.full((n_gcs, n_snaps), np.nan)
    tide_m = np.full((n_gcs, n_snaps), np.nan)

    # --- Precompute GCID → index map per snapshot ---
    snap_gcid_index = {
        snap_id: {g: i for i, g in enumerate(it_dict["snapshots"][snap_id]["gcid"][()])}
        for snap_id in snap_ids
    }

    # --- Fill arrays ---
    for j, snap_id in enumerate(snap_ids):
        snap = it_dict["snapshots"][snap_id]
        gcids = snap["gcid"][()]
        idx_map = snap_gcid_index[snap_id]

        valid = np.isin(gcid_arr, gcids)
        valid_gcids = gcid_arr[valid]
        idxs = [idx_map[g] for g in valid_gcids]

        halo_r[valid, j] = snap["halo.pos.sph"][:, 0][idxs]
        host_r[valid, j] = snap["host.pos.sph"][:, 0][idxs]
        tide_m[valid, j] = snap["tideig"][idxs]

    # --- Masks and weights ---
    nan_mask = ~np.isnan(halo_r)
    weights = np.broadcast_to(tim_lst, halo_r.shape)

    # --- Helper: safe weighted average ---
    def weighted_avg(arr, mask):
        num = np.nansum(arr * weights * mask, axis=1)
        den = np.nansum(weights * mask, axis=1)
        with np.errstate(invalid="ignore", divide="ignore"):
            avg = np.divide(num, den, out=np.full_like(num, np.nan), where=den > 0)
        return np.where(np.any(mask, axis=1) & (den > 0), avg, -1)

    # --- Birth radii (first valid snapshot) ---
    first_valid_idx = np.where(nan_mask, np.arange(n_snaps), n_snaps).argmin(axis=1)
    has_valid = np.any(nan_mask, axis=1)
    birth_halo_radii = np.where(has_valid, halo_r[np.arange(n_gcs), first_valid_idx], -1)
    birth_host_radii = np.where(has_valid, host_r[np.arange(n_gcs), first_valid_idx], -1)

    # --- Time-weighted averages ---
    avg_tidems = weighted_avg(tide_m, nan_mask)
    avg_halo_radii = weighted_avg(halo_r, nan_mask)
    avg_host_radii = weighted_avg(host_r, nan_mask)

    # --- Pre/Post accretion averages ---
    snp_arr = np.array(snp_lst)
    tacc_2d = tacc_arr[:, None]
    acc_mask = snp_arr[None, :] < tacc_2d

    valid_pre = acc_mask & nan_mask
    valid_pos = (~acc_mask) & nan_mask

    avg_tidems_pre = weighted_avg(tide_m, valid_pre)
    avg_tidems_pos = weighted_avg(tide_m, valid_pos)
    avg_halo_radii_pre = weighted_avg(halo_r, valid_pre)
    avg_host_radii_pre = weighted_avg(host_r, valid_pre)
    avg_host_radii_pos = weighted_avg(host_r, valid_pos)

    # --- Assign back to source dict ---
    src = it_dict["source"]
    src["halo.r.birth"] = np.array(birth_halo_radii)
    src["host.r.birth"] = np.array(birth_host_radii)

    src["tideig.avg"] = np.array(avg_tidems)
    src["halo.r.avg"] = np.array(avg_halo_radii)
    src["host.r.avg"] = np.array(avg_host_radii)

    src["halo.r.avg.pre"] = np.array(avg_halo_radii_pre)
    src["host.r.avg.pre"] = np.array(avg_host_radii_pre)
    src["host.r.avg.pos"] = np.array(avg_host_radii_pos)

    src["tideig.avg.pre"] = np.array(avg_tidems_pre)
    src["tideig.avg.pos"] = np.array(avg_tidems_pos)

In [485]:
data_dict["it000"]["source"].keys()

dict_keys(['gcid', 'grpid', 'feh', 'tfor', 'logm_tfor', 'logm_tz0', 'ptype', 'snap_tforp', 'tacc', 'tdis', 's_flag', 'sa_flag', 'snap_tacc', 'halo_tfor', 'snap_tfor', 'halo.r.birth', 'host.r.birth', 'tideig.avg', 'halo.r.avg', 'host.r.avg', 'halo.r.avg.pre', 'host.r.avg.pre', 'host.r.avg.pos', 'tideig.avg.pre', 'tideig.avg.pos'])