# Custom NER for Identifying Diseases and Treatments

This notebook implements a custom Named Entity Recognition (NER) system to identify diseases and treatments from a medical dataset. The dataset is provided in tokenized format, where each word is associated with a label:
- `O` indicates "Other"
- `D` indicates "Disease"
- `T` indicates "Treatment"

## Steps in this Notebook
1. **Data Preprocessing:** Reconstruct sentences and labels from the tokenized dataset.
2. **Concept Identification:** Identify key concepts in the dataset using PoS tagging.
3. **Defining Features for CRF:** Create features for training the CRF model.
4. **Getting Features for Words and Sentences:** Apply feature definitions to all sentences.
5. **Defining Input and Target Variables:** Prepare input features and labels for training and testing.
6. **Building the Model:** Train the CRF model on the training dataset.
7. **Evaluating the Model:** Evaluate the model on the test dataset using F1 score and classification metrics.
8. **Identifying Diseases and Predicted Treatments:** Extract relationships between diseases and treatments using the trained model.


In [1]:
!pip install -q tabulate==0.9.0 spacy==3.8.3 sklearn-crfsuite==0.5.0

## Step 1: Data Preprocessing
The dataset is provided in tokenized format, where each word is stored on a separate line, and sentences are separated by blank lines. In this step, I will:
1. Reconstruct sentences and labels from the training and testing datasets.
2. Count the number of sentences and labels in the processed datasets.


In [2]:
# Paths to the dataset files
train_sent_path = 'data/train_sent'
train_label_path = 'data/train_label'
test_sent_path = 'data/test_sent'
test_label_path = 'data/test_label'

In [3]:
import re

def filter_no_entity_statements(sentences, labels):
    """
    Filters out sentences and their corresponding labels that do not contain any entity labels ('D' or 'T').

    Parameters:
    sentences (list): A list of tokenized sentences, where each sentence is a list of words.
    labels (list): A list of label sequences, where each sequence corresponds to a sentence.

    Returns:
    tuple: Filtered lists of sentences and labels.
    """
    filtered_sentences = []
    filtered_labels = []

    for sentence, label_sequence in zip(sentences, labels):
        # Check if the label sequence contains any 'D' or 'T' labels
        if any(label in {'D', 'T'} for label in label_sequence):
            filtered_sentences.append(sentence)
            filtered_labels.append(label_sequence)

    return filtered_sentences, filtered_labels

def clean_sentence(sentence, labels):
    """
    Cleans a tokenized sentence by normalizing, removing noise, and correcting token splits.
    Adjusts the labels accordingly to ensure alignment with cleaned tokens.

    Parameters:
    sentence (list): A list of tokens (words) in the sentence.
    labels (list): A list of labels corresponding to the tokens.

    Returns:
    tuple: A cleaned list of tokens and their adjusted labels.
    """
    cleaned_sentence = []
    cleaned_labels = []

    for token, label in zip(sentence, labels):
        # Remove special characters except hyphens in compound words
        cleaned_token = re.sub(r'[^\w\-]', '', token)

        # Preserve meaningful hyphenated words, otherwise split on hyphen
        if '-' in cleaned_token and len(cleaned_token.split('-')) > 1:
            sub_tokens = cleaned_token.split('-')
            if all(len(sub) > 1 for sub in sub_tokens):  # If all parts are meaningful
                cleaned_sentence.append(cleaned_token)
                cleaned_labels.append(label)
            else:
                for sub_token in sub_tokens:
                    cleaned_sentence.append(sub_token)
                    cleaned_labels.append(label)
        else:
            # Lowercase all tokens except acronyms or proper nouns
            if cleaned_token.isupper() and len(cleaned_token) > 1:  # Acronyms like "HIV"
                cleaned_sentence.append(cleaned_token)
                cleaned_labels.append(label)
            else:
                cleaned_sentence.append(cleaned_token.lower())
                cleaned_labels.append(label)

    # Filter out empty tokens or noise tokens
    final_sentence = []
    final_labels = []
    for tok, lab in zip(cleaned_sentence, cleaned_labels):
        if tok and re.search(r'[a-zA-Z]', tok):  # Ensure valid tokens remain
            final_sentence.append(tok)
            final_labels.append(lab)

    return final_sentence, final_labels


