# Photo-Z Challenge: Machine Learning Pipeline
Welcome to the interactive Photo-Z Challenge notebook! 

This notebook integrates the entire machine learning pipeline into a single, step-by-step educational flow. We will cover:
1. Environment and Plotting Setup
2. Data Preprocessing and Custom Datasets
3. Model Architecture Details
4. Training Loop Execution
5. Validation and Metric Visualization

In [1]:
import os
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import gaussian_kde
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
# --- 2. PLOTTING STYLE SETTINGS ---
plt.style.use('dark_background')

plt.rcParams.update({
    'font.size': 16,          
    'axes.labelsize': 20,     
    'axes.titlesize': 22,     
    'xtick.labelsize': 16,    
    'ytick.labelsize': 16,    
    'legend.fontsize': 16,    
    'lines.linewidth': 4      
})

# Define custom colors
TYPE_COLORS = {
    'GALAXY_ID': 'tab:cyan',
    'QSO': 'tab:red',
    'GALAXY_OOD1': 'yellow',
    'GALAXY_OOD2': 'gold',
    'Passive': 'tab:orange',
    'ELG': 'tab:green'
}

def get_smart_limits(data, padding=0.05):
    """Calculates axis limits based on data percentiles to filter outliers."""
    valid_data = data[np.isfinite(data)]
    if len(valid_data) == 0:
        return 0, 1
    low = np.percentile(valid_data, 1)
    high = np.percentile(valid_data, 99)
    span = high - low
    return low - span * padding, high + span * padding

## Configuration & Data Preprocessing
We define our helper functions to read the configuration file, handle missing data, and normalize our inputs. We also define the PyTorch `Dataset` that will feed our model.

In [3]:
def load_config(config_path="./config.yaml"):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def get_mad(series):
    """Calculates Median Absolute Deviation"""
    median = series.median()
    return (series - median).abs().median()

def preprocess_data(df, config, train_stats=None):
    """Selects variables, fills NaNs, and normalizes using MAD."""
    selected_cols = []
    for group in config['data']['selected_features']:
        selected_cols.extend(config['data']['inputs'][group])
    
    X = df[selected_cols].copy()
    
    if 'Z' in df.columns:
        y = df['Z'].values
    else:
        y = np.where(df['SPECTYPE'] == 2.0, df['Z_QSO'], df['Z_GAL'])

    types = df['TYPE'].values
    X = X.fillna(0.0)

    cols_to_norm = []
    for group in config['data']['features_to_normalize']:
        if group in config['data']['selected_features']:
             cols_to_norm.extend(config['data']['inputs'][group])
    
    if train_stats is None:
        medians = X[cols_to_norm].median()
        mads = X[cols_to_norm].apply(get_mad).replace(0, 1.0)
        train_stats = {'medians': medians, 'mads': mads}
    
    X[cols_to_norm] = (X[cols_to_norm] - train_stats['medians']) / train_stats['mads']
    
    return X.values, y, types, train_stats, len(selected_cols)

class PhotoZDataset(Dataset):
    def __init__(self, features, targets, types):
        self.features = torch.FloatTensor(features)
        self.targets = torch.FloatTensor(targets).reshape(-1, 1)
        self.types = np.array(types)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx], self.types[idx]

## Model Architecture
The core model is a highly flexible Multilayer Perceptron (MLP) defined by `PhotoZNet`. We optimize it using a custom Delta Z Loss function tailored for photometric redshift challenges.

In [4]:
class PhotoZNet(nn.Module):
    def __init__(self, input_size, hidden_layers, dropout_rates):
        super(PhotoZNet, self).__init__()
        layers = []
        in_dim = input_size
        
        for h_dim, drop_rate in zip(hidden_layers, dropout_rates):
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(drop_rate))
            in_dim = h_dim
            
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.net(x)

class DeltaZLoss(nn.Module):
    def forward(self, y_pred, y_true):
        numerator = torch.abs(y_pred - y_true)
        denominator = 1.0 + y_true
        return torch.mean(numerator / denominator)

## Training the Model
This cell loads the HDF5 datasets, applies preprocessing, initializes the neural network, and runs the training loop over the specified epochs.

In [None]:
cfg = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

df_val = pd.read_hdf(cfg['data']['val_path'], key='data')
df_train = pd.read_hdf(cfg['data']['train_path'], key='data') 
df_val = pd.read_hdf(cfg['data']['val_path'], key='data')

X_train, y_train, types_train, stats, input_dim = preprocess_data(df_train, cfg, train_stats=None)
X_val, y_val, types_val, _, _ = preprocess_data(df_val, cfg, train_stats=stats)

train_dataset = PhotoZDataset(X_train, y_train, types_train)
train_loader = DataLoader(train_dataset, batch_size=cfg['data']['batch_size'], shuffle=True, drop_last=True)

model = PhotoZNet(
    input_size=input_dim,
    hidden_layers=cfg['model']['hidden_layers'],
    dropout_rates=cfg['model']['dropout_rates']
).to(device)

criterion = DeltaZLoss() if cfg['training']['loss_type'] == 'deltaz' else nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=cfg['training']['learning_rate'])

epochs = cfg['training']['epochs']

