In [1]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import torch

In [2]:
BASE_DIR = Path("../").resolve()
RESULTS_DIR = BASE_DIR / "results"

K_FOLDS = 3

In [3]:
results = dict()

for model in ["mlp", "tree_gnn", "env_gnn"]:
    results[model] = pd.concat(
        [pd.read_parquet(f"{RESULTS_DIR}/predictions/{model}_fold_{k}.parquet") for k in range(K_FOLDS)],
        axis=0
    )


valid_subhalo_ids = list(set(results["env_gnn"].index).intersection(results["tree_gnn"].index))
len(valid_subhalo_ids)

123001

In [4]:
results_columns = ["log_Mstar_pred", "log_Mstar_true", "log_Mgas_pred", "log_Mgas_true"]

df = (
    results["mlp"].loc[valid_subhalo_ids]
    .reset_index()
    .join(
        results["tree_gnn"].loc[valid_subhalo_ids][results_columns],
        rsuffix="_tree_gnn",
        on="subhalo_id",
    ).join(
        results["env_gnn"].loc[valid_subhalo_ids][results_columns],
        rsuffix="_env_gnn",
        on="subhalo_id",
    ).rename(
        {col: f"{col}_mlp" for col in results_columns},
        axis=1
    )
)

print(df.columns)

Index(['subhalo_id', 'log_Mstar_pred_mlp', 'log_Mstar_true_mlp',
       'log_Mgas_pred_mlp', 'log_Mgas_true_mlp', 'is_central',
       'log_Mstar_pred_tree_gnn', 'log_Mstar_true_tree_gnn',
       'log_Mgas_pred_tree_gnn', 'log_Mgas_true_tree_gnn',
       'log_Mstar_pred_env_gnn', 'log_Mstar_true_env_gnn',
       'log_Mgas_pred_env_gnn', 'log_Mgas_true_env_gnn'],
      dtype='object')


## Comparison figures

In [5]:
from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error, median_absolute_error

In [6]:
metrics_mapping = {
    r"$R^2$": lambda p, y: r2_score(y, p, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"RMSE": lambda p, y: root_mean_squared_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"MAE":lambda p, y: mean_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"NMAD": lambda p, y: 1.4826 * median_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
}

In [7]:
from easyquery import Query, QueryMaker

In [8]:
min_stellar_mass = 8.5

q_mlp = Query("is_central == 1", f"log_Mstar_true_mlp > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_mlp"))
q_tree_gnn = Query("is_central == True", f"log_Mstar_true_tree_gnn > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_tree_gnn"))
q_env_gnn = Query("is_central == 1", f"log_Mstar_true_env_gnn > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true_env_gnn"))

In [9]:
for target in ["log_Mstar", "log_Mgas"]:
    print("".join(["="]*10))
    print(f"{target}")
    print("".join(["="]*10))
    for model, q in zip(["mlp", "tree_gnn", "env_gnn"], [q_mlp, q_tree_gnn, q_env_gnn]):
        for metric, func in metrics_mapping.items():
            filtered = q.filter(df)
            score = func(filtered[f"{target}_pred_{model}"], filtered[f"{target}_true_{model}"])
            print(f"{model: >8s} {metric: >7s}: {score:.5f}")

log_Mstar
     mlp   $R^2$: 0.88877
     mlp    RMSE: 0.25452
     mlp     MAE: 0.18126
     mlp    NMAD: 0.19503
tree_gnn   $R^2$: 0.93452
tree_gnn    RMSE: 0.19528
tree_gnn     MAE: 0.14125
tree_gnn    NMAD: 0.16165
 env_gnn   $R^2$: 0.92009
 env_gnn    RMSE: 0.21573
 env_gnn     MAE: 0.15263
 env_gnn    NMAD: 0.15982
log_Mgas
     mlp   $R^2$: 0.77372
     mlp    RMSE: 0.27334
     mlp     MAE: 0.17710
     mlp    NMAD: 0.17751
tree_gnn   $R^2$: 0.79492
tree_gnn    RMSE: 0.26022
tree_gnn     MAE: 0.17205
tree_gnn    NMAD: 0.17773
 env_gnn   $R^2$: 0.85828
 env_gnn    RMSE: 0.21632
 env_gnn     MAE: 0.13518
 env_gnn    NMAD: 0.12437


# Pred vs True comparisons

In [10]:
import cmasher as cmr

In [11]:
fig, axes = plt.subplots(2, 3, figsize=(11.5, 8), dpi=300, sharex=False, sharey=True)

