In [1]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import SubsetRandomSampler

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader, DynamicBatchSampler
from torch_geometric.nn import MessagePassing, SAGEConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.utils import index_to_mask, to_undirected, remove_self_loops
from torch_cluster import radius_graph
from torch_scatter import scatter_add

import tqdm
from typing import Tuple, Optional, List

from loader import create_bonsai_stump_pairs
from model import EdgeInteractionGNN, EdgeInteractionLayer, MultiSAGENet, ModernSAGENet, SAGEEncoder, BonsaiStumpSAGENet, ModernBonsaiStumpSAGENet

from muon import SingleDeviceMuonWithAuxAdam


from pathlib import Path
BASE_DIR = Path("../")
RESULTS_DIR = Path(BASE_DIR / "results")

# Prototype

## Load previous helper functions

In [2]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, epsilon: float = 1e-10, indices_mask=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. Each fold uses 1/K of the box for validation and the rest
    for training (with padding to avoid boundary effects).
    
    Args:
        data: PyTorch Geometric data object with pos attribute
        k: Fold index (0 to K-1)
        K: Total number of folds
        boxsize: Simulation box size in Mpc
        pad: Padding between train/valid regions in Mpc
        epsilon: Small value to avoid boundary issues
        indices_mask: either None or a boolean mask of length X.shape[0]
        
    Returns:
        Tuple of (train_indices, valid_indices) as torch tensors
    """

    if indices_mask is None:
        z_coords = data.pos[:, 2]
    else:
        z_coords = data.pos[:, 2][indices_mask]
    
    
    # Calculate validation region boundaries
    valid_start = (k / K * boxsize) % boxsize
    valid_end = ((k + 1) / K * boxsize) % boxsize
    
    # Handle wrap-around case
    if valid_start > valid_end:  # Wraps around the boundary
        valid_mask = (z_coords >= valid_start) | (z_coords <= valid_end)
    else:
        valid_mask = (z_coords >= valid_start) & (z_coords <= valid_end)
    
    # Create training region with padding
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    # Handle wrap-around for training region
    if train_start > train_end:  # Wraps around the boundary
        train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        train_mask = (z_coords >= train_start) & (z_coords <= train_end)


    # Get indices
    train_indices = train_mask.nonzero(as_tuple=True)[0]
    valid_indices = valid_mask.nonzero(as_tuple=True)[0]
    
    # Ensure zero overlap
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices between train and validation"
    
    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5

## Load data

In [3]:
with open("../results/merger_trees.pkl", "rb") as f:
    tree_data = pickle.load(f)

In [60]:
tree_data[0]

Data(x=[902888, 4], edge_index=[2, 902887], y=[2], root_subhalo_id=59551)

In [4]:
n_subhalos_per_tree = [tree.x.shape[0] for tree in tree_data]

In [None]:
plt.figure(figsize=(4,4), dpi=150)
plt.hist(np.log10(n_subhalos_per_tree), log=True, bins=100)
plt.xlabel(r"log $N_{\rm subhalo\ per\ tree}$", fontsize=12)
plt.ylabel(r"$dN_{\rm trees}/d \log N_{\rm subhalo\ per\ tree}$", fontsize=12);
plt.grid(alpha=0.15)
plt.clf();

In [6]:
# get positions to assign folds...
with open("../results/cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

In [7]:
tree_subhalo_ids = [tree.root_subhalo_id for tree in tree_data]

In [8]:
env_data

Data(x=[132953, 3], edge_index=[2, 6653247], edge_attr=[6653247, 6], y=[132953, 2], pos=[132953, 3], vel=[132953, 3], is_central=[132953, 1], x_hydro=[132953, 2], pos_hydro=[132953, 3], vel_hydro=[132953, 3], halfmassradius=[132953, 1], subhalo_id=[132953], overdensity=[132953])

In [9]:
tree_crossmatches = torch.isin(env_data.subhalo_id, torch.tensor(tree_subhalo_ids))

print("Num merger trees:", len(tree_data))
print("Num env nodes:", env_data.x.shape[0])

print("Crossmatches:", tree_crossmatches.sum().item())

Num merger trees: 123004
Num env nodes: 132953
Crossmatches: 123001


In [10]:
K_FOLDS = 3

train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=tree_crossmatches)
    for k in range(K_FOLDS)
]

assert sum(len(v) for t, v in train_valid_split) == tree_crossmatches.sum().item()

Fold 0/3: Train=78822, Valid=36452
Fold 1/3: Train=70062, Valid=46397
Fold 2/3: Train=75703, Valid=40152


In [None]:
# create a mapping from subhalo_id to tree for efficient lookup
tree_map = {tree.root_subhalo_id: tree for tree in tree_data}

# full list of subhalo IDs that have trees, in the env_data order
subhalo_ids_with_trees = env_data.subhalo_id[tree_crossmatches]

## Sampling and data loading

In [11]:
k = 0

train_indices, valid_indices = train_valid_split[k]

In [79]:
# note that the tree_data[0].y shape needs to be like [1, 2], otherwise it'll be flattened
for tree in tree_data:
    tree.y = tree.y.reshape(1, -1)

In [80]:

# THIS IS WRONG!
# train_dataset = [tree_data[idx] for idx in train_indices]
# valid_dataset = [tree_data[idx] for idx in valid_indices]

# correct subhalo_ids for the train/valid split
train_ids_for_fold = subhalo_ids_with_trees[train_indices].numpy()
valid_ids_for_fold = subhalo_ids_with_trees[valid_indices].numpy()

# Build datasets by looking up subhalo_id
train_dataset = [tree_map[sub_id] for sub_id in train_ids_for_fold]
valid_dataset = [tree_map[sub_id] for sub_id in valid_ids_for_fold]

train_sampler = DynamicBatchSampler(
    train_dataset,
    max_num=max(n_subhalos_per_tree), # around 1e6
    mode="node",
    shuffle=True
)

valid_sampler = DynamicBatchSampler(
    valid_dataset,
    max_num=max(n_subhalos_per_tree),
    mode="node",
    shuffle=False
)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)
valid_loader = DataLoader(train_dataset, batch_sampler=valid_sampler)

## GNN model

In [95]:
N_IN = tree_data[0].x.shape[1]
N_OUT = tree_data[0].y.shape[1]
N_HIDDEN = 8
N_LAYERS = 8

device = "cuda"

In [96]:
model = MultiSAGENet(
    n_in=N_IN,
    n_hidden=N_HIDDEN,
    n_layers=N_LAYERS,
    n_out=N_OUT
).to(device)

In [97]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

3116

In [99]:
def configure_optimizer(model, lr, wd,):
    """Only apply weight decay to weights, but not to other
    parameters like biases or LayerNorm. Based on minGPT version.
    """

    decay, no_decay = set(), set()
    yes_wd_modules = (nn.Linear, )
    no_wd_modules = (nn.LayerNorm, )
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, yes_wd_modules):
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, no_wd_modules):
                no_decay.add(fpn)
    param_dict = {pn: p for pn, p in model.named_parameters()}

    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": wd},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.},
    ]

    optimizer = torch.optim.AdamW(
        optim_groups, 
        lr=lr, 
    )

    return optimizer

In [100]:
def train_epoch_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    augment: bool = True
) -> float:
    """Train one epoch for GNN model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    n_graphs = 0
    for data in (dataloader):
        if augment: # add random noise
            data_node_features_scatter = 3e-4 * torch.randn_like(data.x[:, :-1]) * torch.std(data.x[:, :-1], dim=0)
            data.x[:, :-1] += data_node_features_scatter
            assert not torch.isnan(data.x).any() 
            
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data_edge_features_scatter = 3e-4 * torch.randn_like(data.edge_attr) * torch.std(data.edge_attr, dim=0)            
                data.edge_attr += data_edge_features_scatter
                assert not torch.isnan(data.edge_attr).any() 

        data.to(device)
        
        optimizer.zero_grad()
        output = model(data)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        n_graphs += data.y.shape[0]
        
    return loss_total / n_graphs


