# Import necessary libraries

In [None]:
import numpy as np
import pandas as pd
from pandas import read_csv
import pickle
import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from tensorflow.keras.layers import Input, Dense, Dropout, LayerNormalization, MultiHeadAttention, TimeDistributed, Masking, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Positional embedding class

In [None]:
#class PositionalEncoding:
class PositionalEncoding(Layer):
    def __init__(self):
        super(PositionalEncoding, self).__init__()

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
        return pos * angle_rates

    def call(self, seq_len, d_model):
        #seq_len = inputs.shape[1]
        #d_model = inputs.shape[2]
        angles = self.get_angles(np.arange(seq_len)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
        angles[:, 0::2] = np.sin(angles[:, 0::2])
        angles[:, 1::2] = np.cos(angles[:, 1::2])
        pos_encoding = angles[np.newaxis, ...]
        return tf.cast(pos_encoding, tf.float32)

# Encoder

In [None]:
def transformer_encoder(inputs, head_size, num_heads, ff_dim, mask, dropout=0):
    # Self-attention
    attention, scores = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)(inputs, inputs, attention_mask=mask,
                                                                                   return_attention_scores=True)
    attention = Dropout(dropout)(attention)
    out1 = LayerNormalization(epsilon=1e-6)(inputs + attention)
    
    # Feed-forward network
    ffn = Dense(ff_dim, activation='relu')(out1)
    ffn = Dense(inputs.shape[-1])(ffn)
    ffn = Dropout(dropout)(ffn)
    ffn = LayerNormalization(epsilon=1e-6)(out1 + ffn)
    ffn = Dense(ff_dim)(ffn)
    return ffn

In [None]:
def cross_attention_block(inputs_a, inputs_b, head_size, num_heads, dropout=0, mask_a=None, mask_b=None):
    # Cross-attention from A to B
    attention_ab, score_a = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)(
        inputs_a, inputs_b, attention_mask=mask_a, return_attention_scores=True)
    attention_ab = Dropout(dropout)(attention_ab)
    out_ab = LayerNormalization(epsilon=1e-6)(inputs_a + attention_ab)
    
    # Cross-attention from B to A
    attention_ba, score_b = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)(
        inputs_b, inputs_a, attention_mask=mask_b, return_attention_scores=True)
    attention_ba = Dropout(dropout)(attention_ba)
    out_ba = LayerNormalization(epsilon=1e-6)(inputs_b + attention_ba)
    
    return out_ab, out_ba, score_a, score_b


# Decoder

In [None]:
def transformer_decoder(inputs, encoder_outputs, head_size, num_heads, ff_dim, dropout=0):#, encoder_mask=None):
    # Masked self-attention (causal attention)
    attention = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)
    attention_out = attention(inputs, inputs, use_causal_mask=True)
    attention_out = Dropout(dropout)(attention_out)
    out1 = LayerNormalization(epsilon=1e-6)(inputs + attention_out)
    
    # Cross-attention
    cross_attention = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)
    attention_out = cross_attention(out1, encoder_outputs)#, attention_mask=encoder_mask)
    attention_out = Dropout(dropout)(attention_out)
    out2 = LayerNormalization(epsilon=1e-6)(out1 + attention_out)
    
    # Feed-forward
    ffn = Dense(ff_dim, activation='relu')
    ffn_out = ffn(out2)
    ffn_out = Dense(inputs.shape[-1])(ffn_out)
    ffn_out = Dropout(dropout)(ffn_out)
    out3 = LayerNormalization(epsilon=1e-6)(out2 + ffn_out)
    
    return out3

# Build Cross-attentional Transformer-AutoRegressive

