***
### Import of required libraries
***

In [None]:
import pickle
from typing import Generator, Tuple, Any

import numpy as np
import tensorflow as tf
import keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    LSTM,
    Dense,
    Input,
    Reshape,
    Concatenate,
    Flatten,
    Conv1D,
)
from tensorflow.keras.callbacks import ModelCheckpoint
import wandb
from wandb.keras import WandbMetricsLogger

***
### Import of training data, validation data and scalers
***

##### Training data

In [None]:
# time-variant input
train_in_var = np.load(
    "/cluster/home/krum/store/train_in32_var.npy",
    allow_pickle=True,
)
print(f"Training input time variant shape: {train_in_var.shape}")

# time-invariant input
train_in_con = np.load(
    "/cluster/home/krum/store/train_in32_con.npy",
    allow_pickle=True,
)
print(f"Training input time invariant shape: {train_in_con.shape}")

# output
train_out = np.load(
    "/cluster/home/krum/store/train_out32.npy",
    allow_pickle=True,
)
print(f"Training output shape: {train_out.shape}")

##### Validation data

In [None]:
# time-variant input
val_in_var = np.load(
    "/cluster/home/krum/store/val_in32_var.npy",
    allow_pickle=True,
)
print(f"Validation input time variant shape: {val_in_var.shape}")

# time-invariant input
val_in_con = np.load(
    "/cluster/home/krum/store/val_in32_con.npy",
    allow_pickle=True,
)
print(f"Validation input time invariant shape: {val_in_con.shape}")

# output
val_out = np.load(
    "/cluster/home/krum/store/val_out32.npy",
    allow_pickle=True,
)
print(f"Validation output shape: {val_out.shape}")

##### Scaler

In [None]:
# Input scaler
with open("/cluster/home/krum/store/scaler_in.pkl", "rb") as file:
    scaler_in = pickle.load(file)
# Output scaler
with open("/cluster/home/krum/store/scaler_out.pkl", "rb") as file:
    scaler_out = pickle.load(file)

***
### Batch generators
***

In [None]:
def batch_generator_train(
    batch_size: int,
) -> Generator[Tuple[Tuple[Any, Any], Any], None, None]:
    """
    Generates batches of input (numerical and categorical) and output data for training.
    The generator shuffles the dataset at the start of each epoch, ensuring every sample
    is used exactly once per epoch but in a randomized order.

    Parameters
    ----------
    batch_size : int
        The number of samples per batch
    train_data_size : int
        The total number of samples in the training dataset

    Yields
    -------
    Tuple[Tuple[Any, Any], Any]
        A tuple containing the inputs (numerical and categorical) and the output data for the batch.
    """
    train_data_size = train_in_var.shape[0]
    indices = np.arange(train_data_size)
    while True:
        np.random.shuffle(indices)  # Shuffle indices at the start of each epoch
        for start_idx in range(0, train_data_size, batch_size):
            end_idx = min(start_idx + batch_size, train_data_size)
            batch_indices = indices[start_idx:end_idx]

            input_var, input_con = (
                train_in_var[batch_indices],
                train_in_con[batch_indices],
            )
            output = train_out[batch_indices]

            yield (input_var, input_con), output


def batch_generator_val(
    batch_size: int,
) -> Generator[Tuple[Tuple[Any, Any], Any], None, None]:
    """
    Generates batches of input (numerical and categorical) and output data for validation.
    The generator shuffles the dataset at the start of each epoch, ensuring every sample
    is used exactly once per epoch but in a randomized order.

    Parameters
    ----------
    batch_size : int
        The number of samples per batch
    val_data_size : int
        The total number of samples in the validation dataset

    Yields
    -------
    Tuple[Tuple[Any, Any], Any]
        A tuple containing the inputs (numerical and categorical) and the output data for the batch.
    """
    val_data_size = val_in_var.shape[0]
    indices = np.arange(val_data_size)
    while True:
        np.random.shuffle(indices)  # Shuffle indices at the start of each epoch
        for start_idx in range(0, val_data_size, batch_size):
            end_idx = min(start_idx + batch_size, val_data_size)
            batch_indices = indices[start_idx:end_idx]

            input_var, input_con = (
                val_in_var[batch_indices],
                val_in_con[batch_indices],
            )
            output = val_out[batch_indices]

            yield (input_var, input_con), output