def clean_dataset(sentences, labels):
    """
    Cleans a dataset of tokenized sentences and adjusts their corresponding labels.

    Parameters:
    sentences (list): A list of sentences, where each sentence is a list of tokens.
    labels (list): A list of label sequences, where each sequence corresponds to a sentence.

    Returns:
    tuple: A cleaned list of sentences and their adjusted labels.
    """
    cleaned_sentences = []
    cleaned_labels = []

    for sentence, label_sequence in zip(sentences, labels):
        cleaned_sentence, cleaned_label = clean_sentence(sentence, label_sequence)
        cleaned_sentences.append(cleaned_sentence)
        cleaned_labels.append(cleaned_label)

    return cleaned_sentences, cleaned_labels

def process_data(file_path):
    """
    Read a dataset file and reconstruct sentences or labels.

    Parameters:
    file_path (str): Path to the file containing data in tokenized format.

    Returns:
    list: A list of sentences or labels reconstructed from the file.
    """
    sentences = []
    current_sentence = []

    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line == "":  # A blank line indicates the end of a sentence
                if current_sentence:
                    sentences.append(current_sentence)
                    current_sentence = []
            else:
                current_sentence.append(line)
        if current_sentence:  # Add the last sentence if the file does not end with a blank line
            sentences.append(current_sentence)

    return sentences

In [4]:
# Process train and test datasets
train_sentences = process_data(train_sent_path)
train_labels = process_data(train_label_path)
test_sentences = process_data(test_sent_path)
test_labels = process_data(test_label_path)

# Verify by printing counts
print(f"Number of sentences in train dataset: {len(train_sentences)}")
print(f"Number of label lines in train dataset: {len(train_labels)}")
print(f"Number of sentences in test dataset: {len(test_sentences)}")
print(f"Number of label lines in test dataset: {len(test_labels)}")

Number of sentences in train dataset: 2599
Number of label lines in train dataset: 2599
Number of sentences in test dataset: 1056
Number of label lines in test dataset: 1056


### Clean data upfront for noise

In [5]:
# train_sentences, train_labels = clean_dataset(train_sentences, train_labels)
# test_sentences, test_labels = clean_dataset(test_sentences, test_labels)

# train_sentences, train_labels = filter_no_entity_statements(sentences=train_sentences, labels=train_labels)
# test_sentences, test_labels = filter_no_entity_statements(sentences=test_sentences, labels=test_labels)

# Verify by printing counts
print(f"Number of sentences in train dataset: {len(train_sentences)}")
print(f"Number of label lines in train dataset: {len(train_labels)}")

print(f"Number of sentences in test dataset: {len(test_sentences)}")
print(f"Number of label lines in test dataset: {len(test_labels)}")

Number of sentences in train dataset: 2599
Number of label lines in train dataset: 2599
Number of sentences in test dataset: 1056
Number of label lines in test dataset: 1056


NOTE : I concluded on not cleaning any data upfront as all attempts were detremental to model performance.

In [6]:
import random

def print_sentences_with_labels(sentences, labels, num_sentences=5):
    """
    Prints a specified number of random sentences along with their labels.

    Parameters:
    sentences (list): List of tokenized sentences.
    labels (list): List of label sequences corresponding to the sentences.
    num_samples (int): Number of random samples to print. Default is 5.
    """
    assert len(sentences) == len(labels), "Sentences and labels must have the same length."
    
    random_indices = random.sample(range(len(sentences)), num_sentences)
    for idx in random_indices:
        sentence = sentences[idx]
        label_sequence = labels[idx]
        formatted_output = " ".join([f"[{label}]{word}" for word, label in zip(sentence, label_sequence)])
        print (f'Sentence # {idx}')
        print(formatted_output)
        print()

print_sentences_with_labels(train_sentences, train_labels, num_sentences=5)

Sentence # 381
[O]The [O]women [O]presented [O]with [D]cardiovascular [D]symptoms [O]or [O]a [D]heart [D]murmur

Sentence # 2436
[T]Surgical [T]treatment [O]for [D]lung [D]hydatid [D]disease

Sentence # 1489
[O]Breastfeeding [O]and [O]catch-up [O]growth [O]in [O]infants [O]born [O]small [O]for [O]gestational [O]age

Sentence # 1738
[D]Vascular [D]Parkinson [D]syndromes [O]: [O]a [O]controversial [O]concept

Sentence # 2271
[O]Refugees [O]with [D]crawling [D]lice [O]were [O]treated [O]with [O]a [T]pediculicide [T]containing [T]1 [T]% [T]permethrin



