In [16]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import SubsetRandomSampler

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, NeighborLoader, RandomNodeLoader, ClusterData, ClusterLoader
from torch_geometric.nn import MessagePassing, SAGEConv
from torch_geometric.utils import index_to_mask, to_undirected, remove_self_loops
from torch_cluster import radius_graph
from torch_scatter import scatter_add

from data import *
from loader import *

import tqdm
from typing import Tuple, Optional

from model import EdgeInteractionGNN, EdgeInteractionLayer, MultiSAGENet, SAGEGraphConvNet

# Putting it all together

In [17]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5


def configure_optimizer(model, lr, wd,):
    """Only apply weight decay to weights, but not to other
    parameters like biases or LayerNorm. Based on minGPT version.
    """

    decay, no_decay = set(), set()
    yes_wd_modules = (nn.Linear, )
    no_wd_modules = (nn.LayerNorm, )
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, yes_wd_modules):
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, no_wd_modules):
                no_decay.add(fpn)
    param_dict = {pn: p for pn, p in model.named_parameters()}

    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": wd},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.},
    ]

    optimizer = torch.optim.AdamW(
        optim_groups, 
        lr=lr, 
    )

    return optimizer
    

def train_epoch_env_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    augment: bool = True
) -> float:
    """Train one epoch for GNN model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    
    for data in (dataloader):
        if augment: # add random noise
            data_node_features_scatter = 3e-4 * torch.randn_like(data.x[:, :-1]) * torch.std(data.x[:, :-1], dim=0)
            data.x[:, :-1] += data_node_features_scatter
            assert not torch.isnan(data.x).any() 
            
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data_edge_features_scatter = 3e-4 * torch.randn_like(data.edge_attr) * torch.std(data.edge_attr, dim=0)            
                data.edge_attr += data_edge_features_scatter
                assert not torch.isnan(data.edge_attr).any() 

        data.to(device)
        
        optimizer.zero_grad()
        output = model(data)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        
    return loss_total / len(dataloader)


def validate_env_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate GNN model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    y_preds = []
    y_trues = []

    # kinda janky but otherwise this seems impossible to track
    subhalo_ids = []
    is_central = []
    
    for data in dataloader:
        with torch.no_grad():
            data.to(device)
            
            output = model(data)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
            loss_total += loss.item()
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(data.y.detach().cpu().numpy())
            subhalo_ids.append(data.subhalo_id.detach().cpu().numpy())
            is_central.append(data.is_central.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    subhalo_ids = np.concatenate(subhalo_ids, axis=0)
    is_central = np.concatenate(is_central, axis=0).flatten()
    
    return loss_total / len(dataloader), y_preds, y_trues, subhalo_ids, is_central

In [22]:
ENV_DISTANCE_SCALES = [0.6, 1.0, 1.4, 1.8, 2.3]

K_FOLDS = 3
USE_LOOPS = False
NUM_PARTS = 48

N_EPOCHS = 500
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-4

N_LAYERS = 1
N_HIDDEN = 64
N_LATENT = 16
N_UNSHARED_LAYERS = 16
AGGR_FUNC = "multi"

device = "cuda"


In [None]:
# data loading & determine split
subhalos = pd.read_parquet("../results/subhalos.parquet")

for D_LINK in ENV_DISTANCE_SCALES:
    print(f"Working on distance scale {D_LINK}")
    
    env_data_fname = f"../results/env/cosmic_graphs_{D_LINK}Mpc.pkl"
    if not os.path.exists(env_data_fname):
        env_data = make_cosmic_graph(subhalos, D_LINK) 
        with open(env_data_fname, "wb") as f:
            pickle.dump(env_data, f)
    else:
        with open(env_data_fname, "rb") as f:
            env_data = pickle.load(f)
    
    train_valid_split = [
        get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS)
        for k in range(K_FOLDS)
    ]
    
    assert sum(len(v) for t, v in train_valid_split) == env_data.y.shape[0]
    
    # remove self-loops
    if not USE_LOOPS:
        env_data.edge_index, env_data.edge_attr = remove_self_loops(env_data.edge_index, env_data.edge_attr)
    
    # keep these variables for later use
    is_central = env_data.is_central
    subhalo_ids = env_data.subhalo_id
    
    # create a global mask but we'll use it later! -- note somewhat janky implementation for each fold right outside the training loop
    isfinite_mask = np.logical_and(
        np.isfinite(env_data.x).all(axis=1),
        np.isfinite(env_data.y).any(axis=1)
    ).type(torch.bool)
    
    
    # dynamically determine num features
    node_features = env_data.x.shape[1]
    edge_features = env_data.edge_attr.shape[1]
    out_features = env_data.y.shape[1]
    
    for k in range(K_FOLDS):
        log_file = f"../results/env/logs/env_gnn_{D_LINK}Mpc_fold_{k}.txt"
        
        with open(log_file, "a") as f:
            f.write(f"epoch,train_loss,valid_loss,valid_RMSE\n")
    
        # train-valid split & dataloaders
        train_indices, valid_indices = train_valid_split[k]
        
        train_data = ClusterData(
            env_data.subgraph(train_indices), 
            num_parts=NUM_PARTS, 
            recursive=False,
            log=False
        )
        train_loader = ClusterLoader(
            train_data,
            shuffle=True,
            batch_size=1,
        )
        
        valid_data = ClusterData(
            env_data.subgraph(valid_indices), 
            num_parts=NUM_PARTS // 2, 
            recursive=False,
            log=False
        )
        valid_loader = ClusterLoader(
            valid_data,
            shuffle=False, 
            batch_size=1,
        )
    
        model = EdgeInteractionGNN(
            node_features=node_features,
            edge_features=edge_features, 
            n_layers=N_LAYERS, 
            hidden_channels=N_HIDDEN,
            latent_channels=N_LATENT,
            n_unshared_layers=N_UNSHARED_LAYERS,
            n_out=out_features,
            aggr=(["sum", "max", "mean"] if AGGR_FUNC == "multi" else AGGR_FUNC)
        )
        model.to(device);
    
        optimizer = configure_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
    
        train_losses = []
        valid_losses = []
        valid_rmses = []
        
        epoch_pbar = tqdm.tqdm(range(N_EPOCHS), desc=f"Fold {k} Training", leave=True)
        for epoch in epoch_pbar:
            if epoch == int(N_EPOCHS * 0.25):
                optimizer = configure_optimizer(model, LEARNING_RATE/5, WEIGHT_DECAY)
            elif epoch == (N_EPOCHS * 0.5):
                optimizer = configure_optimizer(model, LEARNING_RATE/25, WEIGHT_DECAY)
            elif epoch == (N_EPOCHS * 0.75):
                optimizer = configure_optimizer(model, LEARNING_RATE/125, WEIGHT_DECAY)
                
            train_loss = train_epoch_env_gnn(train_loader, model, optimizer, device="cuda")
            valid_loss, preds, targs, subhalo_ids, is_central = validate_env_gnn(valid_loader, model, device="cuda")
        
            valid_rmse = compute_rmse(preds, targs)
            with open(log_file, "a") as f:
                f.write(f"{epoch:d},{train_loss:.6f},{valid_loss:.6f},{valid_rmse:.6f}\n")
                
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            valid_rmses.append(valid_rmse)
        
            epoch_pbar.set_postfix({'valid_rmse': f'{valid_rmse:.4f}'})
        
        # save predictions
        results_file = f"../results/env/predictions/env_gnn_{D_LINK}Mpc_fold_{k}.parquet"
    
        results_df = pd.DataFrame({
            "subhalo_id": subhalo_ids,
            "log_Mstar_pred": preds[:, 0],
            "log_Mstar_true": targs[:, 0],
            "log_Mgas_pred": preds[:, 1],
            "log_Mgas_true": targs[:, 1],
            "is_central": is_central,
        }).set_index("subhalo_id")
    
        results_df.to_parquet(results_file)
    
        # save model weights
        model_file = f"../results/env/models/env_gnn_{D_LINK}Mpc_fold_{k}.pth"
        torch.save(model.state_dict(), model_file)

Working on distance scale 0.6
Fold 0/3: Train=84952, Valid=39685
Fold 1/3: Train=75945, Valid=49913
Fold 2/3: Train=81916, Valid=43355


Fold 0 Training:   0%|                                                     | 0/500 [00:00<?, ?it/s]

# Results from varying environment

In [34]:
results_dict = {}
for D_LINK in [0.6, 1.0, 1.4, 1.8, 2.3]:
    
    results_dict[D_LINK] = pd.concat(
        [pd.read_parquet(f"../results/env/predictions/env_gnn_{D_LINK}Mpc_fold_{k}.parquet") for k in range(K_FOLDS)],
        axis=0
    )
    # include k folds
    results_dict[D_LINK]['k_fold'] = np.concatenate([
        np.full(pd.read_parquet(f"../results/env/predictions/env_gnn_{D_LINK}Mpc_fold_{k}.parquet").shape[0], k) 
        for k in range(K_FOLDS)
    ])


In [35]:

from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error, median_absolute_error
from easyquery import Query, QueryMaker

In [84]:
metrics_mapping = {
    r"$R^2$": lambda p, y: r2_score(y, p, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"RMSE": lambda p, y: root_mean_squared_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"MAE":lambda p, y: mean_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"NMAD": lambda p, y: 1.4826 * median_absolute_error(p, y, sample_weight=np.isfinite(y.values).nonzero()[0]),
    r"Bias": lambda p, y: np.average(p - y, weights=np.isfinite(y.values).nonzero()[0])
}

min_stellar_mass = 8.5
q_env_gnn = Query("is_central == 1", f"log_Mstar_true > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true"))
# q_env_gnn = Query(f"log_Mstar_true > {min_stellar_mass}", QueryMaker.isfinite("log_Mgas_true"))

In [85]:
# avg & error -- weighted by number of samples
for target in ["log_Mstar", "log_Mgas"]:
    print("".join(["="]*10))
    print(f"{target}")
    print("".join(["="]*10))
    q = q_env_gnn

    for metric, func in metrics_mapping.items():
        for D_LINK in [0.6, 1.0, 1.4, 1.8, 2.3]:
            df = results_dict[D_LINK]
            scores = []
            weights = []
            for k in range(3):
                qk = Query(f"k_fold == {k}")
                filtered = (qk & q).filter(df)
                scores.append(func(filtered[f"{target}_pred"], filtered[f"{target}_true"]))
                weights.append(len(filtered))   
            avg_weighted = np.average(scores, weights=weights)
            std_weighted = np.sqrt(np.cov(scores, aweights=weights))

            print(f"{D_LINK: 0.1f} Mpc {metric: >7s}: ${avg_weighted:.4f} \pm {std_weighted:.4f}$")

log_Mstar
 0.6 Mpc   $R^2$: $0.9188 \pm 0.0059$
 1.0 Mpc   $R^2$: $0.9210 \pm 0.0049$
 1.4 Mpc   $R^2$: $0.9175 \pm 0.0060$
 1.8 Mpc   $R^2$: $0.9232 \pm 0.0043$
 2.3 Mpc   $R^2$: $0.9214 \pm 0.0024$
 0.6 Mpc    RMSE: $0.1439 \pm 0.0065$
 1.0 Mpc    RMSE: $0.1409 \pm 0.0087$
 1.4 Mpc    RMSE: $0.1435 \pm 0.0060$
 1.8 Mpc    RMSE: $0.1389 \pm 0.0054$
 2.3 Mpc    RMSE: $0.1410 \pm 0.0065$
 0.6 Mpc     MAE: $0.1067 \pm 0.0031$
 1.0 Mpc     MAE: $0.1053 \pm 0.0055$
 1.4 Mpc     MAE: $0.1057 \pm 0.0056$
 1.8 Mpc     MAE: $0.1035 \pm 0.0035$
 2.3 Mpc     MAE: $0.1053 \pm 0.0049$
 0.6 Mpc    NMAD: $0.1236 \pm 0.0032$
 1.0 Mpc    NMAD: $0.1234 \pm 0.0050$
 1.4 Mpc    NMAD: $0.1209 \pm 0.0100$
 1.8 Mpc    NMAD: $0.1210 \pm 0.0057$
 2.3 Mpc    NMAD: $0.1225 \pm 0.0058$
 0.6 Mpc    Bias: $-0.0277 \pm 0.0094$
 1.0 Mpc    Bias: $-0.0285 \pm 0.0131$
 1.4 Mpc    Bias: $-0.0311 \pm 0.0050$
 1.8 Mpc    Bias: $-0.0263 \pm 0.0054$
 2.3 Mpc    Bias: $-0.0224 \pm 0.0185$
log_Mgas
 0.6 Mpc   $R^2$: $0.8043 