In [None]:
"""
Created on Fri Jan 19

@author: mginolfi

"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import callbacks
import tensorflow as tf
import scipy.ndimage
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d



In [None]:
""" load training set, make X & Y dataset and normalise"""

df_train = pd.read_pickle('train_dataset.pickle')

# Extract features for training set
all_spectra_train = np.stack(df_train['combined_spectrum'].values)

# Normalization of training spectra
max_value_train = all_spectra_train.max()
all_spectra_train_normalized = all_spectra_train / max_value_train

# make X train # mic
X_train = all_spectra_train_normalized

# Extract labels for training set
all_redshift_train = df_train['z'].values
all_stellar_masses_train = df_train['log_m'].values
all_sfr_train = np.log10(df_train['sfr'].values)

# define labels: make Y
Y_train = np.column_stack((all_redshift_train, all_stellar_masses_train, all_sfr_train))

# Calculate mean and standard deviation for each label type in the training set
Y_train_mean =  Y_train.mean(axis=0)
Y_train_std = Y_train.std(axis=0)

# Normalize training labels
Y_train_normalized = (Y_train - Y_train_mean) / Y_train_std

del df_train

In [None]:
""" load validation set, make X & Y dataset and normalise"""

df_val = pd.read_pickle('validation_dataset.pickle')

# Extract features for validation set
all_spectra_val = np.stack(df_val['combined_spectrum'].values)

# Normalization of validation spectra
all_spectra_val_normalized = all_spectra_val / max_value_train

# make X val # mic
X_val = all_spectra_val_normalized

# Extract labels for validation set
all_redshift_val = df_val['z'].values
all_stellar_masses_val = df_val['log_m'].values
all_sfr_val = np.log10(df_val['sfr'].values)

# define labels
Y_val = np.column_stack((all_redshift_val, all_stellar_masses_val, all_sfr_val))

# Normalize validation labels
Y_val_normalized = (Y_val - Y_train_mean) / Y_train_std

del df_val

In [None]:
df_test = pd.read_pickle('test_dataset.pickle')

df_test.columns

# Extract features for test set
all_spectra_test = np.stack(df_test['combined_spectrum'].values)
# all_skyFlux_test = np.stack(df_test['combined_skyMask'].values)

# Normalization of validation spectra
all_spectra_test_normalized = all_spectra_test / max_value_train

# make X test # mic
X_test = all_spectra_test_normalized

# Extract labels for test set
all_redshift_test = df_test['z'].values
all_stellar_masses_test = df_test['log_m'].values
all_sfr_test = np.log10(df_test['sfr'].values)

# define labels
Y_test = np.column_stack((all_redshift_test, all_stellar_masses_test, all_sfr_test))

# Normalize test labels
Y_test_normalized = (Y_test - Y_train_mean) / Y_train_std

# read the wavelength axis, needed below for checks & visualisations
wavelength_axis = np.stack(df_test['combined_vacuumWave'].values)
wavelength_axis = wavelength_axis[0]

In [None]:
""""""""""""""""""""""""""""""""
"""   Modelling              """
""""""""""""""""""""""""""""""""

from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def create_joint_model(input_shape, encoding_dim):

    # Encoder
    inputs = Input(shape=input_shape)
    encoded = Dense(1024, activation='elu')(inputs)
    encoded = Dropout(0.3)(encoded)
    encoded = Dense(512, activation='elu')(encoded)
    encoded = Dropout(0.3)(encoded)

    # Latent space
    encoded = Dense(encoding_dim, activation='elu')(encoded)

    # Decoder
    decoded = Dense(512, activation='elu')(encoded)
    decoded = Dropout(0.3)(decoded)
    decoded = Dense(1024, activation='elu')(decoded)
    decoded = Dropout(0.3)(decoded)
    decoded = Dense(input_shape, activation='sigmoid', name='decoded_output')(decoded)

    # Regression
    regression = Dense(64, activation='elu')(encoded)
    regression = Dropout(0.3)(regression)
    # regression = Dense(32, activation='elu')(regression)
    # regression = Dropout(0.3)(regression)

    # Task-specific outputs
    redshift_output = Dense(1, activation='linear', name='redshift')(regression)
    stellar_mass_output = Dense(1, activation='linear', name='stellar_mass')(regression)
    sfr_output = Dense(1, activation='linear', name='sfr')(regression)

    # Combined model
    combined_model = Model(inputs, [decoded, redshift_output, stellar_mass_output, sfr_output])

    return combined_model

# Define parameters
input_shape = X_train.shape[1]
encoding_dim = 1000
task_names = ['redshift', 'stellar_mass', 'sfr']

# Create the joint model
joint_model = create_joint_model(input_shape, encoding_dim)

joint_model.summary()

In [None]:
""" Class for metrics tracking """
from tensorflow.keras.callbacks import Callback

class MetricsPlotter(Callback):
    def __init__(self, task_names, include_autoencoder=False):
        self.include_autoencoder = include_autoencoder
        self.train_loss = []
        self.val_loss = []
        self.task_metrics = {task: {'train_mae': [], 'val_mae': []} for task in task_names}
        if self.include_autoencoder:
            self.autoencoder_loss = {'train': [], 'val': []}

    def on_epoch_end(self, epoch, logs=None):
        self.train_loss.append(logs.get('loss'))
        self.val_loss.append(logs.get('val_loss'))
        if self.include_autoencoder:
            self.autoencoder_loss['train'].append(logs.get('decoded_output_loss'))
            self.autoencoder_loss['val'].append(logs.get('val_decoded_output_loss'))

        for task in self.task_metrics.keys():
            self.task_metrics[task]['train_mae'].append(logs.get(f'{task}_mae'))
            self.task_metrics[task]['val_mae'].append(logs.get(f'val_{task}_mae'))

        self.plot_metrics(epoch)

    def plot_metrics(self, epoch):
        plt.clf()
        num_tasks = len(self.task_metrics) + (1 if self.include_autoencoder else 0)
        num_plots = num_tasks + 1
        cols = 2
        rows = (num_plots + cols - 1) // cols
        fig, axs = plt.subplots(rows, cols, figsize=(12, 4 * rows))

        axs[0, 0].plot(range(epoch + 1), self.train_loss, label='Training Loss')
        axs[0, 0].plot(range(epoch + 1), self.val_loss, label='Validation Loss')
        axs[0, 0].set_title('Total Loss')
        axs[0, 0].legend()

        plot_index = 1
        if self.include_autoencoder:
            ax = axs[plot_index // cols, plot_index % cols]
            ax.plot(range(epoch + 1), self.autoencoder_loss['train'], label='Autoencoder Training Loss')
            ax.plot(range(epoch + 1), self.autoencoder_loss['val'], label='Autoencoder Validation Loss')
            ax.set_title('Autoencoder Loss')
            ax.legend()
            plot_index += 1

        for task, metrics in self.task_metrics.items():
            ax = axs[plot_index // cols, plot_index % cols]
            ax.plot(range(epoch + 1), metrics['train_mae'], label=f'{task} Training MAE')
            ax.plot(range(epoch + 1), metrics['val_mae'], label=f'{task} Validation MAE')
            ax.set_title(f'{task.capitalize()} MAE')
            ax.legend()
            plot_index += 1

        plt.tight_layout()
        plt.show()

In [None]:
#%%
""" Compile & run  joint model """

losses = {
    'decoded_output': 'mse',  # Loss for the autoencoder part
    'redshift': 'mse',
    'stellar_mass': 'mse',
    'sfr': 'mse'
}

loss_weights = {
    'decoded_output': 1,  # Adjust this weight as needed
    'redshift': 2,
    'stellar_mass': 0.5,
    'sfr': 0.5
}

joint_model.compile(optimizer=Adam(learning_rate=0.001),
                    loss=losses,
                    loss_weights=loss_weights,
                    metrics={'redshift': 'mae', 'stellar_mass': 'mae', 'sfr': 'mae'})


# Early Stopping Callback
early_stopping = callbacks.EarlyStopping(
    monitor='val_redshift_loss',  # Monitor the validation loss
    patience=10,         # Number of epochs with no improvement after which training will be stopped
    verbose=1,           # To log when training is stopped
    restore_best_weights=True  # Restores model weights from the epoch with the best value of the monitored quantity.
)


# Instantiate the callback with the task names and include the autoencoder
metrics_plotter = MetricsPlotter(task_names=task_names, include_autoencoder=True)


# Train the model
history = joint_model.fit(
    X_train,
    [X_train, Y_train_normalized[:, 0], Y_train_normalized[:, 1], Y_train_normalized[:, 2]],
    validation_data=(X_val, [X_val, Y_val_normalized[:, 0], Y_val_normalized[:, 1], Y_val_normalized[:, 2]]),
    shuffle=True,
    epochs=300,
    batch_size=1026,
    callbacks=[early_stopping, metrics_plotter]
)


In [None]:
""" Check history """

def plot_history(history, task, early_stopping_epoch=None):

    plt.figure(figsize=(12, 4))

    if task == 'total':
        # Plot total training & validation loss values
        plt.subplot(1, 1, 1)
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Total Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        if early_stopping_epoch is not None: plt.axvline(x=early_stopping_epoch, color='gray', linestyle='--')

    else:
        # Plot training & validation loss values for specific task
        plt.subplot(1, 2, 1)
        plt.plot(history.history[task+'_loss'])
        plt.plot(history.history['val_'+task+'_loss'])
        plt.title('Model Loss for ' + task.capitalize())
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        if early_stopping_epoch is not None: plt.axvline(x=early_stopping_epoch, color='gray', linestyle='--')

        # Plot training & validation MAE values for specific task
        plt.subplot(1, 2, 2)
        plt.plot(history.history[task+'_mae'])
        plt.plot(history.history['val_'+task+'_mae'])
        plt.title('Model MAE for ' + task.capitalize())
        plt.ylabel('MAE')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        if early_stopping_epoch is not None: plt.axvline(x=early_stopping_epoch, color='gray', linestyle='--')

    plt.show()

# #Plot total loss and tasks metrics
plot_history(history, 'total', early_stopping_epoch=np.argmin(history.history['val_redshift_loss']))
plot_history(history, 'redshift', early_stopping_epoch=np.argmin(history.history['val_redshift_loss']))
plot_history(history, 'stellar_mass', early_stopping_epoch=np.argmin(history.history['val_redshift_loss']))
plot_history(history, 'sfr', early_stopping_epoch=np.argmin(history.history['val_redshift_loss']))

In [None]:
""" Check global predictions on test-set """

test_metrics = joint_model.evaluate(X_test, {'redshift': Y_test_normalized[:, 0], 'stellar_mass': Y_test_normalized[:, 1], 'sfr': Y_test_normalized[:, 2]})

def inverse_transform(normalized_values, means, stds):
    return normalized_values * stds + means

def plot_predictions(predicted, actual, task_name):
    plt.scatter(actual, predicted, alpha=0.1)
    plt.xlabel('Actual Values')
    plt.ylabel('Predicted Values')
    plt.title(f'Predicted vs Actual Values for {task_name}')
    plt.plot([actual.min(), actual.max()], [actual.min(), actual.max()], 'k--', lw=4)
    plt.show()


# Compute predictions
predictions = joint_model.predict(X_test)

# predictions[0] contains the reconstructed spectra, which we might not need here
# predictions[1], predictions[2], and predictions[3] contain the redshift, stellar mass, and SFR predictions respectively
redshift_pred = predictions[1]
stellar_mass_pred = predictions[2]
sfr_pred = predictions[3]

# Applying the inverse transformation to predictions
redshift_pred_rescaled = inverse_transform(redshift_pred.squeeze(), Y_train_mean[0], Y_train_std[0])
stellar_mass_pred_rescaled = inverse_transform(stellar_mass_pred.squeeze(), Y_train_mean[1], Y_train_std[1])
sfr_pred_rescaled = inverse_transform(sfr_pred.squeeze(), Y_train_mean[2], Y_train_std[2])

# Plotting predictions vs actual values for each task
for i, task_name in enumerate(['Redshift', 'Stellar Mass', 'SFR']):
    actual = Y_test[:, i]
    if task_name == 'Redshift':
        predicted = redshift_pred_rescaled
    elif task_name == 'Stellar Mass':
        predicted = stellar_mass_pred_rescaled
    else: # 'SFR'
        predicted = sfr_pred_rescaled

    plot_predictions(predicted, actual, task_name)


In [None]:
""" Check residual on specific tasks """

# redshfit
residuals = redshift_pred_rescaled -Y_test[:, 0]
plt.hist(residuals, bins=100)

np.std(residuals)

In [None]:
""" check individual spectra """

def plot_spectrum_with_halpha_and_saliency(index, X_test, Y_test, model, Y_train_mean, Y_train_std, wavelength_axis):
    # Predict the redshift for the selected object
    object_spectrum = X_test[index]
    predictions = model.predict(np.expand_dims(object_spectrum, axis=0))

    # Assuming the redshift prediction is the second output
    predicted_redshift = predictions[1]

    # If the predicted_redshift is an array with a single value, extract that value
    if predicted_redshift.size == 1:
        predicted_redshift = predicted_redshift.item()

    # Inverse transform the predicted redshift
    predicted_redshift_rescaled = inverse_transform(predicted_redshift, Y_train_mean[0], Y_train_std[0])

    # Actual redshift
    actual_redshift = Y_test[index, 0]

    # H-alpha line wavelength in Ångström (rest frame)
    h_alpha_rest = 6562.8

    # Calculate the observed positions of the H-alpha line
    predicted_h_alpha_observed = h_alpha_rest * (1 + predicted_redshift_rescaled)
    actual_h_alpha_observed = h_alpha_rest * (1 + actual_redshift)


    input_sample_tensor = tf.convert_to_tensor(object_spectrum, dtype=tf.float32)
    input_sample_tensor = tf.expand_dims(input_sample_tensor, axis=0)  # Ensure it's 2D

    with tf.GradientTape() as tape:
        tape.watch(input_sample_tensor)
        prediction = model(input_sample_tensor)

    gradient = tape.gradient(prediction, input_sample_tensor)
    processed_grad = tf.abs(gradient)
    processed_grad /= tf.reduce_max(processed_grad)  # Normalize
    processed_grad = processed_grad.numpy().flatten()  # Convert to 1D numpy array


    # Plot the spectrum with saliency map overlay
    plt.figure(figsize=(12, 6))

    # Add vertical lines for predicted and actual H-alpha line positions
    plt.axvline(predicted_h_alpha_observed, color='green', linestyle='--', label='Predicted H-alpha')
    plt.axvline(actual_h_alpha_observed, color='red', linestyle='--', label='Actual H-alpha')


    predicted_redshift_scalar = predicted_redshift_rescaled.item()  # if predicted_redshift_rescaled is a numpy array with a single value
    actual_redshift_scalar = actual_redshift.item()  # if actual_redshift is a numpy array with a single value

    # Add an inset with redshift information
    textstr = f'Predicted Redshift: {predicted_redshift_scalar:.2f}\nActual Redshift: {actual_redshift_scalar:.2f}'
    plt.gcf().text(0.75, 0.15, textstr, fontsize=10, bbox=dict(facecolor='white', alpha=0.5))

    # Overlay the saliency map in red
    plt.plot(wavelength_axis, processed_grad, label='Saliency Map', color='red', alpha=1, lw=0.1)

    # Plot the spectrum
    plt.plot(wavelength_axis, X_test[index]/X_test[index].max(), label='Spectrum', lw=1, alpha=1)


    plt.xlabel('Wavelength (Å)')
    plt.ylabel('Intensity / Saliency')
    plt.title(f'Spectrum with Predicted and Actual H-alpha Line and Saliency Map (Object {index})')
    plt.legend(loc='lower left')
    plt.show()



# visualize
index = 11
# special_ID = 333011988000028
# index = np.where(df_test['ID'] == special_ID)[0][0]
plot_spectrum_with_halpha_and_saliency(index, X_test, Y_test, joint_model, Y_train_mean, Y_train_std, wavelength_axis)