## Step 2: Concept Identification
In this step, I will identify key concepts (e.g., diseases and treatments) from the dataset by:
1. Performing Part-of-Speech (PoS) tagging on the text data.
2. Extracting tokens with PoS tags corresponding to nouns (`NOUN` and `PROPN`).
3. Counting the frequency of these tokens across the entire dataset (both training and testing data).
4. Printing the top 25 most frequently mentioned concepts.


In [7]:
# For formatting outputs
from tabulate import tabulate

In [8]:
import spacy
from spacy.cli import download
try:
    spacy.load("en_core_web_sm")
except:
    download("en_core_web_sm")
    exit(1)

In [9]:
import spacy
from collections import Counter

# Load spaCy model for PoS tagging
nlp = spacy.load("en_core_web_sm")

def extract_noun_phrases(sentences):
    """
    Extract nouns and proper nouns from the given sentences.

    Parameters:
    sentences (list): A list of tokenized sentences.

    Returns:
    list: A list of nouns and proper nouns extracted from the sentences.
    """
    nouns = []
    for sentence in sentences:
        doc = nlp(" ".join(sentence))
        for token in doc:
            if token.pos_ in ["NOUN", "PROPN"]:  # Select nouns and proper nouns
                nouns.append(token.text.lower())
    return nouns

# Combine training and testing sentences for concept identification
all_sentences = train_sentences + test_sentences

# Extract nouns and calculate their frequencies
nouns = extract_noun_phrases(all_sentences)
noun_frequencies = Counter(nouns)

In [10]:
# Print the top 25 most common nouns
table_data = [[i, concept, freq] for i, (concept, freq) in enumerate(noun_frequencies.most_common(25), start=1)]
print(tabulate(table_data, headers=['#', "Concept", "Frequency"], tablefmt="github"))

|   # | Concept      |   Frequency |
|-----|--------------|-------------|
|   1 | patients     |         507 |
|   2 | treatment    |         304 |
|   3 | %            |         247 |
|   4 | cancer       |         211 |
|   5 | therapy      |         177 |
|   6 | study        |         174 |
|   7 | disease      |         149 |
|   8 | cell         |         142 |
|   9 | lung         |         118 |
|  10 | results      |         116 |
|  11 | group        |         111 |
|  12 | effects      |          99 |
|  13 | gene         |          91 |
|  14 | chemotherapy |          91 |
|  15 | use          |          87 |
|  16 | effect       |          82 |
|  17 | women        |          81 |
|  18 | analysis     |          76 |
|  19 | risk         |          74 |
|  20 | surgery      |          73 |
|  21 | cases        |          72 |
|  22 | p            |          72 |
|  23 | rate         |          68 |
|  24 | survival     |          67 |
|  25 | response     |          66 |


In [11]:
# Print the top 25 least common nouns
table_data = [[i, concept, freq] for i, (concept, freq) in enumerate(noun_frequencies.most_common()[:-26:-1], start=1)]
print(tabulate(table_data, headers=['#', "Concept", "Frequency"], tablefmt="github"))

|   # | Concept           |   Frequency |
|-----|-------------------|-------------|
|   1 | abortion          |           1 |
|   2 | myeloma           |           1 |
|   3 | tandem            |           1 |
|   4 | occlusions        |           1 |
|   5 | thrombogenicity   |           1 |
|   6 | vasoreactivity    |           1 |
|   7 | epoetin           |           1 |
|   8 | timolol           |           1 |
|   9 | tartrate          |           1 |
|  10 | brimonidine       |           1 |
|  11 | poliovirus        |           1 |
|  12 | poliomyelitis     |           1 |
|  13 | celecoxib         |           1 |
|  14 | formoterol        |           1 |
|  15 | dry               |           1 |
|  16 | levodopa          |           1 |
|  17 | methyltransferase |           1 |
|  18 | catechol          |           1 |
|  19 | tolcapone         |           1 |
|  20 | colonoscopy       |           1 |
|  21 | malathion         |           1 |
|  22 | spoon             |       

## Step 3: Defining Features for CRF

### Explanation of Feature Selection:

1. **Word-Level Features**: Capture fundamental properties of words, such as their lowercased form, prefixes, and suffixes, to identify linguistic patterns.

2. **POS and Dependency Features**: Leverage spaCy's Part-of-Speech tagging and dependency parsing to understand syntactic roles and relationships within the sentence.

3. **Contextual Features**: Incorporate information about preceding and following words, including their grammatical roles, to provide a comprehensive context for each word.

4. **N-grams**: Include both bigrams and trigrams to capture sequential patterns and relationships across consecutive words, crucial for identifying compound terms and multi-word entities.

