<a href="https://colab.research.google.com/github/gyasifred/NLP-Techniques/blob/main/MODEL_NLP_PROJECT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%cd drive/MyDrive/CLINLPPROJ/

[Errno 2] No such file or directory: 'drive/MyDrive/CLINLPPROJ/'
/content/drive/MyDrive/CLINLPPROJ


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, LSTM, Concatenate, MultiHeadAttention, LayerNormalization
from tensorflow.keras.models import Model
from transformers import TFBertModel
from datetime import datetime
from tensorflow.keras import mixed_precision
import os
import numpy as np
import json
import keras

# Configuration Parameters (Hardcoded)
data_dir = '/content/drive/MyDrive/CLINLPPROJ/processed_data'  # Directory containing the processed data files
model_dir = "/content/drive/MyDrive/CLINLPPROJ/models/clinical_prediction"  # Directory for model and logs
batch_size = 4  # Batch size for training
epochs = 2  # Number of epochs
learning_rate = 1e-4  # Learning rate
sequence_length = 128  # Sequence length for time-series data
time_series_features = 24  # Number of time-series features
static_features = 21  # Number of static features
bert_model_name = "emilyalsentzer/Bio_ClinicalBERT"  # Pretrained BERT model
bert_dropout_rate = 0.2  # Dropout rate for BERT output
bert_input_size = 512  # BERT input size
lstm_sizes = [64, 32, 32]  # LSTM layer sizes
lstm_dropout = 0.2  # Dropout rate for LSTM layers
mlp_sizes = [128, 64]  # MLP layer sizes
mlp_dropout = 0.4  # Dropout rate for MLP layers
fusion_dropout = 0.4  # Dropout rate for fusion layers
num_attention_heads = 12  # Number of attention heads in fusion layer
fusion_key_dim = 64  # Key dimension for multi-head attention
gpu = "1"  # GPU to use (set -1 for CPU)

@keras.saving.register_keras_serializable()
class BERTLayer(tf.keras.layers.Layer):
    def __init__(self, bert_model, output_size, dropout_rate, **kwargs):
        super().__init__(**kwargs)
        self.bert = bert_model
        self.output_size = output_size
        self.dropout_rate = dropout_rate

        self.dense = Dense(output_size, activation="relu")
        self.dropout = Dropout(dropout_rate)

    def get_config(self):
        config = super().get_config()
        config.update({
            'output_size': self.output_size,
            'dropout_rate': self.dropout_rate,
        })
        return config

    @classmethod
    def from_config(cls, config):
        # Note: You'll need to recreate the BERT model when loading
        # This is just a placeholder for serialization
        return cls(bert_model=None, **config)

    def call(self, inputs, training=False):
        input_ids, attention_mask = inputs
        input_ids = tf.cast(input_ids, dtype=tf.int32)
        attention_mask = tf.cast(attention_mask, dtype=tf.int32)

        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, training=training)
        sequence_output = bert_output[0]
        pooled_output = sequence_output[:, 0, :]
        x = self.dense(pooled_output)
        x = self.dropout(x, training=training)
        return x


