In [1]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import torch
import cmasher as cmr

from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error, median_absolute_error
from easyquery import Query, QueryMaker

USE_FULL_TREE = True

tree_model = "tree_gnn" if USE_FULL_TREE else "bstree_gnn"

In [2]:
BASE_DIR = Path("../").resolve()
RESULTS_DIR = BASE_DIR / "results"

K_FOLDS = 3

In [3]:
if "results-mhalo-vmax" in str(RESULTS_DIR):
    columns = ['subhalo_id_DMO', 'subhalo_loghalomass_DMO', 'subhalo_logvmax_DMO']
else:
    columns = [
        'subhalo_id_DMO', 'subhalo_loghalomass_DMO', 'subhalo_logvmax_DMO', 
        'subhalo_spin_DMO', 'subhalo_Vdisp_DMO', 'subhalo_VmaxRad_DMO',
        'subhalo_MRvmax_DMO', 'subhalo_RMhalf_DMO',  # 'subhalo_MRhalf_DMO', 'subhalo_MR2half_DMO' <- not present for trees
        'subhalo_x_DMO', 'subhalo_y_DMO', 'subhalo_z_DMO', 
        'subhalo_vx_DMO', 'subhalo_vy_DMO', 'subhalo_vz_DMO'
    ]
subhalos = pd.read_parquet(RESULTS_DIR / "subhalos.parquet", columns=columns)
subhalos = subhalos.rename({c: c.strip("_DMO") for c in subhalos.columns}, axis=1).set_index("subhalo_id")

In [4]:
results = dict()


for model in ["mlp", tree_model, "env_gnn", "tree_residual_gnn"]:
    results[model] = pd.concat(
        [pd.read_parquet(f"{RESULTS_DIR}/predictions/{model}_fold_{k}.parquet") for k in range(K_FOLDS)],
        axis=0
    )

# include k folds
results['mlp']['k_fold'] = np.concatenate([
    np.full(pd.read_parquet(f"{RESULTS_DIR}/predictions/mlp_fold_{k}.parquet").shape[0], k) 
    for k in range(K_FOLDS)
])

results['mlp'] = results['mlp'].reset_index().join(subhalos, on="subhalo_id", how="inner").set_index("subhalo_id")

valid_subhalo_ids = list(set(results["env_gnn"].index).intersection(results[tree_model].index).intersection(results['mlp'].index))
len(valid_subhalo_ids)

122993

In [30]:
results_columns = ["log_Mstar_pred", "log_Mstar_true", "log_Mgas_pred", "log_Mgas_true"]

df = (
    results["mlp"].loc[valid_subhalo_ids]
    .reset_index()
    .join(
        results[tree_model].loc[valid_subhalo_ids][results_columns],
        rsuffix="_tree_gnn",
        on="subhalo_id",
    )
    .join(
        results["tree_residual_gnn"].loc[valid_subhalo_ids][results_columns],
        rsuffix="_tree_residual_gnn",
        on="subhalo_id",
    )
    .join(
        results["env_gnn"].loc[valid_subhalo_ids][results_columns],
        rsuffix="_env_gnn",
        on="subhalo_id",
    ).rename(
        {col: f"{col}_mlp" for col in results_columns},
        axis=1
    )
)

print(df.columns)

Index(['subhalo_id', 'log_Mstar_pred_mlp', 'log_Mstar_true_mlp',
       'log_Mgas_pred_mlp', 'log_Mgas_true_mlp', 'is_central', 'k_fold',
       'subhalo_loghalomass', 'subhalo_logvmax', 'subhalo_spin',
       'subhalo_Vdisp', 'subhalo_VmaxRad', 'subhalo_MRvmax', 'subhalo_RMhalf',
       'subhalo_x', 'subhalo_y', 'subhalo_z', 'subhalo_vx', 'subhalo_vy',
       'subhalo_vz', 'log_Mstar_pred_tree_gnn', 'log_Mstar_true_tree_gnn',
       'log_Mgas_pred_tree_gnn', 'log_Mgas_true_tree_gnn',
       'log_Mstar_pred_tree_residual_gnn', 'log_Mstar_true_tree_residual_gnn',
       'log_Mgas_pred_tree_residual_gnn', 'log_Mgas_true_tree_residual_gnn',
       'log_Mstar_pred_env_gnn', 'log_Mstar_true_env_gnn',
       'log_Mgas_pred_env_gnn', 'log_Mgas_true_env_gnn'],
      dtype='object')


In [32]:
# update residual predictions based on the env_gnn base model
for col in ["log_Mstar", "log_Mgas"]:
    df[f"{col}_pred_tree_residual_gnn"] += df[f"{col}_pred_env_gnn"]
    df[f"{col}_true_tree_residual_gnn"] = df[f"{col}_true_env_gnn"]