def validate_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate GNN model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    n_graphs = 0
    y_preds = []
    y_trues = []
    
    for data in dataloader:
        with torch.no_grad():
            data.to(device)
            
            output = model(data)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
            loss_total += loss.item()
            n_graphs += data.y.shape[0]
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(data.y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)

    
    
    return loss_total / n_graphs, y_preds, y_trues

In [101]:
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

optimizer = configure_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)

In [None]:
N_EPOCHS = 200

train_losses = []
valid_losses = []
valid_rmses = []


for epoch in range(N_EPOCHS):
    
    train_loss = train_epoch_merger_gnn(train_loader, model, optimizer, device="cuda")
    valid_loss, preds, targs = validate_merger_gnn(valid_loader, model, device="cuda")

    valid_rmse = compute_rmse(preds, targs)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    valid_rmses.append(valid_rmse)

    if (epoch + 1) % 10 == 0:
        print(f"{epoch: >3d} {valid_rmse:.4f}")

# Putting it all together (full tree)

In [2]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5


def configure_optimizer(model, lr, wd,):
    """Only apply weight decay to weights, but not to other
    parameters like biases or LayerNorm. Based on minGPT version.
    """

    decay, no_decay = set(), set()
    yes_wd_modules = (nn.Linear, )
    no_wd_modules = (nn.LayerNorm, )
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, yes_wd_modules):
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, no_wd_modules):
                no_decay.add(fpn)
    param_dict = {pn: p for pn, p in model.named_parameters()}

    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": wd},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.},
    ]

    optimizer = torch.optim.AdamW(
        optim_groups, 
        lr=lr, 
    )

    return optimizer
    
