# Training MLP models

- Training logs saved in `results/logs/mlp_fold_{k}.txt`
- Predictions saved in `results/predictions/mlp_fold_{k}.parquet`

**Input features:**
- subhalo mass
- Vmax
- is_central

**Predictions**
- log_Mstar
- log_Mgas
- log_var(log_Mstar) -- not saved
- log_var(log_Mgas) -- not saved

**Important notes**
- We make an initial cut on subhalos with no valid Mstar or Mgas targets, so there are fewer subhalos (132426) than in the `results/cosmic_graphs_3Mpc.pkl` dataset (132953).
- If you want to run the final training analysis (~1.5min per k-fold, or 5 minutes in all), then skip down to the "All Together Now" section

In [1]:
import sys
sys.path.append("../src")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader
import tqdm
from typing import Tuple, Optional

# Prototype

## Data

In [2]:
with open("../results/cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

In [3]:
row_wise_mask = np.logical_and(
    np.isfinite(env_data.x).all(axis=1),   # need all inputs
    np.isfinite(env_data.y).any(axis=1)    # allowed to have gas mass = NaN
).type(torch.bool)

row_wise_mask.sum().item()

132426

In [4]:
X = env_data.x[row_wise_mask]
y = env_data.y[row_wise_mask]
is_central = env_data.is_central[row_wise_mask]

In [14]:
# def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
#                                    pad: float = 3, epsilon: float = 1e-10, indices_mask=None):
#     """Create spatial train/validation indices using z-coordinate splits.
    
#     This creates spatially separated train/validation sets by dividing the simulation
#     box along the z-axis. Each fold uses 1/K of the box for validation and the rest
#     for training (with padding to avoid boundary effects).
    
#     Args:
#         data: PyTorch Geometric data object with pos attribute
#         k: Fold index (0 to K-1)
#         K: Total number of folds
#         boxsize: Simulation box size in Mpc
#         pad: Padding between train/valid regions in Mpc
#         epsilon: Small value to avoid boundary issues
#         indices_mask: either None or a boolean mask of length X.shape[0]
        
#     Returns:
#         Tuple of (train_indices, valid_indices) as torch tensors
#     """

#     if indices_mask is None:
#         z_coords = data.pos[:, 2]
#     else:
#         z_coords = data.pos[:, 2][indices_mask]

#     all_indices = torch.arange(len(z_coords))
    
    
#     # Calculate validation region boundaries
#     valid_start = (k / K * boxsize) % boxsize
#     valid_end = ((k + 1) / K * boxsize) % boxsize
    
#     # Handle wrap-around case
#     if valid_start > valid_end:  # Wraps around the boundary
#         valid_mask = (z_coords >= valid_start) | (z_coords <= valid_end)
#     else:
#         valid_mask = (z_coords >= valid_start) & (z_coords <= valid_end)
    
#     # Create training region with padding
#     train_start = ((k + 1) / K * boxsize + pad) % boxsize
#     train_end = (k / K * boxsize - pad) % boxsize
    
#     # Handle wrap-around for training region
#     if train_start > train_end:  # Wraps around the boundary
#         train_mask = (z_coords >= train_start) | (z_coords <= train_end)
#     else:
#         train_mask = (z_coords >= train_start) & (z_coords <= train_end)


#     # Get indices
#     train_indices = train_mask.nonzero(as_tuple=True)[0]
#     valid_indices = valid_mask.nonzero(as_tuple=True)[0]
    
#     # Ensure zero overlap
#     overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
#     assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices between train and validation"

#     if indices_mask is not None:
#         # Convert the boolean mask to a tensor of indices
#         allowed_indices = indices_mask.nonzero(as_tuple=True)[0]
        
#         # Find the intersection of the spatial slice and the allowed indices
#         valid_indices = torch.tensor(list(set(valid_indices.tolist()) & set(allowed_indices.tolist())), dtype=torch.long)
#         train_indices = torch.tensor(list(set(train_indices.tolist()) & set(allowed_indices.tolist())), dtype=torch.long)
        
#     print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
#     return train_indices, valid_indices

In [16]:
def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    # --- Step 4: Convert boolean masks to global indices ---
    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices

In [17]:
K_FOLDS = 3

train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=None)
    for k in range(K_FOLDS)
]