for epoch in range(epochs):
    model.train()
    train_loss_acc = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    for inputs, targets, _ in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss_acc += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    model.eval()
    with torch.no_grad():
        val_inputs = torch.FloatTensor(X_val).to(device)
        val_targets = torch.FloatTensor(y_val).to(device)
        val_outputs = model(val_inputs)
        total_val_loss = criterion(val_outputs, val_targets).item()

    print(f"Epoch {epoch+1}: Train Loss: {train_loss_acc / len(train_loader):.4f} | Val Loss Global: {total_val_loss:.4f}")

os.makedirs(cfg['experiment']['save_dir'], exist_ok=True)
save_path = os.path.join(cfg['experiment']['save_dir'], f"{cfg['experiment']['group_name']}.pth")
torch.save(model.state_dict(), save_path)
print(f"Model saved to: {save_path}")

Using device: cpu


Epoch 1/10 [Train]:   0%|          | 0/4028 [00:00<?, ?it/s]

## Evaluation and Visualization
Now we run inference on the validation set and calculate our primary metrics: Bias, $\sigma_{NMAD}$, and Outlier Fraction.

In [None]:
def compute_metrics_binned(df, x_col, z_true_col='Z_TRUE', z_pred_col='Z_PRED', bins=None):
    if bins is None:
        return None, None, None, None
    centers, bias_list, sigma_list, outlier_list = [], [], [], []
    dz = (df[z_pred_col] - df[z_true_col]) / (1 + df[z_true_col])
    
    for i in range(len(bins) - 1):
        mask = (df[x_col] >= bins[i]) & (df[x_col] < bins[i+1])
        subset_dz = dz[mask]
        
        if len(subset_dz) < 100:
            centers.append(np.nan); bias_list.append(np.nan); sigma_list.append(np.nan); outlier_list.append(np.nan)
            continue
            
        centers.append(0.5 * (bins[i] + bins[i+1]))
        bias_list.append(np.median(subset_dz))
        sigma_list.append(1.4826 * np.median(np.abs(subset_dz - np.median(subset_dz))))
        outlier_list.append(np.sum(np.abs(subset_dz) > 0.15) / len(subset_dz))

    return np.array(centers), np.array(bias_list), np.array(sigma_list), np.array(outlier_list)

def draw_metric_page(df, title_prefix, filter_condition=None):
    data = df[filter_condition].copy() if filter_condition is not None else df.copy()
    unique_types = data['TYPE'].unique()
    
    fig, axs = plt.subplots(2, 3, figsize=(22, 16))
    fig.suptitle(title_prefix, fontsize=26, weight='bold')
    
    bins_mag = np.arange(17, 23.3 + 0.25, 0.25)
    bins_z = np.arange(0, 1.8 + 0.1, 0.1)
    
    for t in unique_types:
        subset = data[data['TYPE'] == t]
        if len(subset) < 10: continue
        lbl = f"{t} (N={len(subset)})"
        col = TYPE_COLORS.get(t, 'white')
        
        # --- Row 1: vs MAG_i ---
        x_m, bias_m, sig_m, out_m = compute_metrics_binned(subset, 'MAG_i', 'Z', 'Z_PRED', bins_mag)
        axs[0, 0].plot(x_m, bias_m, label=lbl, color=col, marker='o', markersize=8)
        axs[0, 1].plot(x_m, sig_m, label=lbl, color=col, marker='o', markersize=8)
        out_m_log = (out_m * 100) + 1e-6
        axs[0, 2].plot(x_m, out_m_log, label=lbl, color=col, marker='o', markersize=8)
        
        # --- Row 2: vs Z True ---
        x_z, bias_z, sig_z, out_z = compute_metrics_binned(subset, 'Z', 'Z', 'Z_PRED', bins_z)
        axs[1, 0].plot(x_z, bias_z, label=lbl, color=col, marker='o', markersize=8)
        axs[1, 1].plot(x_z, sig_z, label=lbl, color=col, marker='o', markersize=8)
        out_z_log = (out_z * 100) + 1e-6
        axs[1, 2].plot(x_z, out_z_log, label=lbl, color=col, marker='o', markersize=8)

    # --- Formatting ---
    for ax in axs.flatten():
        ax.grid(False)

    for ax in axs[0, :]: 
        ax.set_xlabel('MAG_i')
        ax.set_xlim(17, 23.3)
    
    for ax in axs[1, :]: 
        ax.set_xlabel('Z True')
        ax.set_xlim(0, 1.8)
        
    axs[0, 0].set_ylabel('Bias $\Delta z$'); axs[1, 0].set_ylabel('Bias $\Delta z$')
    
    axs[0, 1].set_ylabel('$\sigma_{NMAD}$ (log)'); axs[1, 1].set_ylabel('$\sigma_{NMAD}$ (log)')
    axs[0, 1].set_yscale('log'); axs[1, 1].set_yscale('log')
    
    axs[0, 2].set_ylabel('Outlier % (log)'); axs[1, 2].set_ylabel('Outlier % (log)')
    axs[0, 2].set_yscale('log'); axs[1, 2].set_yscale('log')

    axs[0, 0].legend(loc='lower left', frameon=True)
    
    plt.tight_layout()
    plt.show()

# Run Evaluation and Plotting
df_val['Z_PRED'] = model(torch.FloatTensor(X_val).to(device)).cpu().detach().numpy().flatten()
if 'Z' not in df_val.columns:
    df_val['Z'] = np.where(df_val['SPECTYPE'] == 2.0, df_val['Z_QSO'], df_val['Z_GAL'])

draw_metric_page(df_val, "Performance Metrics (All Objects)")