In [None]:
def build_transformer_model(max_seq_length, dim_a, dim_b, head_size, num_heads, ff_dim, dropout=0):
    pe = PositionalEncoding()
    # Encoder input
    encoder_input_a = Input(shape=(max_seq_length, dim_a))
    encoder_input_a += pe(max_seq_length, dim_a)
    encoder_input_b = Input(shape=(max_seq_length, dim_b))
    encoder_input_b += pe(max_seq_length, dim_b)
    
    # Masking to handle variable length sequences and features
    time_step_mask_a = tf.math.not_equal(encoder_input_a[:, :, 0], -50.0)
    time_step_mask_a = tf.cast(time_step_mask_a[:, None, None, :], tf.float32)
    time_step_mask_b = tf.math.not_equal(encoder_input_b[:, :, 0], -50.0)
    time_step_mask_b = tf.cast(time_step_mask_b[:, None, None, :], tf.float32)
    
    # Encoder
    encoder_output_a = encoder_input_a
    encoder_output_b = encoder_input_b
    
    # Self-attention
    encoder_output_a = transformer_encoder(encoder_output_a, head_size, num_heads, ff_dim, time_step_mask_a, dropout)
    encoder_output_b = transformer_encoder(encoder_output_b, head_size, num_heads, ff_dim, time_step_mask_b, dropout)
    
    # Cross-attention between modalities
    cross_a, cross_b, cross_score_a, cross_score_b = cross_attention_block(encoder_output_a, encoder_output_b, head_size, num_heads, dropout,
                                                                           time_step_mask_a, time_step_mask_b)

    # Concatenate the Cross-attention outputs
    combined = tf.concat([cross_a, cross_b], axis=2)
    combined = Dense(ff_dim, activation='relu')(combined)
    combined = Dropout(dropout)(combined)
    encoder_output = LayerNormalization(epsilon=1e-6)(combined)
    
    # Decoder (predicting next two time points)
    # Decoder input 
    decoder_inputs = Input(shape=(2, dim_a+dim_b))
    decoder_inputs += pe(2, dim_a+dim_b)
    decoder_outputs = transformer_decoder(decoder_inputs, encoder_output, head_size, num_heads, ff_dim, dropout)
    
    # Output layer
    outputs = TimeDistributed(Dense(dim_a+dim_b))(decoder_outputs)
    
    # CAAT-EHR Model
    model = Model([encoder_input_a, encoder_input_b, decoder_inputs], outputs)
    # CAAT-EHR encoder Model
    encoder = Model([encoder_input_a, encoder_input_b], encoder_output, name='encoder')
    
    #return model, score_a, score_b, cross_score_a, cross_score_b
    return model, encoder

# Retrieve the data

In [None]:
# unpikle data
file_name = 'modal1.pkl'
X_pretrain_modal1 = pd.read_pickle(file_name)

file_name = 'modal2.pkl'
X_pretrain_modal2 = pd.read_pickle(file_name)

file_name = 'target.pkl'
pretrain_target = pd.read_pickle(file_name)

In [None]:
X_pretrain_modal1.shape

In [None]:
X_pretrain_modal2.shape

In [None]:
pretrain_target.shape

# Hyperparameters

In [None]:
# Parameters
MAX_SEQ_LENGTH = 20  # Maximum sequence length
INPUT_DIM_1 = 15
INPUT_DIM_2 = 30
HEAD_SIZE = 32
NUM_HEADS = 2
FF_DIM = 64
DROPOUT_RATE = 0.1
BATCH_SIZE = 64
EPOCHS = 100
LEARNING_RATE = 1e-4

# Building the model

In [None]:
# Build and compile the model
model, encoder_model = build_transformer_model(MAX_SEQ_LENGTH, INPUT_DIM_1, INPUT_DIM_2,HEAD_SIZE, NUM_HEADS, FF_DIM, DROPOUT_RATE)
optimizer = Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer, loss='mse')

# Pre-train the model

In [None]:
# Train the model
callback = keras.callbacks.EarlyStopping(monitor='loss', patience=20, mode="min")
history = model.fit([X_pretrain_modal1, X_pretrain_modal2, pretrain_target], (pretrain_target), batch_size=BATCH_SIZE, epochs=EPOCHS,
                    callbacks=[callback])

# Save and load the model

In [None]:
encoder_model.save('transformer_encoder.keras')