In [33]:
df.to_csv(RESULTS_DIR / "predictions/tng-results_compiled.csv", index=False)

In [34]:
Query("log_Mstar_true_mlp > 8.5").count(df)

27617

## Comparison figures

In [36]:
metrics_mapping = {
    r"$R^2$": lambda p, y: r2_score(y, p, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"RMSE": lambda p, y: root_mean_squared_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"MAE":lambda p, y: mean_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"NMAD": lambda p, y: 1.4826 * median_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"Bias": lambda p, y: np.average(p - y, weights=np.isfinite(y.values).nonzero()[0])
}

In [48]:
min_stellar_mass = 8.5

q_mlp = Query("is_central == 1", f"log_Mstar_true_mlp > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_mlp"))
q_tree_gnn = Query("is_central == True", f"log_Mstar_true_tree_gnn > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_tree_gnn"))
q_residual_gnn = Query("is_central == True", f"log_Mstar_true_tree_residual_gnn > {min_stellar_mass}", QueryMaker.isfinite("log_Mstar_true_tree_residual_gnn"), QueryMaker.isfinite("log_Mgas_true_env_gnn"))
q_env_gnn = Query("is_central == 1", f"log_Mstar_true_env_gnn > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_env_gnn"))

In [53]:
q_mlp.count(df), q_tree_gnn.count(df), q_residual_gnn.count(df), q_env_gnn.count(df)

(26465, 26465, 26465, 26465)

In [49]:
# avg & error -- weighted by number of samples
for target in ["log_Mstar", "log_Mgas"]:
    print("".join(["="]*10))
    print(f"{target}")
    print("".join(["="]*10))
    for model, q in zip(["mlp", "tree_gnn", "env_gnn", "tree_residual_gnn"], [q_mlp, q_tree_gnn, q_env_gnn, q_residual_gnn]):
        for metric, func in metrics_mapping.items():
            scores = []
            weights = []
            for k in range(3):
                qk = Query(f"k_fold == {k}")
                filtered = (qk & q).filter(df)
                scores.append(func(filtered[f"{target}_pred_{model}"], filtered[f"{target}_true_{model}"]))
                weights.append(len(filtered))   
            avg_weighted = np.average(scores, weights=weights)
            std_weighted = np.sqrt(np.cov(scores, aweights=weights))

            print(f"{model: >8s} {metric: >7s}: ${avg_weighted:.4f} \pm {std_weighted:.4f}$")

log_Mstar
     mlp   $R^2$: $0.9258 \pm 0.0071$
     mlp    RMSE: $0.2259 \pm 0.0083$
     mlp     MAE: $0.1622 \pm 0.0068$
     mlp    NMAD: $0.1740 \pm 0.0065$
     mlp    Bias: $-0.0786 \pm 0.0245$
tree_gnn   $R^2$: $0.9406 \pm 0.0048$
tree_gnn    RMSE: $0.2021 \pm 0.0069$
tree_gnn     MAE: $0.1506 \pm 0.0063$
tree_gnn    NMAD: $0.1698 \pm 0.0106$
tree_gnn    Bias: $-0.0788 \pm 0.0208$
 env_gnn   $R^2$: $0.9478 \pm 0.0011$
 env_gnn    RMSE: $0.1897 \pm 0.0009$
 env_gnn     MAE: $0.1397 \pm 0.0008$
 env_gnn    NMAD: $0.1556 \pm 0.0015$
 env_gnn    Bias: $-0.0507 \pm 0.0100$
tree_residual_gnn   $R^2$: $0.9466 \pm 0.0031$
tree_residual_gnn    RMSE: $0.1918 \pm 0.0037$
tree_residual_gnn     MAE: $0.1482 \pm 0.0033$
tree_residual_gnn    NMAD: $0.1774 \pm 0.0037$
tree_residual_gnn    Bias: $0.0640 \pm 0.0073$
log_Mgas
     mlp   $R^2$: $0.8488 \pm 0.0184$
     mlp    RMSE: $0.2268 \pm 0.0165$
     mlp     MAE: $0.1504 \pm 0.0159$
     mlp    NMAD: $0.1664 \pm 0.0268$
     mlp    Bias: $-0

In [50]:
# # overall means (weighted by complete sample)
# for target in ["log_Mstar", "log_Mgas"]:
#     print("".join(["="]*10))
#     print(f"{target}")
#     print("".join(["="]*10))
#     for model, q in zip(["mlp", "tree_gnn", "env_gnn"], [q_mlp, q_tree_gnn, q_env_gnn]):
#         for metric, func in metrics_mapping.items():
#             filtered = q.filter(df)
#             score = func(filtered[f"{target}_pred_{model}"], filtered[f"{target}_true_{model}"])
#             print(f"{model: >8s} {metric: >7s}: {score:.5f}")

# Pred vs True comparisons