assert sum(len(v) for t, v in train_valid_split) == env_data.x.shape[0]

Fold 0/3: Train=84952, Valid=39685
Fold 1/3: Train=75945, Valid=49913
Fold 2/3: Train=81916, Valid=43355


In [18]:
K_FOLDS = 3

train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=row_wise_mask)
    for k in range(K_FOLDS)
]

assert sum(len(v) for t, v in train_valid_split) == sum(row_wise_mask)

Fold 0/3: Train=84601, Valid=39562
Fold 1/3: Train=75669, Valid=49687
Fold 2/3: Train=81609, Valid=43177


## MLP and training hyperparams

In [7]:
def train_epoch_mlp(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str
) -> float:
    """Train one epoch for MLP model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        output = model(X)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, y.shape[1] if len(y.shape) > 1 else 1)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        
    return loss_total / len(dataloader)


def validate_mlp(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate MLP model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    y_preds = []
    y_trues = []
    
    for X, y in dataloader:
        with torch.no_grad():
            X = X.to(device)
            y = y.to(device)
            
            output = model(X)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, y.shape[1] if len(y.shape) > 1 else 1)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, y, logvar_pred)
            loss_total += loss.item()
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    
    return loss_total / len(dataloader), y_preds, y_trues

def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


In [24]:
N_INPUTS = X.shape[1]
N_HIDDEN = 128
N_OUTPUTS = y.shape[1]

In [25]:
model = nn.Sequential(
    nn.Linear(N_INPUTS, N_HIDDEN),
    nn.ReLU(),
    nn.Linear(N_HIDDEN, N_HIDDEN),
    nn.ReLU(),
    nn.Linear(N_HIDDEN, N_HIDDEN),
    nn.ReLU(),
    nn.Linear(N_HIDDEN, 2*N_OUTPUTS)
)

model.cuda()

Sequential(
  (0): Linear(in_features=3, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=4, bias=True)
)

In [26]:
LEARNING_RATE = 1e-3
N_EPOCHS = 200
BATCH_SIZE = 1024

In [27]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [63]:
class SubhaloDataset(Dataset):
    def __init__(self, X, y, is_central):
        self.X = X
        self.y = y
        self.is_central = is_central

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [64]:
k = 0
train_indices, valid_indices = train_valid_split[k]

train_dataset = SubhaloDataset(X[train_indices], y[train_indices], is_central[train_indices])
valid_dataset = SubhaloDataset(X[valid_indices], y[valid_indices], is_central[valid_indices])

In [65]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [57]:
def compute_rmse(preds, targs):
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5

In [32]:
train_losses = []
valid_losses = []
valid_rmses = []

for epoch in range(1, N_EPOCHS+1):
    train_loss = train_epoch_mlp(train_loader, model, optimizer, device="cuda")
    valid_loss, preds, targs = validate_mlp(valid_loader, model, device="cuda")

    valid_rmse = compute_rmse(preds, targs)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    valid_rmses.append(valid_rmse)
    
    if epoch % 10 == 0:
        print(epoch, valid_rmse)

10 0.3778254047206109
20 0.3757114615242275
30 0.35868102945864944
40 0.34507368727894966
50 0.33290764205399503
60 0.33120742020357347
70 0.32710633647519666
80 0.3415095293096444
90 0.3394796459869219
100 0.3262023797304213
110 0.32665069064580127
120 0.33372600220902604
130 0.34176314051228246
140 0.32223641234266526
150 0.3203132908508958
160 0.32449678196655946
170 0.3174565625701904
180 0.32140523323413805
190 0.32300519755111495
200 0.32164500037333615


In [109]:
plt.figure(figsize=(4,4), dpi=150)
plt.plot(range(N_EPOCHS), valid_rmses)
plt.ylim(0.3, 0.4)
plt.grid(alpha=0.15)
plt.xlabel("Epochs", fontsize=12)
plt.ylabel("Validation RMSE [dex]", fontsize=12);

plt.clf()

<Figure size 600x600 with 0 Axes>

## Visualize predictions

In [107]:
plt.figure(figsize=(4,4), dpi=150)
select_centrals = (valid_dataset.is_central).flatten().numpy().astype(bool)

plt.scatter(targs[:, 0][select_centrals], preds[:, 0][select_centrals], c="C3", s=3, edgecolor="none", alpha=0.5, )
plt.scatter(targs[:, 0][~select_centrals], preds[:, 0][~select_centrals], c="C0", s=3, edgecolor="none", alpha=0.5, )
plt.plot([8, 12], [8, 12], ls="-", c="k", lw=1) 
plt.xlim(8, 12)
plt.ylim(8, 12)
plt.gca().set_aspect("equal")
plt.grid(alpha=0.15)
plt.xlabel(r"True log($M_{\bigstar}/M_{\odot}$)", fontsize=12)
plt.ylabel(r"Predicted log($M_{\bigstar}/M_{\odot}$)", fontsize=12);

plt.clf()

<Figure size 600x600 with 0 Axes>

In [108]:
plt.figure(figsize=(4,4), dpi=150)
select_centrals = (valid_dataset.is_central).flatten().numpy().astype(bool)

plt.scatter(targs[:, 1][select_centrals], preds[:, 1][select_centrals], c="C3", s=3, edgecolor="none", alpha=0.5, )
plt.scatter(targs[:, 1][~select_centrals], preds[:, 1][~select_centrals], c="C0", s=3, edgecolor="none", alpha=0.5, )
plt.plot([8, 12], [8, 12], ls="-", c="k", lw=1) 
plt.xlim(8, 12)
plt.ylim(8, 12)
plt.gca().set_aspect("equal")
plt.grid(alpha=0.15)
plt.xlabel(r"True log($M_{\rm gas}/M_{\odot}$)", fontsize=12)
plt.ylabel(r"Predicted log($M_{\rm gas}/M_{\odot}$)", fontsize=12);

plt.clf()

<Figure size 600x600 with 0 Axes>

## Print out some RMSE errors...

In [67]:
# RMSE for log Mstar given a cut on Mstar > 8.5
selection = valid_dataset.y[:, 0] > 8.5

np.mean(((preds[:, 0] - targs[:, 0])[selection])**2)**0.5

0.27665816446725827

In [80]:
# RMSE for log Mstar given a cut on Mstar > 8.5 *and CENTRALS*
selection = (valid_dataset.y[:, 0] > 8.5) & (valid_dataset.is_central).flatten().numpy()

np.mean(((preds[:, 0] - targs[:, 0])[selection])**2)**0.5

0.1536048205374387

In [76]:
# RMSE for log Mgas given a cut on log Mstar > 8.5 and valid Mgas
selection = (targs[:, 0] > 8.5) & torch.isfinite(valid_dataset.y[:, 1]).numpy()

np.mean(((preds[:, 1] - targs[:, 1])[selection])**2)**0.5

0.33502949494075734

In [84]:
# RMSE for log Mgas given a cut on log Mstar > 8.5 and valid Mgas, and for CENTRALS
selection = (targs[:, 0] > 8.5) & torch.isfinite(valid_dataset.y[:, 1]).numpy() & (valid_dataset.is_central).flatten().numpy()

np.mean(((preds[:, 1] - targs[:, 1])[selection])**2)**0.5

0.30505247753148884

# All together now

In [2]:
class SubhaloDataset(Dataset):
    """Super simple dataset class for loading into MLP"""
    def __init__(self, X, y, subhalo_id, is_central):
        self.X = X
        self.y = y
        self.subhalo_id = subhalo_id
        self.is_central = is_central

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


def get_spatial_train_valid_indices(data, k: int, K: int = 3, boxsize: float = 75/0.6774, 
                                   pad: float = 3, indices_mask: Optional[torch.Tensor]=None):
    """Create spatial train/validation indices using z-coordinate splits.
    
    This creates spatially separated train/validation sets by dividing the simulation
    box along the z-axis. It correctly handles periodic boundaries and optional
    pre-filtering to create a true partition of the data.
    
    Returns:
        Tuple of (train_indices, valid_indices) as GLOBAL torch tensors, valid
        for indexing the original `data` object.
    """

    z_coords = data.pos[:, 2]

    valid_start = (k / K * boxsize)
    valid_end = ((k + 1) / K * boxsize)
    
    if k == K - 1:
        spatial_valid_mask = (z_coords >= valid_start) | (z_coords < (valid_end % boxsize))
    else:
        spatial_valid_mask = (z_coords >= valid_start) & (z_coords < valid_end)
    
    train_start = ((k + 1) / K * boxsize + pad) % boxsize
    train_end = (k / K * boxsize - pad) % boxsize
    
    if train_start > train_end:
        spatial_train_mask = (z_coords >= train_start) | (z_coords <= train_end)
    else:
        spatial_train_mask = (z_coords >= train_start) & (z_coords <= train_end)

    if indices_mask is None:
        final_mask = torch.ones_like(z_coords, dtype=torch.bool)
    else:
        final_mask = indices_mask

    # Use logical AND to get the final masks for training and validation
    final_valid_mask = spatial_valid_mask & final_mask
    final_train_mask = spatial_train_mask & final_mask

    valid_indices = final_valid_mask.nonzero(as_tuple=True)[0]
    train_indices = final_train_mask.nonzero(as_tuple=True)[0]

    # Double-check for overlap, which should now be impossible by construction
    overlap = set(train_indices.tolist()) & set(valid_indices.tolist())
    assert len(overlap) == 0, f"Found {len(overlap)} overlapping indices"

    print(f"Fold {k}/{K}: Train={len(train_indices)}, Valid={len(valid_indices)}")
    
    return train_indices, valid_indices


def train_epoch_mlp(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str
) -> float:
    """Train one epoch for MLP model.
    
    Args:
        dataloader: Data loader for training data (X, y tuples)
        model: Model to train
        optimizer: Optimizer
        device: Device to train on
        
    Returns:
        Average training loss for the epoch
    """
    model.train()
    loss_total = 0
    
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        output = model(X)

        y_pred, logvar_pred = output.chunk(2, dim=1)
        
        assert not torch.isnan(y_pred).any() and not torch.isnan(logvar_pred).any()
        
        y_pred = y_pred.view(-1, y.shape[1] if len(y.shape) > 1 else 1)
        logvar_pred = logvar_pred.mean()
        
        loss = gaussian_nll_loss(y_pred, y, logvar_pred)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        loss_total += loss.item()
        
    return loss_total / len(dataloader)


def validate_mlp(
    dataloader: DataLoader,
    model: nn.Module,
    device: str
) -> Tuple[float, np.ndarray, np.ndarray]:
    """Validate MLP model.
    
    Args:
        dataloader: Validation data loader
        model: Model to validate
        device: Device to validate on
        
    Returns:
        Tuple of (loss, predictions, targets)
    """
    model.eval()
    loss_total = 0
    y_preds = []
    y_trues = []
    
    for X, y in dataloader:
        with torch.no_grad():
            X = X.to(device)
            y = y.to(device)
            
            output = model(X)
            y_pred, logvar_pred = output.chunk(2, dim=1)
            
            y_pred = y_pred.view(-1, y.shape[1] if len(y.shape) > 1 else 1)
            logvar_pred = logvar_pred.mean()
            
            loss = gaussian_nll_loss(y_pred, y, logvar_pred)
            loss_total += loss.item()
            
            y_preds.append(y_pred.detach().cpu().numpy())
            y_trues.append(y.detach().cpu().numpy())
    
    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    
    return loss_total / len(dataloader), y_preds, y_trues

def gaussian_nll_loss(y_pred: torch.Tensor, y_true: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute Gaussian negative log-likelihood loss *with masking out infinite values*.
    
    Args:
        y_pred: Model predictions
        y_true: Ground truth values  
        logvar: Log variance predictions
        
    Returns:
        Gaussian NLL loss
    """
    finite_mask = (y_true > 0.) & (y_true.isfinite())
    
    if not finite_mask.any():
        return torch.tensor(0.0, device=y_pred.device, requires_grad=True)
    

    y_pred_masked = y_pred[finite_mask]
    y_true_masked = y_true[finite_mask]
    mse_loss = F.mse_loss(y_pred_masked, y_true_masked)
    
    return 0.5 * (mse_loss / 10**logvar + logvar)


