# Stellar Parameter and Abundance Estimation using ML

## Introduction:

This project began as a senior honours dissertation to analyse Milky Way Survey (MWS) spectra from the Dark Energy Spectroscopic Instrument (DESI) and estimate stellar parameters — effective temperature ($T_\text{eff}$), surface gravity ($\log g$), and metallicity [Fe/H] — along with detailed chemical abundances, using machine learning.

This work is essential for improving current techniques used to infer these stellar quantities from observed spectra. Better estimates lead to more accurate models of stellar populations and a deeper understanding of Galactic structure and evolution, particularly how stellar populations differ based on their location within the Milky Way.

We apply two main ML strategies:

1. Random Forest Regression with PCA for dimensionality reduction.

2. Neural Networks for learning complex mappings from spectra to parameters.

Future work includes integrating physics-informed loss functions to guide the models using astrophysical priors and improve physical reliability.

This notebook walks through:

- Preprocessing the DESI spectra (e.g., Doppler correction, flux normalization),

- Training and evaluating ML models for parameter estimation,

- Visualizing model performance and residuals,

- Discussing implications and next steps.

In [3]:
# Imports
import numpy as np
import pandas as pd
from joblib import load
import matplotlib.pyplot as plt
#from sklearn.metrics import root_mean_squared_error

# load configuration
from utils import load_config, setup_env
setup_env(load_config('config.yaml'))

## Preprocessing:



## 🔍 3. Data Preprocessing

Preprocessing is a crucial step in this project to ensure the spectral data is clean, consistent, and suitable for input into machine learning models. The raw data consists of DESI **FITS files** containing flux measurements from three cameras (**B**, **R**, **Z**) across different wavelength ranges.

The preprocessing pipeline involves the following steps:

---

### 📁 3.1 Load FITS Spectra

- The FITS files are loaded using `astropy.io.fits`, one for each DESI observation.
- Each file contains flux, wavelength, and masking information for three spectral bands:
  - **B-band**: blue wavelengths
  - **R-band**: red wavelengths
  - **Z-band**: infrared wavelengths

Each FITS file also contains metadata (RA/Dec) needed to cross-match with label data.

---

### 🌌 3.2 Cross-Match with APOGEE Catalog

To obtain ground-truth labels (e.g. [Fe/H], $\log g$), the DESI targets are **matched by sky coordinates** with high-fidelity **APOGEE DR17** data:
- RA/Dec from both catalogs are compared using `astropy.coordinates.SkyCoord`.
- Matches within **3 arcseconds** are accepted.
- For each matched star, the relevant **chemical abundance values** are extracted and stored.

This step ensures the spectra and the training labels refer to the **same stars**.

---

### 🧼 3.3 Wavelength Alignment & Trimming

The different cameras overlap slightly, but for consistency:
- The last **25 pixels** of the **B-band**, the first **26 pixels** and last **63 pixels** of the **R-band**, and the first **63 pixels** of the **Z-band** are removed.
- This avoids noisy or edge-effect-prone regions and creates a continuous 1D wavelength grid.

All cameras are then **stitched together** into one full spectrum.

---

### ⚠️ 3.4 Masking and Interpolation

Each camera includes a **mask array** indicating bad pixels. These pixels are:
- Marked as masked values (`np.ma.MaskedArray`)
- Interpolated using a 1D linear interpolation based on neighboring valid pixels

> This step ensures **no NaNs or masked values** remain in the flux array, which is critical for ML input.

---

### 🚀 3.5 Doppler Shift Correction

Each star's **radial velocity** ($v_\text{helio}$) from APOGEE is used to correct for Doppler shifts:
- The observed wavelengths are shifted to the **rest frame** using:

  $$ \lambda_\text{rest} = \frac{\lambda_\text{obs}}{1 + v/c} $$

- The flux is interpolated to a **common rest-frame wavelength grid**, ensuring consistency across stars.

---

### 🔃 3.6 Flux Normalization

The rest-frame flux is then **continuum-normalized**:
- A **moving median filter** (window size = 151) is applied to approximate the continuum.
- The flux is divided by this smoothed continuum, reducing the impact of instrumental effects and star-specific baselines.

---