target_captions = {
    "log_Mstar": r"log($M_\bigstar/M_\odot$)",
    "log_Mgas": r"log($M_{\rm gas}/M_\odot$)",
}

model_captions = {
    "mlp": "Subhalo only: MLP",
    "tree_gnn": "Merger tree: GNN",
    "env_gnn": "Environment: GNN"
}

for target, ax_row, cmap in zip(target_captions.keys(), axes, [cmr.torch_r, cmr.voltage_r]):
    for model, q, ax in zip(model_captions.keys(), [q_mlp, q_tree_gnn, q_env_gnn], ax_row.flat):
        p = q.filter(df)[f"{target}_pred_{model}"]
        y = q.filter(df)[f"{target}_true_{model}"]
        
        # ax.scatter(y, p, s=5, alpha=0.15, color=cmap(0.7), linewidths=0, edgecolors="none", rasterized=True)
        ax.hist2d(y, p, bins=[70, 70], range=[(8,12), (8,12)], cmap=cmap, rasterized=True)

        ax.plot([8, 12], [8, 12], ls='-', c='w', lw=0.9)
        ax.plot([8, 12], [8, 12], ls='-', c='0.5', lw=0.1)

        ax.text(0.03, 0.9, f"{model_captions[model]}", fontsize=16, ha="left", transform=ax.transAxes)
        ax.text(0.03, 0.8, f"{target_captions[target]}", fontsize=16, ha="left", transform=ax.transAxes)

        ax.set_xlim(8.5, 12)
        ax.set_ylim(8.5, 12)
        ax.set_xticks([9, 10, 11, 12], [9, 10, 11, 12])
        ax.set_yticks([9, 10, 11, 12], [9, 10, 11, 12])
        
        ax.set_aspect("equal")
        ax.grid(alpha=0.15)
        
        for z, [metric, func] in enumerate(metrics_mapping.items()):
            score = func(p, y)
            ax.text(0.97, 0.08*(0.5 + z), f"{metric: >7s} = {score:.3f}", fontsize=12, ha="right", transform=ax.transAxes)
        
        
fig.subplots_adjust(wspace=0.075, hspace=0.2, left=0.075, right=0.975, top=0.975, bottom=0.075)

axes[0][0].set_ylabel(r"Pred log($M_\bigstar/M_\odot$)", fontsize=12)
axes[1][0].set_ylabel(r"Pred log($M_{\rm gas}/M_\odot$)", fontsize=12)

axes[0][0].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)
axes[0][1].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)
axes[0][2].set_xlabel(r"True log($M_\bigstar/M_\odot$)", fontsize=12)

axes[1][0].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)
axes[1][1].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)
axes[1][2].set_xlabel(r"True log($M_{\rm gas}/M_\odot$)", fontsize=12)

plt.savefig(RESULTS_DIR / "figures/results_pred-vs-true.pdf")

plt.clf()

<Figure size 3450x2400 with 0 Axes>

# Training curves

In [12]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=300, sharex=False, sharey=True)


for model, ax in zip(model_captions.keys(), axes):
    for k, c in enumerate(cmr.torch_r([0.3, 0.5, 0.7])):
        log_file = RESULTS_DIR / f"logs/{model}_fold_{k}.txt"
        losses = pd.read_csv(log_file)
        
        ax.plot(losses["valid_RMSE"], c=c)
    
    
    ax.text(0.97, 0.9, f"{model_captions[model]}", fontsize=16, ha="right", transform=ax.transAxes)

    ax.set_ylim(0.27, 0.45)

    ax.grid(alpha=0.15)
        
fig.subplots_adjust(wspace=0.075, hspace=0.2, left=0.075, right=0.975, top=0.975, bottom=0.125)

axes[0].text(0.05, 0.20, f"Fold 1", color=cmr.torch_r(0.3), fontsize=16, ha="left", transform=axes[0].transAxes)
axes[0].text(0.05, 0.12, f"Fold 2", color=cmr.torch_r(0.5), fontsize=16, ha="left", transform=axes[0].transAxes)
axes[0].text(0.05, 0.04, f"Fold 3", color=cmr.torch_r(0.7), fontsize=16, ha="left", transform=axes[0].transAxes)

axes[0].set_ylabel(r"Avg Validation RMSE [dex]", fontsize=12)
axes[0].set_xlabel(r"Epoch", fontsize=12)
axes[1].set_xlabel(r"Epoch", fontsize=12)
axes[2].set_xlabel(r"Epoch", fontsize=12)

plt.savefig(RESULTS_DIR / "figures/results_validation-rmse-curves.pdf")

plt.clf()

<Figure size 3600x1200 with 0 Axes>