In [None]:
from pathlib import Path

import numpy as np
from astropy.table import Table
import pandas as pd
import matplotlib.pyplot as plt
import mpl_scatter_density
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, PolynomialFeatures

import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import torch.nn as nn
import torch.optim as optim
import os
import math
from plot_utils import plot_kiel_scatter_density

### Read in the cleaned APOGEE data and make a Kiel Diagram for the whole sample.

In [None]:
apogee_path = Path("./data/apogee_cleaned.parquet")
apogee_cat = pd.read_parquet(apogee_path)


In [None]:
fig_kiel, ax_kiel = plot_kiel_scatter_density(
    apogee_cat['TEFF'],
    apogee_cat['LOGG'],
    apogee_cat['FE_H']
)

### Train-test split and fit a simple linear model

In [None]:
# Create simple linear model to predict feh from teff and logg
# perform train test split and normalize the data


X = apogee_cat[['TEFF', 'LOGG']]
y = apogee_cat['FE_H']

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.999, random_state=42)

_, X_test, _, y_test = train_test_split(X_test, y_test, test_size=0.002, random_state=42)

# Normalize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train).astype("float64")
X_test_scaled = scaler.transform(X_test).astype("float64")
y_train = y_train.astype("float64")
y_test = y_test.astype("float64")

#Print the sizes of the training and testing sets
print(f"Training set size: {X_train_scaled.shape[0]}")
print(f"Testing set size: {X_test_scaled.shape[0]}")

In [None]:
# Plot the kiel diagram for the predicted values (test set only)
fig_kiel, ax_kiel = plot_kiel_scatter_density(
    X_train['TEFF'],
    X_train['LOGG'], 
    y_train,
    title='Kiel Diagram - Train Set',
    colorbar_label='Predicted [Fe/H]',
    scatter = True
)


In [None]:


