<a href="https://colab.research.google.com/github/mohammadreza-mohammadi94/Deep-Learning-Projects/blob/main/NER-Medical-Texts/ner_medical_texts_model_subclassing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [3]:
!pip install -q datasets==3.6.0
!pip install -q seqeval

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [22]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, losses, callbacks, regularizers
from datasets import load_dataset
from tensorflow.keras.preprocessing.sequence import pad_sequences
from seqeval.metrics import f1_score, classification_report

# Configuration

In [23]:
MAX_LEN = 64
EMBEDDING_DIM = 128
RNN_UNITS = 128
BATCH_SIZE = 32
EPOCHS = 15

# Data Loading & Preprocessing

In [24]:
def prepare_data():
    """
    Load dataset and convert tokens to int IDS with padding
    """
    print(">> Loading BC5CDR Dataset...")
    dataset = load_dataset("tner/bc5cdr")

    # Extract tokens & tags
    train_sent, train_tags = dataset['train']['tokens'], dataset['train']['tags']
    test_sent, test_tags = dataset['test']['tokens'], dataset['test']['tags']
    val_sent, val_tags = dataset['validation']['tokens'], dataset['validation']['tags']

    # Build Vocab
    vocab = set(w for s in train_sent for w in s)
    word2idx = {w: i for i, w in enumerate(sorted(list(vocab)))}
    word2idx["[PAD]"] = 0
    word2idx["[UNK]"] = 1

    # Encoder and pad
    def encode(sentences, tags_list):
        X = [[word2idx.get(w, 1) for w in s] for s in sentences]
        X_p = pad_sequences(
            X, maxlen=MAX_LEN, padding='post'
        )
        y_p = pad_sequences(
            tags_list, maxlen=MAX_LEN, padding='post', value=0
        )
        return X_p, np.array(y_p)

    X_train, y_train = encode(train_sent, train_tags)
    X_test, y_test = encode(test_sent, test_tags)
    X_val, y_val = encode(val_sent, val_tags)

    return (X_train, y_train), (X_val, y_val), (X_test, y_test), word2idx, test_sent


# Model Definition

In [25]:
class BioNERModel(tf.keras.Model):
    """
    Bi-directional LSTM for medical entity recognition.
    Inherits from tf.keras.Model for maximum flexibility.
    """
    def __init__(self, vocab_size, num_tags):
        super(BioNERModel, self).__init__()
        self.embedding = layers.Embedding(vocab_size, EMBEDDING_DIM, mask_zero=True)
        self.spatial_dropout = layers.SpatialDropout1D(0.3)
        self.bi_lstm = layers.Bidirectional(
            layers.LSTM(
                        RNN_UNITS,
                        return_sequences=True,
                        kernel_regularizer=regularizers.l2(1e-5), # Prevent weight explosion
                        recurrent_regularizer=regularizers.l2(1e-5)
                        )
            )
        self.dropout = layers.Dropout(0.3)
        self.classifier = layers.Dense(num_tags)

    def call(self, inputs, training=False):
        x = self.embedding(inputs)
        x = self.spatial_dropout(x, training=training)
        x = self.bi_lstm(x)
        x = self.dropout(x, training=training)
        return self.classifier(x)


# Define Weighted Loss
def get_weighted_loss(class_weights):
    """
    Computes weighted cross-entropy while normalizing by the sum of weights.
    Higher weights on index 1-4 increase the Recall for medical entities.
    """
    def loss_fn(y_true, y_pred):
        cce = losses.SparseCategoricalCrossentropy(from_logits=True, reduction=None)
        loss_val = cce(y_true, y_pred)

        weights = tf.gather(class_weights, tf.cast(y_true, tf.int32))
        weighted_loss = loss_val * weights

        # Safe normalization to prevent division by zero
        return tf.reduce_sum(weighted_loss) / (tf.reduce_sum(weights) + 1e-8)

    return loss_fn

# Main Pipeline

In [26]:
def main():
    # Load data
    (X_train, y_train), (X_val, y_val), (X_test, y_test), word2idx, original_test_tokens = prepare_data()

    # Calculate Class Weights
    # Standard 'O' gets 1.0, Chemicals get 6.0, Diseases get 8.0
    # Higher weights for rarer entities to force the model to prioritize them
    class_weights = tf.constant([1.0, 5.0, 5.0, 7.0, 7.0], dtype=tf.float32)

    # Instantiate Model
    num_tags = 5 # O, B-Chem, I-Chem, B-Dis, I-Dis
    model = BioNERModel(len(word2idx), 5)

    # Compile with SparseCategoricalAccuracy to avoid InvalidArgumentError
    # Using explicit metric class is safer for Many-to-Many tasks
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=get_weighted_loss(class_weights),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
    )

    # Callbacks for training efficiency
    cb_list = [
        callbacks.EarlyStopping(
            monitor='val_loss', patience=4, restore_best_weights=True),
        callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=0.5, patience=1)
    ]

    # Model Training
    print("\n>> Training Model...")
    model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=cb_list
    )

    # EVALUATION (SEQEVAL)
    print("\n>> Evaluating on Test Set...")
    logits = model.predict(X_test)
    preds = np.argmax(logits, axis=-1)

    # Mapping for conversion
    tag_idx_to_name = {0: "O", 1: "B-CHM", 2: "I-CHM", 3: "B-DIS", 4: "I-DIS"}

    def get_real_tags(y_indices, original_sentences):
        all_tags = []
        for i, sentence in enumerate(original_sentences):
            # Convert only valid tokens (ignore padding at the end)
            length = len(sentence)
            tags = [tag_idx_to_name[idx] for idx in y_indices[i][:length]]
            all_tags.append(tags)
        return all_tags

    true_tags = get_real_tags(y_test, original_test_tokens)
    pred_tags = get_real_tags(preds, original_test_tokens)

    print("\nDetailed Classification Report:")
    print(classification_report(true_tags, pred_tags))


# Execution
main()

>> Loading BC5CDR Dataset...

>> Training Model...
Epoch 1/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 202ms/step - accuracy: 0.9544 - loss: 0.4959 - val_accuracy: 0.9758 - val_loss: 0.2557 - learning_rate: 0.0010
Epoch 2/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 225ms/step - accuracy: 0.9840 - loss: 0.1429 - val_accuracy: 0.9820 - val_loss: 0.3139 - learning_rate: 0.0010
Epoch 3/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 199ms/step - accuracy: 0.9451 - loss: 0.0704 - val_accuracy: 0.2967 - val_loss: 0.3625 - learning_rate: 5.0000e-04
Epoch 4/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 218ms/step - accuracy: 0.3122 - loss: 0.0485 - val_accuracy: 0.2969 - val_loss: 0.3712 - learning_rate: 2.5000e-04
Epoch 5/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 220ms/step - accuracy: 0.3159 - loss: 0.0451 - val_accuracy: 0.2976 - val_loss: 0.4010 - learning_rate: 1.2