In [1]:
# ==============================================================
# Example usage of trained Neural Networks
# Author: Elena Gonz√°lez Prieto
# ==============================================================

import sys
import torch
import numpy as np
from sample_NNs import ClassificationNeuralNetwork, RegressionNeuralNetwork
import torch.nn.functional as F



In [None]:
# --------------------------------------------------------------
# Classification Model
# --------------------------------------------------------------

classification_model = ClassificationNeuralNetwork()

# Load the saved checkpoint for the classification model
classification_checkpoint = torch.load(
    "sample_classmodel.pt",  # replace with your file path
    map_location=torch.device("cpu")  # load on CPU for demonstration
)

# Load model weights
classification_model.load_state_dict(classification_checkpoint["model_state_dict"])

# Load normalization statistics (mean and std used during training)
input_mean = classification_checkpoint["train_mean"]
input_std = classification_checkpoint["train_std"]


# --------------------------------------------------------------
# Regression Model
# --------------------------------------------------------------

regression_model = RegressionNeuralNetwork()

# Load the saved checkpoint for the regression model
regression_checkpoint = torch.load(
    "sample_regmodel.pt",
    map_location=torch.device("cpu")
)

# Load model weights
regression_model.load_state_dict(regression_checkpoint["model_state_dict"])


In [None]:

# Data needs to be pre-processed in the exact way that the model was trained on. Input data should have the format 
# 0: log10(Age[Gyr] + 0.001) 1:log10(Rp[Rsun]+1) 2:log10(Vinf[km/s]) 3:ln(Mass1[MSUN]) 4:ln(Mass2[Msun])

# --------------------------------------------------------------
# Example input data
# --------------------------------------------------------------

# Example: say I have two 1 Msun stars, of age 5 Gyr, colliding at 100 km/s with periapsis of 0.5 Rsun: 
data = np.array([
    np.log10(5. + 0.001),
    np.log10(0.5 + 1.),
    np.log10(100.),
    np.log(1.),
    np.log(1.)
], dtype=np.float32)

# Record the initial masses to transform regression outputs
initial_masses = np.exp(data[3]) + np.exp(data[4])

# --------------------------------------------------------------
# Normalization
# --------------------------------------------------------------

# Subtract the mean and divide by standard deviation computed on the training set.
# This ensures the new input data has the same scaling as the data the model was trained on.
normalized_data = (data - input_mean) / input_std


# Convert to torch tensor and add batch dimension
X = torch.tensor(normalized_data, dtype=torch.float32).unsqueeze(0)

In [None]:
# --------------------------------------------------------------
# Make Predictions
# --------------------------------------------------------------

# Classification labels include: 
# 0: mutual destruction, no stars left 
# 1: merger, one star left 
# 2: fly-by, two stars left
# 3: stripped star, 1 star left


# Regression quantities include (shape = [1, 3]): 
# M1,final [Msun] / Mtot,initial [Msun] (Note: you need to multiply by the initial mass of the system to recover the final mass in units of Msun)
# M2,final [Msun] / Mtot,initial [Msun] (Note: you need to multiply by the initial mass of the system to recover the final mass in units of Msun)
# Mejec,final [Msun] / Mtot,initial [Msun] (Here Mejec is the unbound mass)

classsification_pred = classification_model(X)
regression_pred      = regression_model(X)

predicted_class = torch.argmax(classsification_pred, dim=1).item()
predicted_values = regression_pred.squeeze(0).tolist()  # Converts tensor to a list


print(f"Predicted class: {predicted_class}")
print(f"Predicted Star 1 final mass in [MSUN]: {round(predicted_values[0] * initial_masses, 2)}")
print(f"Predicted Star 2 final mass in [MSUN]: {round(predicted_values[1] * initial_masses, 2)}")