def train_epoch_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    augment: bool = True
) -> float:
    """Train one epoch for GNN model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    n_graphs = 0
    for data in (dataloader):
        if augment: # add random noise
            data_node_features_scatter = 3e-4 * torch.randn_like(data.x[:, :-1]) * torch.std(data.x[:, :-1], dim=0)
            data.x[:, :-1] += data_node_features_scatter
            assert not torch.isnan(data.x).any() 
            
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data_edge_features_scatter = 3e-4 * torch.randn_like(data.edge_attr) * torch.std(data.edge_attr, dim=0)            
                data.edge_attr += data_edge_features_scatter
                assert not torch.isnan(data.edge_attr).any() 

        data.to(device)
        
        optimizer.zero_grad()
        output = model(data)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        n_graphs += data.y.shape[0]
        
    return loss_total / n_graphs


def validate_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate GNN model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    n_graphs = 0
    y_preds = []
    y_trues = []
    
    for data in dataloader:
        with torch.no_grad():
            data.to(device)
            
            output = model(data)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
            loss_total += loss.item()
            n_graphs += data.y.shape[0]
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(data.y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)

    
    
    return loss_total / n_graphs, y_preds, y_trues

In [3]:
K_FOLDS = 3

N_HIDDEN = 32
N_LAYERS = 12

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# hyperparams for one-cycle LR schedule (based on Jespersen+ 2022)
# PCT_START = 0.15
# FINAL_DIV = 1e3

BATCH_SIZE = 32
DYNAMIC_SAMPLING = False


N_EPOCHS = 25

device = "cuda"

In [4]:
with open(RESULTS_DIR / "merger_trees.pkl", "rb") as f:
    tree_data = pickle.load(f)

# need to know for dynamic sampling
n_subhalos_per_tree = [tree.x.shape[0] for tree in tree_data]

# reshape tree.y! -- also lazy removal of features that are always -inf
for tree in tree_data: 
    tree.x = torch.concatenate([tree.x[:, :2], tree.x[:, 3:4], tree.x[:, 5:]], axis=1)
    tree.x[~torch.isfinite(tree.x)] = -3
    tree.y = tree.y.reshape(1, -1)

In [5]:

# use env graph to assign same folds as other experiments
# note that this will be a SUBSET of the env graph subhalos!!!
with open(RESULTS_DIR /  "cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

# tree_subhalo_ids = [tree.root_subhalo_id for tree in tree_data]
# tree_crossmatches = torch.isin(env_data.subhalo_id, torch.tensor(tree_subhalo_ids))

# find which subhalos in env_data have trees
all_tree_ids = set(tree.root_subhalo_id for tree in tree_data)
tree_crossmatches = torch.tensor([sid.item() in all_tree_ids for sid in env_data.subhalo_id], dtype=torch.bool)


train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=tree_crossmatches)
    for k in range(K_FOLDS)
]

# metadata
subhalo_ids = env_data.subhalo_id[tree_crossmatches]
is_central = torch.full_like(subhalo_ids, True, dtype=bool)


# create a mapping from subhalo_id to tree for efficient lookup
tree_map = {tree.root_subhalo_id: tree for tree in tree_data}

N_IN = tree_data[0].x.shape[1]
N_OUT = tree_data[0].y.shape[1]

Fold 0/3: Train=78822, Valid=36452
Fold 1/3: Train=70062, Valid=46397
Fold 2/3: Train=75703, Valid=40152


In [None]:
for k in range(K_FOLDS):
    log_file = RESULTS_DIR / f"logs/tree_gnn_fold_{k}.txt"
    
    with open(log_file, "a") as f:
        f.write(f"epoch,train_loss,valid_loss,valid_RMSE\n")

    # train-valid split & dataloaders
    train_indices, valid_indices = train_valid_split[k]
    
    # map back to subhalo_ids
    train_ids_for_fold = env_data.subhalo_id[train_indices].numpy()
    valid_ids_for_fold = env_data.subhalo_id[valid_indices].numpy()

    is_central_valid_fold = env_data.is_central[valid_indices].flatten().numpy()

    train_dataset = [tree_map[sub_id] for sub_id in train_ids_for_fold]
    valid_dataset = [tree_map[sub_id] for sub_id in valid_ids_for_fold]

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = MultiSAGENet(
        n_in=N_IN,
        n_hidden=N_HIDDEN,
        n_layers=N_LAYERS,
        n_out=N_OUT
    )
    # model = ModernSAGENet(
    #     n_in=N_IN,
    #     n_hidden=N_HIDDEN,
    #     n_layers=N_LAYERS,
    #     n_out=N_OUT
    # )
    model.to(device);

    # optimizer = configure_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     optimizer,
    #     max_lr=LEARNING_RATE,
    #     epochs=N_EPOCHS, 
    #     steps_per_epoch=len(train_loader),
    #     pct_start=PCT_START,
    #     final_div_factor=FINAL_DIV,
    # )

    # Muon optimizer
    hidden_weights = [p for p in model.parameters() if p.ndim >= 2]
    hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2]
    param_groups = [
        dict(params=hidden_weights, use_muon=True, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
        dict(params=hidden_gains_biases, use_muon=False, lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0),
    ]
    
    optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

    train_losses = []
    valid_losses = []
    valid_rmses = []
    
    epoch_pbar = tqdm.tqdm(range(N_EPOCHS), desc=f"Fold {k} Training", leave=True)
    for epoch in epoch_pbar:
  
        train_loss = train_epoch_merger_gnn(train_loader, model, optimizer, device=device)
        valid_loss, preds, targs = validate_merger_gnn(valid_loader, model, device=device)
        # scheduler.step()
    
        valid_rmse = compute_rmse(preds, targs)
        with open(log_file, "a") as f:
            f.write(f"{epoch:d},{train_loss:.6f},{valid_loss:.6f},{valid_rmse:.6f}\n")
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_rmses.append(valid_rmse)
    
        epoch_pbar.set_postfix({'valid_rmse': f'{valid_rmse:.4f}'})

    # save predictions
    results_file = RESULTS_DIR / f"predictions/tree_gnn_fold_{k}.parquet"

    results_df = pd.DataFrame({
        "subhalo_id": valid_ids_for_fold,
        "log_Mstar_pred": preds[:, 0],
        "log_Mstar_true": targs[:, 0],
        "log_Mgas_pred": preds[:, 1],
        "log_Mgas_true": targs[:, 1],
        "is_central": is_central_valid_fold,
    }).set_index("subhalo_id")

    results_df.to_parquet(results_file)

    # save model weights
    model_file = RESULTS_DIR / f"models/tree_gnn_fold_{k}.pth"
    torch.save(model.state_dict(), model_file)

Fold 0 Training:   4%|█                          | 1/25 [00:55<22:05, 55.24s/it, valid_rmse=0.3090]

In [None]:
del model
import gc
gc.collect()

# Training Stumps + Bonsais (jointly)

In [None]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5


def configure_optimizer(model, lr, wd,):
    """Only apply weight decay to weights, but not to other
    parameters like biases or LayerNorm. Based on minGPT version.
    """

    decay, no_decay = set(), set()
    yes_wd_modules = (nn.Linear, )
    no_wd_modules = (nn.LayerNorm, )
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, yes_wd_modules):
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, no_wd_modules):
                no_decay.add(fpn)
    param_dict = {pn: p for pn, p in model.named_parameters()}

    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": wd},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.},
    ]

    optimizer = torch.optim.AdamW(
        optim_groups, 
        lr=lr, 
    )

    return optimizer
    
def train_epoch_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    augment: bool = True
) -> float:
    """Train one epoch for GNN model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    n_graphs = 0
    for data in (dataloader):
        if augment: # add random noise
            data_node_features_scatter = 3e-4 * torch.randn_like(data.x[:, :-1]) * torch.std(data.x[:, :-1], dim=0)
            data.x[:, :-1] += data_node_features_scatter
            assert not torch.isnan(data.x).any() 
            
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data_edge_features_scatter = 3e-4 * torch.randn_like(data.edge_attr) * torch.std(data.edge_attr, dim=0)            
                data.edge_attr += data_edge_features_scatter
                assert not torch.isnan(data.edge_attr).any() 

        data.to(device)
        
        optimizer.zero_grad()
        output = model(data)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        n_graphs += data.y.shape[0]
        
    return loss_total / n_graphs