### 💾 3.7 Save Processed Data

- All normalized, rest-frame flux arrays are stacked into a single matrix: `flux.joblib`, shape = (n_stars, n_pixels).
- Corresponding labels (stellar parameters and abundances) are saved in `labels.csv`.

---

### 🧪 Example: Plotting a Normalized Spectrum

```python
from joblib import load
import matplotlib.pyplot as plt

flux = load("data/spectral_dir/flux.joblib")
plt.figure(figsize=(10, 4))
plt.plot(flux[0])
plt.title("Normalized, Rest-Frame Spectrum (Star 0)")
plt.xlabel("Wavelength Index")
plt.ylabel("Normalized Flux")
plt.grid(True)
plt.show()
```


## 🤖 4. Modeling

The goal of this stage is to map the **preprocessed spectral flux** of a star to its **stellar parameters** (e.g., $T_\text{eff}$, $\log g$, [Fe/H]) and **chemical abundances** (e.g., [C/Fe], [Mg/Fe], etc.) using machine learning models.

Two different types of models are trained and evaluated:

- 🌲 **Random Forest Regression** with **Incremental PCA** (for dimensionality reduction)
- 🧠 **Feedforward Neural Networks** tuned using KerasTuner

---

### 🌲 4.1 Random Forest + PCA

High-dimensional spectral data (thousands of wavelength points) poses challenges for classical models. To address this:

- We use **Incremental PCA (IPCA)** to reduce flux dimensionality while preserving variance.
- Then, we apply **Random Forest Regression** — a robust, interpretable ensemble model.

This is implemented as a **`sklearn.Pipeline`**, which includes:
1. `StandardScaler()` for feature scaling,
2. `IncrementalPCA(n_components=100)` for reducing dimensions,
3. `RandomForestRegressor()` for prediction.

Hyperparameters such as the number of PCA components, tree depth, and leaf size are tuned using **`GridSearchCV`** with 5-fold cross-validation.

#### 🛠 Example: Load a trained model and predict Teff

```python
from joblib import load
from sklearn.metrics import root_mean_squared_error
import pandas as pd

flux = load("data/spectral_dir/flux.joblib")
labels = pd.read_csv("data/label_dir/labels.csv")
y_true = labels["teff"].to_numpy()
mask = ~np.isnan(y_true)

model = load("models/teff_model.joblib")
y_pred = model.predict(flux[mask])

rmse = root_mean_squared_error(y_true[mask], y_pred)
print(f"Teff RMSE (Random Forest): {rmse:.2f} K")
```

---

### 🧠 4.2 Neural Network (NN) Models

To model non-linear relationships in the spectral data, we also train **feedforward neural networks (FNNs)** using **TensorFlow/Keras**.

The neural network architecture is selected using **KerasTuner**, which explores:
- Number of hidden units (32–128),
- Dropout rate (0–0.5),
- Learning rate (1e-4 to 1e-2).

**Architecture Summary:**
- Input layer = flattened flux vector
- 2 hidden layers (ReLU activations, Dropout optional)
- Output layer = scalar (parameter or abundance value)

Training uses:
- `Adam` optimizer
- `EarlyStopping` and `ReduceLROnPlateau` callbacks
- Loss = mean squared error

#### 🛠 Example: Predict with trained NN model

```python
from tensorflow.keras.models import load_model
import numpy as np
import pandas as pd
from sklearn.metrics import root_mean_squared_error

flux = load("data/spectral_dir/flux.joblib")
labels = pd.read_csv("data/label_dir/labels.csv")
y_true = labels["teff"].to_numpy()
mask = ~np.isnan(y_true)

model = load_model("models/teff_model.keras")
y_pred = model.predict(flux[mask]).flatten()

rmse = root_mean_squared_error(y_true[mask], y_pred)
print(f"Teff RMSE (Neural Network): {rmse:.2f} K")
```

---

### 📉 4.3 Residual Analysis & Model Comparison

For each parameter, we compute:
- **Predicted value**
- **Residual** = true value – predicted value
- **RMSE**: Root Mean Squared Error

These are visualized to assess systematic offsets and model uncertainty.

#### 🛠 Example: Plot residuals vs. [Fe/H]