5. **Phrase Boundary Features**: Detect syntactic indicators, such as adjectives and compounds, that suggest the beginning or continuation of an entity phrase.

6. **Sentence Position Indicators**: Use `BOS` (Beginning of Sentence) and `EOS` (End of Sentence) flags to capture word position within the sentence, aiding in boundary detection for entities.

7. **Placeholder Features**: Handle missing or inapplicable features consistently with placeholders (e.g., `<NONE>`, `<START>`, `<END>`) to maintain uniformity across feature dictionaries.

In [12]:
def word2features(sentence, i):
    """
    Generate features for a single word in a sentence with consistent context relationships.
    
    Parameters:
    sentence (list): A list of tokens (words) in the sentence.
    i (int): Index of the word in the sentence.

    Returns:
    dict: A dictionary of features for the word.
    """
    word = sentence[i]
    features = {
        'word.lower()': word.lower(),
        'word.prefix': word[:3].lower(),
        'word.suffix': word[-3:].lower(),
    }

    # PoS tagging using spaCy
    doc = nlp(" ".join(sentence))
    token = doc[i]
    features['pos'] = token.pos_
    features['dep'] = token.dep_
    features['head'] = token.head.text.lower()
    features['head.pos'] = token.head.pos_

    # Previous word features
    if i > 0:
        prev_token = doc[i - 1]
        features.update({
            'prev_word.lower()': sentence[i - 1].lower(),
            'prev_word.pos': prev_token.pos_,
            'prev_word.dep': prev_token.dep_,
        })
        features['bigram.prev'] = sentence[i - 1].lower() + '_' + word.lower()
    else:
        features.update({
            'prev_word.lower()': '<START>',
            'prev_word.pos': '<NONE>',
            'prev_word.dep': '<NONE>',
            'bigram.prev': '<NONE>',
        })

    # Next word features
    if i < len(sentence) - 1:
        next_token = doc[i + 1]
        features.update({
            'next_word.lower()': sentence[i + 1].lower(),
            'next_word.pos': next_token.pos_,
            'next_word.dep': next_token.dep_,
        })
        features['bigram.next'] = word.lower() + '_' + sentence[i + 1].lower()
    else:
        features.update({
            'next_word.lower()': '<END>',
            'next_word.pos': '<NONE>',
            'next_word.dep': '<NONE>',
            'bigram.next': '<NONE>',
        })

    # Trigram features
    if i > 1:
        features['trigram.prev'] = sentence[i - 2].lower() + '_' + sentence[i - 1].lower() + '_' + word.lower()
    else:
        features['trigram.prev'] = '<NONE>'
    if i < len(sentence) - 2:
        features['trigram.next'] = word.lower() + '_' + sentence[i + 1].lower() + '_' + sentence[i + 2].lower()
    else:
        features['trigram.next'] = '<NONE>'

    # Phrase boundary features
    features['is_descriptor'] = token.dep_ in ['amod', 'compound']

    # Sentence boundary features
    features['BOS'] = (i == 0)
    features['EOS'] = (i == len(sentence) - 1)

    return features


def sent2features(sentence):
    """
    Generate features for all words in a sentence.

    Parameters:
    sentence (list): A list of tokens (words) in the sentence.

    Returns:
    list: A list of dictionaries, each containing features for a word.
    """
    return [word2features(sentence, i) for i in range(len(sentence))]


## Step 4: Getting Features for Words and Sentences
Using the feature extraction functions defined earlier, I will generate features for all sentences in the training and testing datasets. This involves:
1. Applying `sent2features` to each sentence.
2. Preparing the data in a format suitable for training and evaluating the CRF model.


In [13]:
# TODO - run feature extraction in parallel to speed up
# from joblib import Parallel, delayed

# def prepare_features_and_labels(sentences, labels):
#     """
#     Generate features and labels for all sentences in the dataset.

#     Parameters:
#     sentences (list): A list of sentences, where each sentence is a list of tokens (words).
#     labels (list): A list of label sequences, where each sequence corresponds to a sentence.

#     Returns:
#     tuple: A tuple containing:
#         - features (list): A list of feature dictionaries for each sentence.
#         - labels (list): A list of label sequences for each sentence.
#     """
#     # Parallelize the sentence feature extraction using threads
#     features = Parallel(n_jobs=-1, prefer="threads")(delayed(sent2features)(sentence) for sentence in sentences)
#     return features, labels

# # Prepare features and labels for the train dataset
# train_features, train_labels = prepare_features_and_labels(train_sentences, train_labels)