def validate_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate GNN model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    n_graphs = 0
    y_preds = []
    y_trues = []
    
    for data in dataloader:
        with torch.no_grad():
            data.to(device)
            
            output = model(data)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
            loss_total += loss.item()
            n_graphs += data.y.shape[0]
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(data.y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    
    return loss_total / n_graphs, y_preds, y_trues

In [None]:
with open(RESULTS_DIR / "cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

with open(RESULTS_DIR / "merger_tree_bonsais.pkl", "rb") as f:
    bonsais = pickle.load(f)

with open(RESULTS_DIR / "merger_tree_stumps.pkl", "rb") as f:
    stumps = pickle.load(f)

In [None]:
tree_data = create_bonsai_stump_pairs(bonsais, stumps)

In [None]:
# reshape tree.y! -- also skip columns 2, 4 which are also -inf
for tree in tree_data: 
    tree['bonsai'].x = torch.concatenate([tree["bonsai"].x[:, :2], tree["bonsai"].x[:, 3:4], tree["bonsai"].x[:, 5:]], axis=1)
    tree['stump'].x = torch.concatenate([tree["stump"].x[:, :2], tree["stump"].x[:, 3:4], tree["stump"].x[:, 5:]], axis=1)
    tree['bonsai'].x[~torch.isfinite(tree['bonsai'].x)] = -3
    tree['stump'].x[~torch.isfinite(tree['stump'].x)] = -3
    tree.y = tree.y.reshape(1, -1)

In [15]:
K_FOLDS = 3

N_HIDDEN = 32
N_LAYERS = 8

LEARNING_RATE = 3e-3
WEIGHT_DECAY = 1e-4

# hyperparams for one-cycle LR schedule (based on Jespersen+ 2022)
# PCT_START = 0.15
# INIT_DIV = 1e1
# FINAL_DIV = 1e3


BATCH_SIZE = 128
DYNAMIC_SAMPLING = False


N_EPOCHS = 25

device = "cuda"

In [16]:
# find which subhalos in env_data have trees
all_tree_ids = set(tree.root_subhalo_id for tree in tree_data)
tree_crossmatches = torch.tensor([sid.item() in all_tree_ids for sid in env_data.subhalo_id], dtype=torch.bool)


train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=tree_crossmatches)
    for k in range(K_FOLDS)
]