```python
import matplotlib.pyplot as plt
residuals = pd.read_csv("residuals/teff.csv")

plt.scatter(residuals["[Fe/H]"], residuals["Residuals"], alpha=0.4)
plt.axhline(0, color='gray', linestyle='--')
plt.title("Teff Residuals vs [Fe/H]")
plt.xlabel("[Fe/H]")
plt.ylabel("Teff Residuals (K)")
plt.grid(True)
plt.show()
```

---

### 🧪 4.4 Performance Summary

A full summary of RMSE values across all parameters is saved in:

```python
errors = pd.read_csv("output/results/predictions_errors.csv")
errors.sort_values("RMSE")
```

This helps identify which stellar parameters are easier/harder to predict, and which model (RF or NN) performs better for each.

---

Next, we’ll move to final discussion and interpretation of these results.


## 🧬 5. Toward Physics-Informed Neural Networks (PINNs)

While classical machine learning models (like Random Forests and standard Neural Networks) can effectively learn complex mappings between spectral features and stellar labels, they do so **agnostically** — without knowing anything about the physics of stars. This can lead to predictions that are:

- Physically inconsistent (e.g., [C/N] increasing with stellar age),
- Poorly generalizable in low-data regimes,
- Sensitive to spurious correlations.

To address this, we are exploring the use of **Physics-Informed Neural Networks (PINNs)**.

---

### 🤔 What is a PINN?

A **Physics-Informed Neural Network** incorporates **domain-specific physical constraints** into the loss function of a neural network. Instead of minimizing error purely based on training labels (supervised loss), the model is also penalized for violating **physical laws or empirical relationships**.

In our case, candidate physics-based constraints include:

| Constraint | Rationale |
|-----------|-----------|
| Spectral Flux Conservation | Total energy across wavelength range should be bounded |
| [C/N] vs. log(g) or age | Carbon and nitrogen abundances evolve predictably in red giants |
| Known stellar relations (e.g., $\log g \propto M/R^2$) | Consistency with stellar structure theory |
| Abundance gradients vs. [Fe/H] | Enforcing chemical evolution trends |

---

### 🏗 PINN Loss Function Design

The total loss $\mathcal{L}_{\text{total}}$ combines:

$$
\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{data}} + \lambda_{\text{phys}} \cdot \mathcal{L}_{\text{physics}}
$$

- **$\mathcal{L}_{\text{data}}$**: standard MSE between predicted and true values.
- **$\mathcal{L}_{\text{physics}}$**: penalty for violating known astrophysical relations.
- **$\lambda_{\text{phys}}$**: a tunable weight controlling how strongly physics is enforced.

---

### ⚙️ Implementation Roadmap

Here’s the plan for integrating PINNs into this project:

1. **Define Physics Constraints**  
   E.g., use [C/N] vs. $\log g$ relation as a differentiable loss term.

2. **Extend Neural Network Training Code**  
   Modify `train_neural_network()` to include a custom loss function combining MSE and physics loss.

3. **Experiment with λ Weighting**  
   Perform hyperparameter tuning to find a balance between data fit and physics regularization.

4. **Evaluate Performance & Interpretability**  
   Assess whether PINNs reduce residuals, improve physical trends, and generalize better on edge cases.

---

### 📘 Example: Adding a Physics Term

Here's a conceptual sketch of what a physics-informed loss could look like in TensorFlow:

```python
def physics_loss(y_pred, features):
    # e.g., penalize unphysical [C/N] values at low logg
    logg = features["logg"]
    c_n = y_pred["c_fe"] - y_pred["n_fe"]
    penalty = tf.nn.relu(c_n + 0.5 * (logg - 1.5))  # Toy example
    return tf.reduce_mean(penalty)

def total_loss(y_true, y_pred, features, lambda_phys=0.1):
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    phys = physics_loss(y_pred, features)
    return mse + lambda_phys * phys
```

---

### 🧠 Why PINNs Matter

Incorporating physics:
- Increases **model interpretability**,
- Reduces **overfitting to noisy training labels**,
- Makes the models more **robust to outliers and extrapolation**,
- Helps bridge **theory and data-driven methods** in stellar astrophysics.