In [45]:
fig, axes = plt.subplots(2, 4, figsize=(14.5, 8), dpi=300, sharex=False, sharey=True)

target_captions = {
    "log_Mstar": r"log($M_\bigstar/M_\odot$)",
    "log_Mgas": r"log($M_{\rm gas}/M_\odot$)",
}

model_captions = {
    "mlp": "Subhalo only: MLP",
    "tree_gnn": "Merger tree: GNN",
    "env_gnn": "Environment: GNN",
    "tree_residual_gnn": "Residual Tree + Env: GNN"
}

for target, ax_row, cmap in zip(target_captions.keys(), axes, [cmr.torch_r, cmr.voltage_r]):
    for model, q, ax in zip(model_captions.keys(), [q_mlp, q_tree_gnn, q_env_gnn, q_residual_gnn], ax_row.flat):
        p = q.filter(df)[f"{target}_pred_{model}"]
        y = q.filter(df)[f"{target}_true_{model}"]
        
        # ax.scatter(y, p, s=5, alpha=0.15, color=cmap(0.7), linewidths=0, edgecolors="none", rasterized=True)
        ax.hist2d(y, p, bins=[70, 70], range=[(8,12), (8,12)], cmap=cmap, rasterized=True)

        ax.plot([8, 12], [8, 12], ls='-', c='w', lw=0.9)
        ax.plot([8, 12], [8, 12], ls='-', c='0.5', lw=0.1)

        ax.text(0.03, 0.9, f"{model_captions[model]}", fontsize=16, ha="left", transform=ax.transAxes)
        ax.text(0.03, 0.8, f"{target_captions[target]}", fontsize=16, ha="left", transform=ax.transAxes)

        ax.set_xlim(8.5, 12)
        ax.set_ylim(8.5, 12)
        ax.set_xticks([9, 10, 11, 12], [9, 10, 11, 12])
        ax.set_yticks([9, 10, 11, 12], [9, 10, 11, 12])
        
        ax.set_aspect("equal")
        ax.grid(alpha=0.15)
        
        for z, [metric, func] in enumerate(metrics_mapping.items()):
            score = func(p, y)
            ax.text(0.97, 0.08*(0.5 + z), f"{metric: >7s} = {score:.3f}", fontsize=12, ha="right", transform=ax.transAxes)
        
        
fig.subplots_adjust(wspace=0.075, hspace=0.2, left=0.075, right=0.975, top=0.975, bottom=0.075)

axes[0][0].set_ylabel(r"Pred log($M_\bigstar/M_\odot$)", fontsize=12)
axes[1][0].set_ylabel(r"Pred log($M_{\rm gas}/M_\odot$)", fontsize=12)

axes[0][0].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)
axes[0][1].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)
axes[0][2].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)

axes[1][0].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)
axes[1][1].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)
axes[1][2].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)

plt.savefig(RESULTS_DIR / "figures/results_pred-vs-true.pdf")

plt.clf()

<Figure size 4350x2400 with 0 Axes>

# Training curves

In [14]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=300, sharex=False, sharey=True)


for model, ax in zip(model_captions.keys(), axes):
    for k, c in enumerate(cmr.torch_r([0.3, 0.5, 0.7])):
        log_file = RESULTS_DIR / f"logs/{model}_fold_{k}.txt"
        losses = pd.read_csv(log_file)
        
        ax.plot(losses["valid_RMSE"], c=c)
    
    
    ax.text(0.97, 0.9, f"{model_captions[model]}", fontsize=16, ha="right", transform=ax.transAxes)

    ax.set_ylim(0.19, 0.41)
    ax.set_yticks([0.2, 0.25, 0.30, 0.35, 0.40])

    ax.grid(alpha=0.15)
        
fig.subplots_adjust(wspace=0.075, hspace=0.2, left=0.075, right=0.975, top=0.975, bottom=0.125)

axes[0].text(0.05, 0.20, f"Fold 1", color=cmr.torch_r(0.3), fontsize=16, ha="left", transform=axes[0].transAxes)
axes[0].text(0.05, 0.12, f"Fold 2", color=cmr.torch_r(0.5), fontsize=16, ha="left", transform=axes[0].transAxes)
axes[0].text(0.05, 0.04, f"Fold 3", color=cmr.torch_r(0.7), fontsize=16, ha="left", transform=axes[0].transAxes)

axes[0].set_ylabel(r"Avg Validation RMSE [dex]", fontsize=12)
axes[0].set_xlabel(r"Epoch", fontsize=12)
axes[1].set_xlabel(r"Epoch", fontsize=12)
axes[2].set_xlabel(r"Epoch", fontsize=12)

plt.savefig(RESULTS_DIR / "figures/results_validation-rmse-curves.pdf")

plt.clf()

<Figure size 3600x1200 with 0 Axes>