# metadata
subhalo_ids = env_data.subhalo_id[tree_crossmatches]
is_central = torch.full_like(subhalo_ids, True, dtype=bool)

# create a mapping from subhalo_id to tree for efficient lookup
tree_map = {tree.root_subhalo_id: tree for tree in tree_data}

N_IN = tree_data[0]['bonsai'].x.shape[1]
N_OUT = tree_data[0].y.shape[1]

Fold 0/3: Train=78822, Valid=36452
Fold 1/3: Train=70062, Valid=46397
Fold 2/3: Train=75703, Valid=40152


In [None]:
for k in range(K_FOLDS):
    log_file = RESULTS_DIR / f"logs/bstree_gnn_fold_{k}.txt"
    
    with open(log_file, "a") as f:
        f.write(f"epoch,train_loss,valid_loss,valid_RMSE\n")

    # train-valid split & dataloaders
    train_indices, valid_indices = train_valid_split[k]
    
    # map back to subhalo_ids
    train_ids_for_fold = env_data.subhalo_id[train_indices].numpy()
    valid_ids_for_fold = env_data.subhalo_id[valid_indices].numpy()

    is_central_valid_fold = env_data.is_central[valid_indices].flatten().numpy()

    train_dataset = [tree_map[sub_id] for sub_id in train_ids_for_fold]
    valid_dataset = [tree_map[sub_id] for sub_id in valid_ids_for_fold]

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
  
    model = ModernBonsaiStumpSAGENet(
        n_in=N_IN,
        n_hidden=N_HIDDEN,
        n_layers=N_LAYERS,
        n_out=N_OUT,
        act_fn=nn.SiLU()
    )
    
    # model = BonsaiStumpSAGENet(
    #     n_in=N_IN,
    #     n_hidden=N_HIDDEN,
    #     n_layers=N_LAYERS,
    #     n_out=N_OUT,
    # )
    
    model.to(device);

    # # 1-cycle
    # optimizer = configure_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     optimizer,
    #     max_lr=LEARNING_RATE,
    #     epochs=N_EPOCHS, 
    #     steps_per_epoch=len(train_loader),
    #     pct_start=PCT_START,
    #     div_factor=INIT_DIV,
    #     final_div_factor=FINAL_DIV,
    # )

    # Muon optimizer
    hidden_weights = [p for p in model.parameters() if p.ndim >= 2]
    hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2]
    param_groups = [
        dict(params=hidden_weights, use_muon=True, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
        dict(params=hidden_gains_biases, use_muon=False, lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0),
    ]
    
    optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

    train_losses = []
    valid_losses = []
    valid_rmses = []
    
    epoch_pbar = tqdm.tqdm(range(N_EPOCHS), desc=f"Fold {k} Training", leave=True)
    for epoch in epoch_pbar:
  
        train_loss = train_epoch_merger_gnn(train_loader, model, optimizer, augment=False, device=device) # augment doesn't work with the HeteroData
        valid_loss, preds, targs = validate_merger_gnn(valid_loader, model, device=device)
        # scheduler.step()
    
        valid_rmse = compute_rmse(preds, targs)
        with open(log_file, "a") as f:
            f.write(f"{epoch:d},{train_loss:.6f},{valid_loss:.6f},{valid_rmse:.6f}\n")
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_rmses.append(valid_rmse)
    
        epoch_pbar.set_postfix({'valid_rmse': f'{valid_rmse:.4f}'})

    # save predictions
    results_file = RESULTS_DIR / f"predictions/bstree_gnn_fold_{k}.parquet"

    results_df = pd.DataFrame({
        "subhalo_id": valid_ids_for_fold,
        "log_Mstar_pred": preds[:, 0],
        "log_Mstar_true": targs[:, 0],
        "log_Mgas_pred": preds[:, 1],
        "log_Mgas_true": targs[:, 1],
        "is_central": is_central_valid_fold,
    }).set_index("subhalo_id")

    results_df.to_parquet(results_file)

    # save model weights
    model_file = RESULTS_DIR / f"models/bstree_gnn_fold_{k}.pth"
    torch.save(model.state_dict(), model_file)

Fold 0 Training:  28%|███████▌                   | 7/25 [03:31<09:00, 30.05s/it, valid_rmse=0.3073]

# Residual learning with full Merger Tree

In [1]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import SubsetRandomSampler

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader, DynamicBatchSampler
from torch_geometric.nn import MessagePassing, SAGEConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.utils import index_to_mask, to_undirected, remove_self_loops
from torch_cluster import radius_graph
from torch_scatter import scatter_add

import tqdm
from typing import Tuple, Optional, List

from loader import create_bonsai_stump_pairs
from model import EdgeInteractionGNN, EdgeInteractionLayer, MultiSAGENet, ModernSAGENet, SAGEEncoder, BonsaiStumpSAGENet, ModernBonsaiStumpSAGENet

from muon import SingleDeviceMuonWithAuxAdam


from pathlib import Path
BASE_DIR = Path("../")
RESULTS_DIR = Path(BASE_DIR / "results")

In [2]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5


def configure_optimizer(model, lr, wd,):
    """Only apply weight decay to weights, but not to other
    parameters like biases or LayerNorm. Based on minGPT version.
    """

    decay, no_decay = set(), set()
    yes_wd_modules = (nn.Linear, )
    no_wd_modules = (nn.LayerNorm, )
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if pn.endswith('bias'):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, yes_wd_modules):
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, no_wd_modules):
                no_decay.add(fpn)
    param_dict = {pn: p for pn, p in model.named_parameters()}

    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": wd},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.},
    ]

    optimizer = torch.optim.AdamW(
        optim_groups, 
        lr=lr, 
    )

    return optimizer
    