# Define the MLP model using PyTorch Lightning
class MLP(pl.LightningModule):
    def __init__(self, network_shape, input_dim, learning_rate=1e-3, output_dim=1):
        super().__init__()
        # Save hyperparameters (network_shape, input_dim, learning_rate, output_dim)
        # This allows loading from checkpoint without re-passing these args
        self.save_hyperparameters()

        layers = []
        current_dim = self.hparams.input_dim
        for hidden_dim in self.hparams.network_shape:
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.ReLU())
            current_dim = hidden_dim
        layers.append(nn.Linear(current_dim, self.hparams.output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

# Function to fit MLP, make predictions, and plot Kiel diagram
def fit_mlp_regression(
    X_train_scaled_np, X_test_scaled_np, y_train_series, y_test_series,
    test_teff_series, test_logg_series, network_shape, **kwargs
):
    """
    Fits an MLP regression model using PyTorch Lightning, plots the Kiel diagram with predicted values.

    Args:
        X_train_scaled_np (np.ndarray): Scaled training features.
        X_test_scaled_np (np.ndarray): Scaled test features.
        y_train_series (pd.Series): Training target values.
        y_test_series (pd.Series): Test target values.
        test_teff_series (pd.Series): Test set effective temperatures.
        test_logg_series (pd.Series): Test set surface gravities.
        network_shape (list of int): List defining the number of neurons in each hidden layer.
        **kwargs: Additional hyperparameters for training.
            learning_rate (float): Optimizer learning rate. Default: 1e-3.
            batch_size (int): Batch size for DataLoaders. Default: 256.
            max_epochs (int): Maximum number of training epochs. Default: 100.
            patience (int): Patience for EarlyStopping. Default: 10.
            val_split_ratio (float): Fraction of training data to use for validation. Default: 0.2.
            num_workers (int): Number of workers for DataLoader. Default: 0.
            accelerator (str): PyTorch Lightning accelerator ('cpu', 'gpu', 'auto'). Default: 'auto'.
            devices (any): PyTorch Lightning devices. Default: 'auto'.
            checkpoint_dir (str): Directory to save model checkpoints. Default: 'mlp_checkpoints/'.

    Returns:
        tuple: (best_model, predicted_feh_mlp, fig, ax, density_map)
            best_model (MLP): The trained MLP model (best checkpoint).
            predicted_feh_mlp (np.ndarray): Predicted [Fe/H] values on the test set.
            fig (matplotlib.figure.Figure): The figure object for the Kiel diagram.
            ax (matplotlib.axes.Axes): The axes object for the Kiel diagram.
            density_map (matplotlib.collections.PathCollection): The scatter density plot object.
    """
    # Hyperparameters from kwargs with defaults
    learning_rate = kwargs.get('learning_rate', 1e-3)
    batch_size = kwargs.get('batch_size', 256)
    max_epochs = kwargs.get('max_epochs', 100)
    patience = kwargs.get('patience', 10)
    val_split_ratio = kwargs.get('val_split_ratio', 0.2)
    num_workers = kwargs.get('num_workers', 0) # Default to 0 for broader compatibility
    accelerator = kwargs.get('accelerator', 'auto')
    devices = kwargs.get('devices', 'auto')
    checkpoint_dir = kwargs.get('checkpoint_dir', 'mlp_checkpoints/')
    
    # Ensure checkpoint directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Prepare data: Convert numpy arrays and pandas Series to PyTorch Tensors
    X_train_tensor = torch.tensor(X_train_scaled_np, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train_series.values.reshape(-1, 1), dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test_scaled_np, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test_series.values.reshape(-1, 1), dtype=torch.float32)

    # Create full training dataset
    full_train_dataset = TensorDataset(X_train_tensor, y_train_tensor)

    # Split full training dataset into actual train and validation sets
    num_train_samples = len(full_train_dataset)
    num_val_samples = int(val_split_ratio * num_train_samples)
    num_actual_train_samples = num_train_samples - num_val_samples

    train_dataset, val_dataset = random_split(
        full_train_dataset, [num_actual_train_samples, num_val_samples],
        generator=torch.Generator().manual_seed(42) # for reproducibility
    )
    
    persistent_workers_flag = num_workers > 0
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=persistent_workers_flag)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers_flag)
    
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers_flag)

    # Callbacks
    early_stop_callback = EarlyStopping(monitor='val_loss', patience=patience, verbose=False, mode='min')
    # Sanitize network_shape for filename
    shape_str = "_".join(map(str, network_shape))
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=checkpoint_dir,
        filename=f'mlp-shape_{shape_str}-{{epoch:02d}}-{{val_loss:.4f}}',
        save_top_k=1,
        mode='min',
    )

    # Initialize model
    model = MLP(
        network_shape=network_shape,
        input_dim=X_train_scaled_np.shape[1],
        learning_rate=learning_rate
    )


    # Initialize Trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        callbacks=[early_stop_callback, checkpoint_callback],
        accelerator=accelerator,
        devices=devices,
        logger=True, # Uses TensorBoardLogger by default, logs to lightning_logs/
        enable_progress_bar=True,
        deterministic=True # For reproducibility, might impact performance
    )
    pl.seed_everything(42, workers=True) # For reproducibility

    # Train the model
    print(f"\nTraining MLP with network shape: {network_shape}...")
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # Load the best model from checkpoint
    print(f"Loading best model from: {checkpoint_callback.best_model_path}")
    best_model = MLP.load_from_checkpoint(checkpoint_callback.best_model_path)
    best_model.eval() # Set to evaluation mode

    # Evaluate on the test set (optional, but good practice)
    test_results = trainer.test(best_model, dataloaders=test_loader, verbose=False)
    print(f"Test results for MLP shape {network_shape}: {test_results}")

    # Make predictions on the test set
    with torch.no_grad():
        predicted_feh_mlp_tensor = best_model(X_test_tensor)
    predicted_feh_mlp = predicted_feh_mlp_tensor.cpu().numpy().flatten()

    # Plot the Kiel diagram using the existing plot_kiel_scatter_density function
    

    return best_model, predicted_feh_mlp


# Example usage: Define network shapes and run the MLP regression function
# These variables are assumed to be defined in previous cells:
# X_train_scaled, X_test_scaled, y_train, y_test, test_teff, test_logg
# plot_kiel_scatter_density (function)

mlp_network_shapes = [
    [64],             # Single hidden layer with 64 neurons
    [64, 32],         # Two hidden layers: 64 neurons, then 32 neurons
    [128, 64, 32]     # Three hidden layers
]
mlp_results = {}

# Set common hyperparameters for all MLP runs, can be overridden in kwargs
mlp_hyperparams = {
    'learning_rate': 0.001,
    'batch_size': 512,
    'max_epochs': 50,  # Adjust as needed, can be short for demonstration
    'patience': 5,     # Adjust as needed
    'num_workers': 0,  # Set to 0 for simplicity, os.cpu_count() for performance if env supports
    'accelerator': 'gpu', 'devices': 1, # Uncomment if GPU is available
}

for shape in mlp_network_shapes:
    print("-" * 50)
    shape_key = f'shape_{"_".join(map(str, shape))}'

    model_mlp, predictions_mlp = fit_and_plot_mlp_regression(
        X_train_scaled, X_test_scaled, y_train, y_test,
        test_teff, test_logg,
        network_shape=shape,
        **mlp_hyperparams 
    )
    mlp_results[shape_key] = {
        'model': model_mlp,
        'predictions': predictions_mlp
    }

print("-" * 50)
print("MLP regression experiments complete.")
print(f"Results stored in mlp_results dictionary with keys: {list(mlp_results.keys())}")