# # Prepare features and labels for the test dataset
# test_features, test_labels = prepare_features_and_labels(test_sentences, test_labels)

In [14]:
def prepare_features_and_labels(sentences, labels):
    """
    Generate features and labels for all sentences in the dataset.

    Parameters:
    sentences (list): A list of sentences, where each sentence is a list of tokens (words).
    labels (list): A list of label sequences, where each sequence corresponds to a sentence.

    Returns:
    tuple: A tuple containing:
        - features (list): A list of feature dictionaries for each sentence.
        - labels (list): A list of label sequences for each sentence.
    """
    features = [sent2features(sentence) for sentence in sentences]
    return features, labels

# Prepare features and labels for the train dataset
train_features, train_labels = prepare_features_and_labels(train_sentences, train_labels)

# Prepare features and labels for the test dataset
test_features, test_labels = prepare_features_and_labels(test_sentences, test_labels)

## Step 5: Defining Input and Target Variables
In this step, I will define the input features and target labels for the CRF model:
1. Input Variables: Features extracted for each word in the sentences.
2. Target Variables: Corresponding labels (`O`, `D`, `T`) for each word in the sentences.

Additionally, I will display a random example from the training dataset in a tabular format to inspect the features and labels.


In [15]:
import random
from tabulate import tabulate

# Display the number of samples for training and testing
print(f"Number of training samples: {len(train_features)}")
print(f"Number of testing samples: {len(test_features)}")

Number of training samples: 2599
Number of testing samples: 1056


In [16]:
# Function to display features and labels in a tabular format
def display_random_example(features, labels, sentences):
    """
    Display a random example from the dataset in a tabular format.

    Parameters:
    features (list): List of feature dictionaries for the dataset.
    labels (list): List of label sequences corresponding to the features.
    sentences (list): List of tokenized sentences.
    """
    # Select a random example
    random_index = random.randint(0, len(features) - 1)
    example_features = features[random_index]
    example_labels = labels[random_index]
    example_sentence = sentences[random_index]
    
    # Prepare the data for tabulation
    table_data = []
    for i, (word, label, feature) in enumerate(zip(example_sentence, example_labels, example_features)):
        row = [i + 1, word, label] + [f"{key}: {value}" for key, value in feature.items()]
        table_data.append(row)
    
    # Define headers for the table
    headers = ["Index", "Word", "Label"] + [f"Feature {i + 1}" for i in range(len(example_features[0]))]

    # Display the table using tabulate
    print(f"\nRandom Example from Training Set (Index {random_index}):")
    print(tabulate(table_data, headers=headers, tablefmt="github"))

# Display a random example from the training set
display_random_example(train_features, train_labels, train_sentences)


Random Example from Training Set (Index 1228):
|   Index | Word          | Label   | Feature 1                   | Feature 2        | Feature 3        | Feature 4   | Feature 5     | Feature 6        | Feature 7      | Feature 8                        | Feature 9             | Feature 10              | Feature 11                    | Feature 12                       | Feature 13            | Feature 14              | Feature 15                    | Feature 16                                | Feature 17                                | Feature 18           | Feature 19   | Feature 20   |
|---------|---------------|---------|-----------------------------|------------------|------------------|-------------|---------------|------------------|----------------|----------------------------------|-----------------------|-------------------------|-------------------------------|----------------------------------|-----------------------|-------------------------|-------------------------------|

## Step 6: Building the Model

In this step, I perform model selection by evaluating a Conditional Random Field (CRF) model with various combinations of hyperparameters. These include optimization algorithms (`lbfgs`, `arow`, `pa`), regularization coefficients (`c1`, `c2`), and maximum iterations. 

The goal is to identify the best-performing model based on the F1-score for the `D` (Disease) label. The best model is then assigned to the variable `crf_model` for further use.


I am not evaluating `lbfgs` as I already tested it to conclude it is not a candidate.

In [47]:
from sklearn_crfsuite import CRF
from sklearn_crfsuite.metrics import flat_classification_report