@keras.saving.register_keras_serializable()
class FusionLayer(tf.keras.layers.Layer):
    def __init__(self, num_heads, key_dim, dropout_rate, fusion_dim=128, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.dropout_rate = dropout_rate
        self.fusion_dim = fusion_dim

    def build(self, input_shape):
        # Project all inputs to the same dimension
        self.lstm_projection = Dense(self.fusion_dim, activation="relu")
        self.bert_projection = Dense(self.fusion_dim, activation="relu")
        self.mlp_projection = Dense(self.fusion_dim, activation="relu")

        self.attention = MultiHeadAttention(
            num_heads=self.num_heads,
            key_dim=self.key_dim,
            dropout=self.dropout_rate
        )
        self.layer_norm = LayerNormalization(epsilon=1e-6)

        super().build(input_shape)

    def get_config(self):
        config = super().get_config()
        config.update({
            'num_heads': self.num_heads,
            'key_dim': self.key_dim,
            'dropout_rate': self.dropout_rate,
            'fusion_dim': self.fusion_dim
        })
        return config

    def compute_output_shape(self, input_shape):
        return (None, self.fusion_dim)

    def call(self, inputs, training=True):
        lstm_output, bert_output, mlp_output = inputs

        # Project all inputs to same dimension
        lstm_proj = self.lstm_projection(lstm_output)
        bert_proj = self.bert_projection(bert_output)
        mlp_proj = self.mlp_projection(mlp_output)

        # Add sequence dimension
        lstm_proj = tf.expand_dims(lstm_proj, axis=1)
        bert_proj = tf.expand_dims(bert_proj, axis=1)
        mlp_proj = tf.expand_dims(mlp_proj, axis=1)

        # Concatenate along sequence dimension
        concat_output = tf.concat([lstm_proj, bert_proj, mlp_proj], axis=1)

        # Self-attention
        attention_output = self.attention(
            query=concat_output,
            value=concat_output,
            key=concat_output,
            training=training
        )

        # Residual connection and normalization
        normalized_output = self.layer_norm(concat_output + attention_output)

        # Global pooling across sequence dimension
        output = tf.reduce_mean(normalized_output, axis=1)
        return output


def build_model():
    # Initialize BERT
    bert_model = TFBertModel.from_pretrained(bert_model_name)
    bert_layer = BERTLayer(bert_model, 128, bert_dropout_rate)

    # Define inputs
    time_series_input = Input(shape=(sequence_length, time_series_features),
                            name="time_series_input")
    bert_input_ids = Input(shape=(512,), dtype=tf.int32, name="bert_input_ids")
    bert_attention_mask = Input(shape=(512,), dtype=tf.int32, name="bert_attention_mask")
    static_input = Input(shape=(static_features,), name="static_input")

    # LSTM branch
    x_lstm = LSTM(lstm_sizes[0], return_sequences=True)(time_series_input)
    x_lstm = LSTM(lstm_sizes[1])(x_lstm)
    x_lstm = Dropout(lstm_dropout)(x_lstm)
    lstm_output = Dense(128, activation="relu")(x_lstm)

    # BERT branch
    bert_output = bert_layer([bert_input_ids, bert_attention_mask])

    # Static features branch
    x_mlp = Dense(mlp_sizes[0], activation="relu")(static_input)
    x_mlp = Dropout(mlp_dropout)(x_mlp)
    mlp_output = Dense(128, activation="relu")(x_mlp)

    # Fusion
    fusion_layer = FusionLayer(
        num_heads=num_attention_heads,
        key_dim=fusion_key_dim,
        dropout_rate=fusion_dropout,
        fusion_dim=128
    )
    fusion_output = fusion_layer([lstm_output, bert_output, mlp_output])

    # Final classification layers
    x = Dense(64, activation="relu")(fusion_output)
    x = Dropout(0.2)(x)
    output = Dense(1, activation="sigmoid", name="prediction")(x)

    # Create model with dictionary inputs
    model = Model(
        inputs={
            'time_series_input': time_series_input,
            'bert_input_ids': bert_input_ids,
            'bert_attention_mask': bert_attention_mask,
            'static_input': static_input
        },
        outputs=output
    )

    return model

# Enable mixed precision to reduce memory usage
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

def load_data(data_dir, split):
    data = np.load(os.path.join(data_dir, f'{split}_data.npz'))

    time_series = data['time_series_features'].astype(np.float32)  # Convert to float32
    if len(time_series.shape) == 2:
        time_series = np.expand_dims(time_series, axis=1)
        time_series = np.repeat(time_series, 128, axis=1)

    # Ensure all inputs are in the correct dtype
    return {
        'time_series_features': time_series,
        'bert_input_ids': data['bert_input_ids'].astype(np.int32),
        'bert_attention_mask': data['bert_attention_mask'].astype(np.int32),
        'static_features': data['static_features'].astype(np.float32),
        'labels': data['labels'].astype(np.float32)
    }

def create_dataset(data, batch_size, shuffle=False):
    dataset = tf.data.Dataset.from_tensor_slices(({
        'time_series_input': data['time_series_features'],
        'bert_input_ids': data['bert_input_ids'],
        'bert_attention_mask': data['bert_attention_mask'],
        'static_input': data['static_features'],
    }, data['labels']))

    if shuffle:
        dataset = dataset.shuffle(1000)

    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)


def main():
    # Load and prepare data
    print("Loading training data...")
    train_data = load_data(data_dir, 'test')
    print("Loading validation data...")
    val_data = load_data(data_dir, 'val')

    # Create datasets
    train_dataset = create_dataset(train_data, batch_size, shuffle=True)
    val_dataset = create_dataset(val_data, batch_size, shuffle=False)

    # Print shapes for debugging
    for batch in train_dataset.take(1):
        inputs, labels = batch
        print("\nInput shapes:")
        for key, tensor in inputs.items():
            print(f"{key}: {tensor.shape}")
        print(f"Labels shape: {labels.shape}")

    # Build and compile model
    print("\nBuilding model...")
    model = build_model()
    model.summary()

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    model.compile(
        optimizer=optimizer,
        loss="binary_crossentropy",
        metrics=[
            "accuracy",
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )

    # Create model directory and callbacks
    os.makedirs(model_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(model_dir, "best_model.keras"),
            monitor="val_loss",
            save_best_only=True,
            verbose=1
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=5,
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=3,
            verbose=1
        ),
        tf.keras.callbacks.TensorBoard(
            log_dir=os.path.join(model_dir, "logs", timestamp)
        ),
    ]

    # Train model
    print("\nStarting training...")
    try:
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )

        # Save final model with custom objects
        final_model_path = os.path.join(model_dir, "final_model.keras")

        # Saving method that preserves custom layers
        custom_objects = {
            'BERTLayer': BERTLayer,
            'FusionLayer': FusionLayer
        }
        tf.keras.models.save_model(model, final_model_path, custom_objects=custom_objects)
        print(f"\nModel training complete. Final model saved at {final_model_path}")
        # Save model configuration
        model_config_path = os.path.join(model_dir, "model_config.json")
        model_config = {
            "bert_model_name": bert_model_name,
            "sequence_length": sequence_length,
            "time_series_features": time_series_features,
            "static_features": static_features,
            "lstm_sizes": lstm_sizes,
            "lstm_dropout": lstm_dropout,
            "mlp_sizes": mlp_sizes,
            "mlp_dropout": mlp_dropout,
            "fusion_dropout": fusion_dropout,
            "num_attention_heads": num_attention_heads,
            "fusion_key_dim": fusion_key_dim,
            "learning_rate": learning_rate
        }

        with open(model_config_path, 'w') as f:
            json.dump(model_config, f, indent=4)
        print(f"Model configuration saved at {model_config_path}")

    except Exception as e:
        print(f"\nTraining failed with error: {str(e)}")
        raise


if __name__ == "__main__":
    main()