In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout, Masking, Reshape, Layer, Lambda, Concatenate, LayerNormalization, MultiHeadAttention, Add, Flatten, Dot
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report
from tensorflow.keras.optimizers import Adadelta
from keras.layers import BatchNormalization
from tensorflow.keras.initializers import HeNormal

def grn(input,hidden_size=None):
    if hidden_size == None:
        first_dense = Dense(units=input.shape[-1],activation="elu")(input)
    else:
        first_dense = Dense(hidden_size, activation="elu")(input)
    dense_gate = Dense(units=first_dense.shape[-1], activation="sigmoid")(first_dense)
    skip_conn = Add()[input, dense_gate]
    norm = LayerNormalization()(skip_conn)
    return norm

def variable_selection(input,hidden_size):
    grn_1 = grn(input=input,hidden_size=hidden_size)
    flatten_ = Flatten()(input)
    grn_2 = grn(input=flatten_,hidden_size=hidden_size)
    dense_softmax = Dense(unites=grn_2.shape[-1], activation="softmax")(grn_2)
    combined = Dot(axes=-1)([grn_1,dense_softmax])
    return combined
    
def TFT(X_train_static,X_valid_static,X_test_static,X_train_past,X_valid_past,X_test_past,
        X_train_future,X_valid_future,X_test_future,
        Y_train,Y_valid,Y_test,weights, epochs=10,
                     batch_size=100, hidden_size = 80, attention_heads=4):

    tf.keras.backend.clear_session()
    n_classes = 3

    inputs_static = Input(shape=(X_train_static.shape[1], X_train_static.shape[2]))
    inputs_past = Input(shape=(X_train_past.shape[1], X_train_past.shape[2])) 
    inputs_future = Input(shape=(X_train_future.shape[1], X_train_future.shape[2]))

    variable_selection_static = variable_selection(inputs_static, hidden_size=hidden_size)
    variable_selection_past = variable_selection(inputs_past, hidden_size=hidden_size)
    variable_selection_future = variable_selection(inputs_future, hidden_size=hidden_size)

    lstm_past = LSTM(hidden_size)(variable_selection_past)
    lstm_future = LSTM(hidden_size)(variable_selection_future)

    lstm_gate_past = Dense(hidden_size,activation="sigmoid")(lstm_past)
    lstm_gate_future = Dense(hidden_size,activation="sigmoid")(lstm_future)

    skip_lstm_past = Add()([variable_selection_past,lstm_gate_past])
    skip_lstm_future = Add()([variable_selection_future,lstm_gate_future])

    norm_lstm_past = LayerNormalization()(skip_lstm_past)
    nomr_lstm_future = LayerNormalization()(skip_lstm_future)

    combine_past_static = Concatenate(axis=-1)([norm_lstm_past,variable_selection_static])
    combine_future_static = Concatenate(axis=-1)([nomr_lstm_future,variable_selection_static])

    grn_past = grn(combine_past_static,hidden_size=hidden_size)
    grn_future = grn(combine_future_static,hidden_size=hidden_size)

    attention_past = MultiHeadAttention(
    num_heads=attention_heads,
    key_dim=hidden_size//attention_heads)(
    query=grn_past,
    key=grn_past,
    value=grn_past)
    residual_1 = Add()([grn_past, attention_past])  # Add residual connection
    norm_1 = LayerNormalization()(residual_1)

    attention_future = MultiHeadAttention(
    num_heads=attention_heads,
    key_dim=hidden_size//attention_heads)(
    query=grn_future,
    key=grn_future,
    value=grn_future)
    residual_2 = Add()([grn_future, attention_future])
    norm_2 = LayerNormalization()(residual_2)

    cross_attention = MultiHeadAttention(
    num_heads=attention_heads,
    key_dim=hidden_size//attention_heads)(
    # Query comes from the future sequence
    query=norm_2,
    # Keys and values come from the past sequence
    key=norm_1,
    value=norm_1)

    # Add residual connection with the future sequence
    residual_cross = Add()([norm_2, cross_attention])

    # Apply layer normalization
    norm_cross = LayerNormalization()(residual_cross)

    grn_cross = grn(norm_cross,hidden_size=hidden_size)

    class_predictions = Dense(n_classes, activation='softmax',name="class", kernel_initializer=HeNormal())(grn_cross)
    Attention_base = Model(inputs=[X_train_static,X_train_past,X_train_future], outputs=class_predictions)

    # we will keep this as a standardized learning rate optimizer across all models
    optimizer = Adadelta(
    learning_rate=1.0,
    rho=0.8,
    epsilon=1e-7)      # Default , to prevent division by zero)

    Attention_base.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['f1_score'])

    Attention_base.summary()

    # Pass the callback during training
    history = Attention_base.fit(
    x=[X_train_static,X_train_past,X_train_future], y=Y_train,
    validation_data=([X_valid_static,X_valid_past,X_valid_future], Y_valid),
    epochs=epochs,
    batch_size=batch_size,
    class_weight=weights,
    shuffle=False)

    fig, ax1 = plt.subplots()

    # Plot losses on the primary y-axis
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color='tab:red')
    ax1.plot(history.history['loss'], label='Train Loss', color='red', linestyle='-')
    ax1.plot(history.history['val_loss'], label='Validation Loss', color='red', linestyle='--')
    ax1.tick_params(axis='y', labelcolor='tab:red')

    # Create a second y-axis for accuracy
    ax2 = ax1.twinx()
    ax2.set_ylabel('Accuracy', color='tab:blue')
    ax2.plot(history.history['f1_score'], label='Train f1_score', color='blue', linestyle='-')
    ax2.plot(history.history['val_f1_score'], label='Validation f1_score', color='blue', linestyle='--')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    # Combine legends from both axes
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2)  # Legend outside the plot

    plt.title('Model Accuracy and Loss')
    plt.tight_layout()  # Adjust layout to avoid clipping
    plt.show()

    y_pred = Attention_base.predict(X_test_static,X_test_past,X_test_future)
    y_pred = np.argmax(y_pred, axis=-1)

    return Y_test, y_pred