# Function to train and evaluate a CRF model with specified parameters
def train_crf(algorithm, c1, c2, max_iterations, train_features, train_labels, test_features, test_labels):
    """
    Train and evaluate a CRF model.

    Parameters:
    - algorithm (str): The optimization algorithm (e.g., 'lbfgs', 'arow', 'pa').
    - c1 (float): Coefficient for L1 regularization (only for lbfgs).
    - c2 (float): Coefficient for L2 regularization (only for lbfgs).
    - max_iterations (int): Maximum number of iterations for training.
    - train_features (list): Features for the training data.
    - train_labels (list): Labels for the training data.
    - test_features (list): Features for the test data.
    - test_labels (list): Labels for the test data.

    Returns:
    - f1_score_d (float): F1-score for the 'D' (Disease) label on the test set.
    - model: The trained CRF model.
    - report: Classification report.
    """
    if algorithm in ["arow", "pa"]:
        crf = CRF(
            algorithm=algorithm,
            max_iterations=max_iterations,
            all_possible_transitions=True
        )
    else:
        crf = CRF(
            algorithm=algorithm,
            c1=c1,
            c2=c2,
            max_iterations=max_iterations,
            all_possible_transitions=True
        )

    crf.fit(X=train_features, y=train_labels)
    predictions = crf.predict(test_features)
    report = flat_classification_report(test_labels, predictions, output_dict=True)
    f1_score_d = report['D']['f1-score']
    return f1_score_d, crf, report

In [48]:
# Define parameters for grid search
algorithms = ['arow', 'pa']
max_iterations_values = [100, 125, 150, 175, 200, 250, 300]

parameter_combinations = (
    [(alg, None, None, max_iter) for alg in ['arow', 'pa'] for max_iter in max_iterations_values]
)

# Grid search for best model
best_f1_d = 0
best_model = None
best_params = None
best_report = None

for algorithm, c1, c2, max_iterations in parameter_combinations:
    if algorithm in ["arow", "pa"]:
        print(f"Evaluating: Algorithm={algorithm}, max_iterations={max_iterations}")
    # else:
    #     print(f"Evaluating: Algorithm={algorithm}, c1={c1}, c2={c2}, max_iterations={max_iterations}")
    try:
        f1_d, model, report = train_crf(
            algorithm, c1, c2, max_iterations,
            train_features, train_labels, test_features, test_labels
        )
        if f1_d > best_f1_d:
            best_f1_d = f1_d
            best_model = model
            best_params = (algorithm, c1, c2, max_iterations)
            best_report = report
    except Exception as e:
        print(f"Error with combination Algorithm={algorithm}, c1={c1}, c2={c2}, max_iterations={max_iterations}: {e}")

best_model

Evaluating: Algorithm=arow, max_iterations=100
Evaluating: Algorithm=arow, max_iterations=125
Evaluating: Algorithm=arow, max_iterations=150
Evaluating: Algorithm=arow, max_iterations=175
Evaluating: Algorithm=arow, max_iterations=200
Evaluating: Algorithm=arow, max_iterations=250
Evaluating: Algorithm=arow, max_iterations=300
Evaluating: Algorithm=pa, max_iterations=100
Evaluating: Algorithm=pa, max_iterations=125
Evaluating: Algorithm=pa, max_iterations=150
Evaluating: Algorithm=pa, max_iterations=175
Evaluating: Algorithm=pa, max_iterations=200
Evaluating: Algorithm=pa, max_iterations=250
Evaluating: Algorithm=pa, max_iterations=300


In [60]:
# Hard coded to avoid confusion 
crf_model = CRF(
            algorithm='pa',
            max_iterations=150,
            all_possible_transitions=True
)

crf_model.fit(X=train_features, y=train_labels)
crf_model

## Step 7: Evaluating the Model
In this step, I will evaluate the CRF model's performance using the test dataset. The model will:
1. Predict labels for each token in the test sentences.
2. Calculate the F1 score for overall performance.
3. Display a detailed classification report to analyze the model's predictions for each label (`O`, `D`, `T`).


In [61]:
from sklearn_crfsuite import metrics

# Predict labels for the test dataset
test_predictions = crf_model.predict(test_features)

# Evaluate the model using the F1 score
# f1_score = metrics.flat_f1_score(
#     test_labels, test_predictions, average='weighted', labels=crf_model.classes_
# )

# print(f"F1 Score: {f1_score:.3f}")

# Print classification report for detailed evaluation
classification_report = metrics.flat_classification_report(
    test_labels, test_predictions, labels=crf_model.classes_, digits=2
)
print("Classification Report:")
print(classification_report)

Classification Report:
              precision    recall  f1-score   support

           O       0.94      0.98      0.96     16127
           D       0.79      0.64      0.71      1450
           T       0.80      0.56      0.66      1041

    accuracy                           0.93     18618
   macro avg       0.85      0.73      0.78     18618
weighted avg       0.92      0.93      0.92     18618



### **Model Performance Summary**