def compute_rmse(preds, targs):
    """lil helper func"""
    finite_mask = (targs > 0.) & (np.isfinite(targs))
    y_pred_masked = preds[finite_mask]
    y_true_masked = targs[finite_mask]
    return np.mean((y_pred_masked - y_true_masked)**2)**0.5

In [3]:
K_FOLDS = 3

with open("../results/cosmic_graphs_3Mpc.pkl", "rb") as f:
    env_data = pickle.load(f)

# mask out completely NaN/inf rows
isfinite_mask = np.logical_and(
    np.isfinite(env_data.x).all(axis=1),   # need all inputs
    np.isfinite(env_data.y).any(axis=1)    # allowed to have gas mass = NaN
).type(torch.bool)

train_valid_split = [
    get_spatial_train_valid_indices(env_data, k=k, K=K_FOLDS, indices_mask=isfinite_mask)
    for k in range(K_FOLDS)
]

assert sum(len(v) for t, v in train_valid_split) == sum(isfinite_mask)

X = env_data.x
y = env_data.y
is_central = env_data.is_central
subhalo_ids = env_data.subhalo_id

Fold 0/3: Train=84601, Valid=39562
Fold 1/3: Train=75669, Valid=49687
Fold 2/3: Train=81609, Valid=43177


In [4]:

# MLP hyperparams
N_INPUTS = X.shape[1]
N_HIDDEN = 128
N_OUTPUTS = y.shape[1]

# optimization hyperparms
LEARNING_RATE = 1e-3
N_EPOCHS = 200
BATCH_SIZE = 1024


In [5]:
for k in range(K_FOLDS):
    log_file = f"../results/logs/mlp_fold_{k}.txt"
    
    with open(log_file, "a") as f:
        f.write(f"epoch,train_loss,valid_loss,valid_RMSE\n")
    
    train_indices, valid_indices = train_valid_split[k]

    X_train, y_train, subhalo_ids_train, is_central_train = X[train_indices], y[train_indices], subhalo_ids[train_indices], is_central[train_indices]
    X_valid, y_valid, subhalo_ids_valid, is_central_valid = X[valid_indices], y[valid_indices], subhalo_ids[valid_indices], is_central[valid_indices]
    
    train_dataset = SubhaloDataset(X_train, y_train, subhalo_ids_train, is_central_train)
    valid_dataset = SubhaloDataset(X_valid, y_valid, subhalo_ids_valid, is_central_valid)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = nn.Sequential(
        nn.Linear(N_INPUTS, N_HIDDEN),
        nn.ReLU(),
        nn.Linear(N_HIDDEN, N_HIDDEN),
        nn.ReLU(),
        nn.Linear(N_HIDDEN, N_HIDDEN),
        nn.ReLU(),
        nn.Linear(N_HIDDEN, 2*N_OUTPUTS)
    ).cuda()
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    train_losses = []
    valid_losses = []
    valid_rmses = []

    epoch_pbar = tqdm.tqdm(range(N_EPOCHS), desc=f"Fold {k} Training", leave=True)
    for epoch in epoch_pbar:
        train_loss = train_epoch_mlp(train_loader, model, optimizer, device="cuda")
        valid_loss, preds, targs = validate_mlp(valid_loader, model, device="cuda")
    
        valid_rmse = compute_rmse(preds, targs)

        with open(log_file, "a") as f:
            f.write(f"{epoch:d},{train_loss:.6f},{valid_loss:.6f},{valid_rmse:.6f}\n")

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_rmses.append(valid_rmse)

        epoch_pbar.set_postfix({'valid_rmse': f'{valid_rmse:.4f}'})

    # save predictions
    results_file = f"../results/predictions/mlp_fold_{k}.parquet"

    results_df = pd.DataFrame({
        "subhalo_id": subhalo_ids[valid_indices].numpy(),
        "log_Mstar_pred": preds[:, 0],
        "log_Mstar_true": targs[:, 0],
        "log_Mgas_pred": preds[:, 1],
        "log_Mgas_true": targs[:, 1],
        "is_central": is_central[valid_indices].flatten().numpy()
    }).set_index("subhalo_id")

    results_df.to_parquet(results_file)

    # save model weights
    model_file = f"../results/models/mlp_fold_{k}.pth"
    torch.save(model.state_dict(), model_file)

Fold 0 Training: 100%|████████████████████████| 200/200 [01:28<00:00,  2.25it/s, valid_rmse=0.3328]
Fold 1 Training: 100%|████████████████████████| 200/200 [01:23<00:00,  2.41it/s, valid_rmse=0.3511]
Fold 2 Training: 100%|████████████████████████| 200/200 [01:20<00:00,  2.47it/s, valid_rmse=0.3347]