---

### 🔮 Future Work

- Incorporate more complex constraints (e.g., HR diagram priors, stellar evolution tracks).
- Apply PINNs to low-S/N spectra and rare stellar populations.
- Benchmark PINN models vs. standard NNs across different galactic environments.

---

PINNs represent an exciting next step in combining astrophysical knowledge with the flexibility of deep learning. Work is ongoing to prototype and evaluate their full integration in this framework.


## 📊 6. Results & Evaluation

After training both **Random Forest** and **Neural Network** models, we assess their performance on a set of stellar parameters and chemical abundances. The key metric used for evaluation is the **Root Mean Squared Error (RMSE)**:

$$
\text{RMSE} = \sqrt{ \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y}_i)^2 }
$$

This gives a direct measure of prediction error in physical units (e.g., Kelvin for $T_\text{eff}$, dex for abundances).

---

### 📋 6.1 Summary of Model Performance

The table below lists the RMSE values for each parameter:

```python
import pandas as pd
rf_results = pd.read_csv("output/results/predictions_errors.csv")  # Random Forest
rf_results["Model"] = "Random Forest"

nn_results = pd.read_csv("results/predictions_errors.csv")  # Neural Network
nn_results["Model"] = "Neural Network"

summary = pd.concat([rf_results, nn_results])
summary = summary.pivot(index="Parameter", columns="Model", values="RMSE").sort_index()
summary
```

---

### ⚖️ 6.2 Random Forest vs Neural Network

This bar plot compares the RMSE for each parameter across both model types:

```python
import matplotlib.pyplot as plt

summary.plot(kind='bar', figsize=(14, 6))
plt.title("Model Performance: Random Forest vs Neural Network")
plt.ylabel("RMSE")
plt.xticks(rotation=45, ha='right')
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()
```

> 📌 **Observation**: For parameters like $T_\text{eff}$ and [Fe/H], both models may perform similarly. For trace elements (e.g., [C/Fe], [Ti/Fe]), neural networks may capture more subtle features due to non-linearity.

---

### 📉 6.3 Residual Analysis

We visualize the residuals (true - predicted) for a few key parameters to identify systematic biases or model failures.

#### 🔎 Example: Teff Residuals

```python
res = pd.read_csv("residuals/teff.csv")
plt.figure(figsize=(8, 5))
plt.scatter(res["[Fe/H]"], res["Residuals"], alpha=0.4)
plt.axhline(0, color="black", linestyle="--")
plt.title("Teff Residuals vs [Fe/H]")
plt.xlabel("[Fe/H]")
plt.ylabel("Residuals (K)")
plt.grid(True)
plt.show()
```

You can repeat this for other parameters like `logg`, `[Mg/Fe]`, etc., to spot patterns.

---

### 🌍 6.4 Trends vs Metallicity

Plotting predictions against [Fe/H] can reveal physical consistency:

```python
teff_preds = pd.read_csv("results/teff_predictions.csv")
plt.figure(figsize=(8, 5))
plt.scatter(teff_preds["[Fe/H]"], teff_preds["teff"], alpha=0.5, s=10)
plt.title("Predicted Teff vs [Fe/H]")
plt.xlabel("[Fe/H]")
plt.ylabel("Teff (K)")
plt.grid(True)
plt.show()
```

---

### 🧠 6.5 General Observations

- **Teff and [Fe/H]** are predicted with relatively low error, suggesting strong signal in the flux.
- **Surface gravity ($\log g$)** is often harder to estimate and may benefit from PINN regularization.
- **Abundance predictions** (e.g., [C/Fe], [Mg/Fe]) show greater spread, particularly for rare elements or low-S/N spectra.

---

### ✅ 6.6 Evaluation Summary

| Metric         | Comment |
|----------------|---------|
| **RMSE**       | Used to quantify overall error per parameter |
| **Residuals**  | Help identify systematic under/overestimation |
| **Trend plots**| Useful for checking astrophysical consistency |
| **Model Comparison** | Neural networks often outperform RFs on more subtle parameters, but RFs are faster and more interpretable |

---

In the next section, we discuss scientific interpretation and future improvements — including using Physics-Informed models to boost physical reliability.