#### **Overall Performance**
- **Accuracy**: **0.93**
- **Macro Average F1-Score**: **0.78**
- **Weighted Average F1-Score**: **0.93**

The model demonstrates strong overall performance, especially for non-entity tokens (`O`), while maintaining moderate performance for entities (`D` and `T`). There is room for improvement in recall for minority classes.

---

#### **Key Observations**
1. **Strengths**:
   - Outstanding performance for the `O` class, indicating strong handling of non-entity tokens.
   - High precision for `D` and `T` classes minimizes false positives.

2. **Weaknesses**:
   - Recall for `D` and `T` classes remains suboptimal, leading to missed entities.
   - The lower macro F1-score (**0.78**) highlights the variability in performance between classes.

---

#### **Conclusion**
The model excels in detecting non-entity tokens and maintains good precision for entity classes. However, further improvements in recall for the `D` and `T` classes are necessary to enhance overall entity detection. Refining features or augmenting the training dataset may help address these challenges.

## Step 8: Identifying Diseases and Predicted Treatments
In this step, I will extract diseases and their corresponding treatments from the test dataset using the trained CRF model. The output will be structured as a dictionary, where:
- Each disease (label `D`) is a key.
- Treatments (label `T`) associated with the disease are the values.
Additionally, the results for the specific disease "hereditary retinoblastoma" will be explicitly extracted to meet the assignment's requirements.


In [62]:
from collections import defaultdict
import spacy
import re

# Load spaCy's small English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

def extract_diseases_and_treatments(sentences, predictions):
    """
    Extract diseases and treatments, including descriptive multi-word entities,
    with reduced noise using dependency parsing and validation.

    Parameters:
    sentences (list): A list of tokenized sentences.
    predictions (list): A list of predicted label sequences for each sentence.

    Returns:
    dict: A dictionary where keys are diseases (D) with descriptors and values are lists of treatments (T).
    """
    disease_treatment_map = defaultdict(list)

    def is_valid_entity(entity):
        """
        Validate if the extracted entity is meaningful.

        Parameters:
        entity (str): The entity to validate.

        Returns:
        bool: True if the entity is valid, False otherwise.
        """
        # Disallow entities with invalid characters or overly short entities
        if re.search(r"[()\d]", entity) or len(entity.split()) < 1 or re.match(r"^[A-Z]\.$", entity):
            return False
        # Exclude overly generic terms
        if entity.lower() in ["disease", "cancer", "advanced disease"]:
            return False
        return True

    def is_valid_treatment(treatment):
        """
        Validate if the extracted treatment is meaningful.

        Parameters:
        treatment (str): The treatment to validate.

        Returns:
        bool: True if the treatment is valid, False otherwise.
        """
        # Exclude generic terms and overly short treatments
        invalid_terms = {"and", "with", "the", "of"}
        return treatment.isalpha() and len(treatment) > 2 and treatment.lower() not in invalid_terms

    for sentence, prediction in zip(sentences, predictions):
        # Convert the tokenized sentence into a spaCy Doc object for dependency parsing
        doc = nlp(" ".join(sentence))

        current_disease = None
        for idx, (word, label) in enumerate(zip(sentence, prediction)):
            if label == "D":  # Identify disease
                # Start forming a multi-word entity
                token = doc[idx]
                descriptor = set()

                # Add adjectives or compound descriptors linked to the disease
                for child in token.children:
                    if child.dep_ in ["amod", "compound"] and child.pos_ in ["ADJ", "NOUN"]:
                        descriptor.add(child.text)

                # Check for preceding descriptors in the sentence
                j = idx - 1
                while j >= 0 and prediction[j] == "O":
                    prev_token = doc[j]
                    if prev_token.dep_ in ["amod", "compound"] and prev_token.pos_ in ["ADJ", "NOUN"]:
                        descriptor.add(sentence[j])
                    j -= 1

                # Combine descriptor with the disease
                descriptor_list = list(descriptor)
                current_disease = " ".join(descriptor_list + [word])

                # Include subsequent words labeled as `D` to form a multi-word entity
                k = idx + 1
                while k < len(sentence) and prediction[k] == "D":
                    current_disease += f" {sentence[k]}"
                    k += 1

                # Skip to the last word of the entity
                idx = k - 1

                # Validate disease entity
                if not is_valid_entity(current_disease):
                    current_disease = None

            elif label == "T" and current_disease:  # Associate treatment with the disease
                if is_valid_treatment(word):
                    disease_treatment_map[current_disease].append(word)

    # Post-process the map to remove non-alphabetic treatments and normalize phrases
    final_map = {}
    for disease, treatments in disease_treatment_map.items():
        meaningful_treatments = list(set(t for t in treatments if is_valid_treatment(t)))  # Deduplicate treatments
        if is_valid_entity(disease):
            final_map[disease] = meaningful_treatments

    return final_map