***
### Custom measures
***

In [None]:
# Extracting min and max values from scaler
lat_min = scaler_out.data_min_[0]
lat_max = scaler_out.data_max_[0]
lon_min = scaler_out.data_min_[1]
lon_max = scaler_out.data_max_[1]
alt_min = scaler_out.data_min_[2]
alt_max = scaler_out.data_max_[2]

In [None]:
def rmse_lat(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Calculates the RMSE between the true and predicted latitude values after
    adjusting for Min-Max scaling.

    Parameters
    ----------
    y_true : tf.Tensor
        The true latitude values
    y_pred : tf.Tensor
        The predicted latitude values

    Returns
    ----------
    tf.Tensor
        RMSE value between the true and predicted latitude values
    """
    lat_true = tf.gather(y_true, [0], axis=2)
    lat_true = tf.cast(lat_true, tf.float32)
    lat_pred = tf.gather(y_pred, [0], axis=2)
    lat_pred = tf.cast(lat_pred, tf.float32)

    # Adjusting for Min-Max scaling
    lat_min_f = tf.cast((tf.fill(tf.shape(lat_true), lat_min)), tf.float32)
    lat_max_f = tf.cast((tf.fill(tf.shape(lat_pred), lat_max)), tf.float32)

    # Reverse the Min-Max scaling
    lat_true_unnorm = lat_true * (lat_max_f - lat_min_f) + lat_min_f
    lat_pred_unnorm = lat_pred * (lat_max_f - lat_min_f) + lat_min_f

    return tf.math.sqrt(
        tf.math.reduce_mean(tf.math.square(lat_pred_unnorm - lat_true_unnorm))
    )


def rmse_lon(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Calculates the RMSE between the true and predicted longitude values after
    adjusting for Min-Max scaling.

    Parameters
    ----------
    y_true : tf.Tensor
        The true longitude values
    y_pred : tf.Tensor
        The predicted longitude values

    Returns
    ----------
    tf.Tensor
        RMSE value between the true and predicted longitude values
    """
    lon_true = tf.gather(y_true, [1], axis=2)
    lon_true = tf.cast(lon_true, tf.float32)
    lon_pred = tf.gather(y_pred, [1], axis=2)
    lon_pred = tf.cast(lon_pred, tf.float32)

    # Adjusting for Min-Max scaling
    lon_min_f = tf.cast((tf.fill(tf.shape(lon_true), lon_min)), tf.float32)
    lon_max_f = tf.cast((tf.fill(tf.shape(lon_pred), lon_max)), tf.float32)

    # Reverse the Min-Max scaling
    lon_true_unnorm = lon_true * (lon_max_f - lon_min_f) + lon_min_f
    lon_pred_unnorm = lon_pred * (lon_max_f - lon_min_f) + lon_min_f

    return tf.math.sqrt(
        tf.math.reduce_mean(tf.math.square(lon_pred_unnorm - lon_true_unnorm))
    )


def rmse_alt(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Calculates the RMSE between the true and predicted altitude values after
    adjusting for Min-Max scaling.

    Parameters
    ----------
    y_true : tf.Tensor
        The true altitude values
    y_pred : tf.Tensor
        The predicted altitude values

    Returns
    ----------
    tf.Tensor
        RMSE value between the true and predicted altitude values
    """
    alt_true = tf.gather(y_true, [2], axis=2)
    alt_true = tf.cast(alt_true, tf.float32)
    alt_pred = tf.gather(y_pred, [2], axis=2)
    alt_pred = tf.cast(alt_pred, tf.float32)

    # Adjusting for Min-Max scaling
    alt_min_f = tf.cast((tf.fill(tf.shape(alt_true), alt_min)), tf.float32)
    alt_max_f = tf.cast((tf.fill(tf.shape(alt_pred), alt_max)), tf.float32)

    # Reverse the Min-Max scaling
    alt_true_unnorm = alt_true * (alt_max_f - alt_min_f) + alt_min_f
    alt_pred_unnorm = alt_pred * (alt_max_f - alt_min_f) + alt_min_f

    return tf.math.sqrt(
        tf.math.reduce_mean(tf.math.square(alt_pred_unnorm - alt_true_unnorm))
    )

***
### Model definition
***

##### Size of model inputs and outputs

In [None]:
# Define shape of time dependent input
t_input_sequence_length = train_in_var.shape[1]
t_input_features = train_in_var.shape[2]
print(
    f"Time variant input: {t_input_sequence_length} timesteps with {t_input_features} features"
)

# Define shape of constant input
c_input_sequence_length = train_in_con.shape[1]
c_input_features = train_in_con.shape[2]
print(
    f"Time constant input: {c_input_sequence_length} timesteps with {c_input_features} features"
)

# Define shape of output
output_sequence_length = train_out.shape[1]
output_features = train_out.shape[2]
print(f"Output: {output_sequence_length} timesteps with {output_features} features")

##### Model architecture

In [None]:
# Define two sets of inputs
input_t = Input(shape=(t_input_sequence_length, t_input_features))
input_c = Input(shape=(c_input_sequence_length, c_input_features))

# First branch (Temporal)
conv_outputs = []
for i in range(t_input_features):
    x = Conv1D(filters=32, kernel_size=5, padding="same", activation="relu")(
        input_t[..., i : i + 1]
    )
    x = Conv1D(filters=32, kernel_size=5, padding="same", activation="relu")(x)
    x = Conv1D(filters=32, kernel_size=5, padding="same", activation="relu")(x)
    x = Flatten()(x)
    x = Dense(32, activation="relu")(x)
    conv_outputs.append(x)

t = Concatenate()(conv_outputs)

# Second branch (Constant features)
c = Flatten()(input_c)
c = Dense(32, activation="relu")(c)
c = Dense(32, activation="relu")(c)
c = Dense(32, activation="relu")(c)

# Combine the outputs of the two branches
combined = Concatenate()([t, c])

# Apply dense layers after combining
z = Reshape((9, 32))(combined)
z = LSTM(32, return_sequences=True)(z)
z = LSTM(32, return_sequences=False)(z)
z = Dense(output_sequence_length * output_features, activation="linear")(z)
z = Reshape((output_sequence_length, output_features))(z)

# Build model
model = Model(inputs=[input_t, input_c], outputs=z)

# Print overview
model.summary()

##### Custom loss function

In [None]:
# Loss function (Weighted MSE)
def custom_mse(weights):
    def weighted_mse(gt, pred):
        return K.sum(weights * K.square(gt - pred)) / K.sum(weights)

    return weighted_mse


# Defining the weights
weights = np.linspace(1, 0.1, 37)
weights = np.tile(weights, (3, 1)).T.astype(np.float32)

##### Model compilation

In [None]:
model.compile(
    loss=custom_mse(weights),
    optimizer=keras.optimizers.Adam(),
    metrics=[rmse_lat, rmse_lon, rmse_alt],
)

***
### Run model
***

##### Wandb log

In [None]:
wandb.init(
    project="MIAR_4Departures",
    entity="zhaw_zav",
    config={
        "input_lstm": ["2 layers, 32 cells"],
        "input_dense": ["Flatten()", '3 x Dense(32, activation="relu")'],
        "output": [
            "3CNN 32, Kernel=5",
        ],
        "optimizer": "adam",
        "loss": "weighted mean_squared_error (inputs:1, outputs linear from 1 to 0.1)",
        "learning_rate": "standard (0.001)",
        "epoch": 200,
        "batch_size": 1024,
        "steps_per_epoch": 31392,
        "validation_steps": 10479,
        "comment": "Model with mass input",
    },
)

##### Model checkpoints

In [None]:
# Saving model with best performance on validation data
model_checkpoint_callback = ModelCheckpoint(
    filepath=f"/cluster/home/krum/git/MT_krum_code/models/{wandb.run.name}.keras",
    save_best_only=True,
    monitor="val_loss",
    verbose=0,
    save_freq="epoch",
)

# Callback to reduce learning rate on plateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.8,
    patience=6,
    min_delta=0.0001,
    mode="min",
    min_lr=0.0001,
)

# Callback for early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=20,
    min_delta=0.0001,
    mode="min",
)

##### Fit model

In [None]:
model.fit(
    batch_generator_train(1024),
    epochs=200,
    steps_per_epoch=31392,
    validation_steps=10479,
    validation_data=batch_generator_val(1024),
    shuffle=False,
    callbacks=[
        WandbMetricsLogger(200),
        model_checkpoint_callback,
        reduce_lr,
        early_stopping,
    ],
)
wandb.finish()