def train_epoch_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    augment: bool = True
) -> float:
    """Train one epoch for GNN model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    n_graphs = 0
    for data in (dataloader):
        if augment: # add random noise
            data_node_features_scatter = 3e-4 * torch.randn_like(data.x[:, :-1]) * torch.std(data.x[:, :-1], dim=0)
            data.x[:, :-1] += data_node_features_scatter
            assert not torch.isnan(data.x).any() 
            
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data_edge_features_scatter = 3e-4 * torch.randn_like(data.edge_attr) * torch.std(data.edge_attr, dim=0)            
                data.edge_attr += data_edge_features_scatter
                assert not torch.isnan(data.edge_attr).any() 

        data.to(device)
        
        optimizer.zero_grad()
        output = model(data)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        n_graphs += data.y.shape[0]
        
    return loss_total / n_graphs


def validate_merger_gnn(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate GNN model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    n_graphs = 0
    y_preds = []
    y_trues = []
    
    for data in dataloader:
        with torch.no_grad():
            data.to(device)
            
            output = model(data)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, data.y.shape[1] if len(data.y.shape) > 1 else 2)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, data.y, logvar_pred)
            loss_total += loss.item()
            n_graphs += data.y.shape[0]
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(data.y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)

    
    
    return loss_total / n_graphs, y_preds, y_trues

In [3]:
K_FOLDS = 3

N_HIDDEN = 32
N_LAYERS = 12

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# hyperparams for one-cycle LR schedule (based on Jespersen+ 2022)
# PCT_START = 0.15
# FINAL_DIV = 1e3

BATCH_SIZE = 32
DYNAMIC_SAMPLING = False


N_EPOCHS = 25

device = "cuda"

In [32]:
with open(RESULTS_DIR / "merger_trees.pkl", "rb") as f:
    tree_data = pickle.load(f)

# need to know for dynamic sampling
n_subhalos_per_tree = [tree.x.shape[0] for tree in tree_data]

# reshape tree.y! -- also lazy removal of features that are always -inf
for tree in tree_data: 
    tree.x = torch.concatenate([tree.x[:, :2], tree.x[:, 3:4], tree.x[:, 5:]], axis=1)
    tree.x[~torch.isfinite(tree.x)] = -3
    tree.y = tree.y.reshape(1, -1)

In [33]:
# use env graph to assign same folds as other experiments
# note that this will be a SUBSET of the env graph subhalos!!!
with open(RESULTS_DIR /  "cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

# find which subhalos in env_data have trees
all_tree_ids = set(tree.root_subhalo_id for tree in tree_data)
tree_crossmatches = torch.tensor([sid.item() in all_tree_ids for sid in env_data.subhalo_id], dtype=torch.bool)


train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=tree_crossmatches)
    for k in range(K_FOLDS)
]

# create a mapping from subhalo_id to tree for efficient lookup
tree_map = {tree.root_subhalo_id: tree for tree in tree_data}

N_IN = tree_data[0].x.shape[1]
N_OUT = tree_data[0].y.shape[1]

Fold 0/3: Train=78822, Valid=36452
Fold 1/3: Train=70062, Valid=46397
Fold 2/3: Train=75703, Valid=40152


In [34]:
# change to residual learning...
env_predictions = pd.concat([pd.read_parquet(RESULTS_DIR / f"predictions/env_gnn_fold_{k}.parquet") for k in range(3)], axis=0)

In [35]:
for subhalo_id, tree in tree_map.items(): 
    tree.y_original = tree.y
    tree.y = torch.Tensor(env_predictions.loc[subhalo_id][["log_Mstar_pred", "log_Mgas_pred"]].values).reshape(1, -1) - tree.y

In [36]:
for k in range(K_FOLDS):
    log_file = RESULTS_DIR / f"logs/tree_residual_gnn_fold_{k}.txt"
    
    with open(log_file, "a") as f:
        f.write(f"epoch,train_loss,valid_loss,valid_RMSE\n")

    # train-valid split & dataloaders
    train_indices, valid_indices = train_valid_split[k]
    
    # map back to subhalo_ids
    train_ids_for_fold = env_data.subhalo_id[train_indices].numpy()
    valid_ids_for_fold = env_data.subhalo_id[valid_indices].numpy()

    is_central_valid_fold = env_data.is_central[valid_indices].flatten().numpy()

    train_dataset = [tree_map[sub_id] for sub_id in train_ids_for_fold]
    valid_dataset = [tree_map[sub_id] for sub_id in valid_ids_for_fold]

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = MultiSAGENet(
        n_in=N_IN,
        n_hidden=N_HIDDEN,
        n_layers=N_LAYERS,
        n_out=N_OUT
    )
    # model = ModernSAGENet(
    #     n_in=N_IN,
    #     n_hidden=N_HIDDEN,
    #     n_layers=N_LAYERS,
    #     n_out=N_OUT
    # )
    model.to(device);

    # optimizer = configure_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     optimizer,
    #     max_lr=LEARNING_RATE,
    #     epochs=N_EPOCHS, 
    #     steps_per_epoch=len(train_loader),
    #     pct_start=PCT_START,
    #     final_div_factor=FINAL_DIV,
    # )

    # Muon optimizer
    hidden_weights = [p for p in model.parameters() if p.ndim >= 2]
    hidden_gains_biases = [p for p in model.parameters() if p.ndim < 2]
    param_groups = [
        dict(params=hidden_weights, use_muon=True, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
        dict(params=hidden_gains_biases, use_muon=False, lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0),
    ]
    
    optimizer = SingleDeviceMuonWithAuxAdam(param_groups)

    train_losses = []
    valid_losses = []
    valid_rmses = []
    
    epoch_pbar = tqdm.tqdm(range(N_EPOCHS), desc=f"Fold {k} Training", leave=True)
    for epoch in epoch_pbar:
  
        train_loss = train_epoch_merger_gnn(train_loader, model, optimizer, device=device)
        valid_loss, preds, targs = validate_merger_gnn(valid_loader, model, device=device)
        # scheduler.step()
    
        valid_rmse = compute_rmse(preds, targs)
        with open(log_file, "a") as f:
            f.write(f"{epoch:d},{train_loss:.6f},{valid_loss:.6f},{valid_rmse:.6f}\n")
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_rmses.append(valid_rmse)
    
        epoch_pbar.set_postfix({'valid_rmse': f'{valid_rmse:.4f}'})

    # save predictions
    results_file = RESULTS_DIR / f"predictions/tree_residual_gnn_fold_{k}.parquet"

    results_df = pd.DataFrame({
        "subhalo_id": valid_ids_for_fold,
        "log_Mstar_pred": preds[:, 0],
        "log_Mstar_true": targs[:, 0],
        "log_Mgas_pred": preds[:, 1],
        "log_Mgas_true": targs[:, 1],
        "is_central": is_central_valid_fold,
    }).set_index("subhalo_id")

    results_df.to_parquet(results_file)

    # save model weights
    model_file = RESULTS_DIR / f"models/tree_residual_gnn_fold_{k}.pth"
    torch.save(model.state_dict(), model_file)

Fold 0 Training: 100%|███████████████████████████████████████████████| 25/25 [26:19<00:00, 63.19s/it, valid_rmse=0.1952]
Fold 1 Training: 100%|███████████████████████████████████████████████| 25/25 [24:43<00:00, 59.35s/it, valid_rmse=0.2124]
Fold 2 Training: 100%|███████████████████████████████████████████████| 25/25 [26:12<00:00, 62.90s/it, valid_rmse=0.1985]