In [63]:
# Extract diseases and treatments using test sentences and predictions
# disease_treatment_dict = extract_diseases_and_treatments(train_sentences, train_labels)
disease_treatment_dict = extract_diseases_and_treatments(test_sentences, test_predictions)

In [64]:
print("\nComplete Disease-Treatment Dictionary:")
table_data = [[i + 1, disease, ', '.join(treatments)] for i, (disease, treatments) in enumerate(disease_treatment_dict.items())]
print(tabulate(table_data, headers=["#", "Disease", "Treatments"], tablefmt="github"))


Complete Disease-Treatment Dictionary:
|   # | Disease                                   | Treatments                                                                                          |
|-----|-------------------------------------------|-----------------------------------------------------------------------------------------------------|
|   1 | gestational diabetes cases                | glycemic, good, control                                                                             |
|   2 | hereditary retinoblastoma                 | radiotherapy                                                                                        |
|   3 | myocardial infarction                     | aspirin                                                                                             |
|   4 | ulcer                                     | treatment, antibiotic, intravenous                                                                  |
|   5 | hemorrhagic stroke          

In [65]:
# import pandas as pd
# disease_treatment_data = pd.DataFrame(
#     data = table_data,
#     columns=["#", "Disease", "Treatments"]
# )

# disease_treatment_data.set_index('#')    

### **Summary of Disease-Treatment Dictionary**

#### **Key Observations**

1. **Accurate Disease-Treatment Pairs**:
   - Diseases like **diabetes gestational cases**, **hereditary retinoblastoma**, and **myocardial infarction** are appropriately matched with treatments like `glycemic control`, `radiotherapy`, and `aspirin`.
   - Common conditions such as **colds**, **viremia**, and **sickle cell disease** have accurate treatments like `antibiotics`, `combination therapy`, and `hydroxyurea`.

2. **Ambiguous Treatments**:
   - Some treatments lack specificity or refer to procedures or anatomical terms:
     - **Intracranial hemorrhage** → `method` (not a specific treatment).
     - **Inflammatory disorders** → `large, intestine` (describes anatomy, not a therapy).
     - **CBD stones** → `exploration, surgical` (procedural, not specific).

3. **Irrelevant or Noisy Entries**:
   - Diseases like **deficiency** and **prevention** have vague or overly generic treatments (`therapy`, `vaccines, oral`), indicating potential noise in the data.

4. **Multi-Word Diseases**:
   - Verbose entries like **brain clinical consecutive advanced nsclc** and **response complete year among overall sclc** include excessive descriptors, complicating interpretation.
   - Treatments such as `cisplatin` and `carboplatin` are relevant but could benefit from cleaner disease names.

5. **Procedural and General Descriptions**:
   - Terms like **primary cancer** and **advanced stage** lack specificity in both disease and treatment descriptions.
   - Treatments like `resection` and `radiation` are procedural rather than therapeutic.

---

#### **Key Findings**
1. **Strengths**:
   - Accurate matching of common diseases such as **diabetes**, **retinoblastoma**, and **colorectal cancer** with appropriate treatments.
   - Identification of valid treatments for less common conditions like **sickle cell disease** and **meningitis**.

2. **Weaknesses**:
   - Noise in disease names (e.g., "advanced stage") and vague treatments (e.g., "method").
   - Procedural terms (e.g., "resection", "exploration") dominate some entries, reducing specificity.
   - General categories (e.g., "deficiency", "prevention") lack detailed context or relevance.

---

#### **Conclusion**
The dictionary captures accurate disease-treatment mappings for many common and rare conditions. However, improving clarity in disease names and filtering procedural or vague treatments is essential for enhancing usability. This refinement could focus on removing noise and ensuring specificity in disease-treatment relationships.

#### Validation as per rubric

In [66]:
# Display results for "hereditary retinoblastoma"
specific_disease = "hereditary retinoblastoma"
specific_treatments = disease_treatment_dict.get(specific_disease, [])

if specific_treatments:
    print(f"Predicted treatments for the disease '{specific_disease}': {', '.join(specific_treatments)}")
else:
    print(f"No treatments found for the disease '{specific_disease}'.")

Predicted treatments for the disease 'hereditary retinoblastoma': radiotherapy
