# ðŸ©º Multi-Label Medical Classification: The OHSUMED Project

1. Project Introduction & Problem Definition
* Objective: To build a robust classification system capable of indexing medical abstracts into the MeSH (Medical Subject Headings) hierarchy.
* The Challenge: Unlike standard classification, this is an Extreme Multi-Label problem involving nearly 3,000 active categories and over 300,000 documents. The project evaluates the trade-offs between statistical keyword importance (TF-IDF), semantic embeddings (BERT/Word2Vec), and sequential modeling (RNNs).

---
2. Data Loading & Initial Exploratory Data Analysis (EDA)
* **Strategy**: We load the OHSUMED dataset (1987-1991 cohort).
* **Key Insight**: Initial analysis revealed a massive label space of 106,799 unique raw tags. A standard document-to-label distribution check confirmed the multi-label nature, with some documents containing up to 30 clinical tags.

---

---
3. Data Integrity & Clinical Label Engineering

**The "Check Tag" Problem**: Analysis showed that the most frequent labels were "**Human**," "**Male**," and "**Female**." These are metadata "Check Tags" that provide zero clinical value for disease classification.

**Normalization Strategy:**
* Implemented a "Fingerprinting" algorithm to standardize punctuation and casing (e.g., matching "Non-U.S. Gov't" to "non u s govt").
* Stripped importance markers (*) and subheadings (/diagnosis).
* Removed non-clinical demographic noise to force the model to learn Pathology and Pharmacology.

---

---
4. The **Pareto Principle**: Data-Driven Dimensionality Reduction

* **The Strategy**: To solve the "Long Tail" problem where thousands of labels appear only once, I applied the Pareto Principle (**80/20** Rule).
* **The Decision**: By calculating the cumulative frequency of clinical assignments, I identified that 2,891 labels cover 80% of all information in the corpus. This data-driven cutoff ensures maximum dataset coverage while maintaining the statistical support required for model convergence

---

---
5. Feature Engineering (The Input X)

**Preprocessing Pipeline**:
To maximize the signal-to-noise ratio, I developed a cleaning function:
* **Feature Fusion**: Concatenated Title and Abstract to ensure high-value keywords are captured.
* **Numerical Normalization**: Replaced all digits with a [num] token to preserve dosage context without bloating the vocabulary.
* **Linguistic Consolidation**: Utilized NLTKâ€™s WordNet Lemmatizer to merge variations like "Infections" and "Infection."
* **Vocabulary Pareto**: Identified that 6,485 unique words account for 90% of the corpus tokens.

---

---
6. Feature Representation & Vectorization Showdown

Now, we transform the cleaned medical abstracts into numerical matrices. Because OHSUMED is a domain-specific dataset (Medical), the choice of "how to represent a word" is the most critical factor in model performance. We implemented a multi-generational comparison to identify the strongest feature signal.

We represent the corpus using five distinct methodologies:
* Statistical Baselines (Sparse):
    * **Bag-of-Words (BoW)**: Simple frequency counts to establish a "Baseline 0."
        * ***TF-IDF (1,2-grams)***: Statistical weighting using single words and bigrams (e.g., "heart failure") to capture specific clinical phrases.
    * **Static Embeddings (Dense)**:
        * ***Word2Vec***: Trained directly on this OHSUMED corpus to capture medical-specific semantic relationships.
        * ***GloVe***: Pre-trained on 6 Billion tokens (Wikipedia/Web) to compare general-world knowledge against medical-specific data.
    * **Dynamic Transformers (SOTA)**:
        * ***BioBERT Embeddings***: 768-dimensional vectors extracted from a model pre-trained on millions of PubMed articles, capturing deep clinical context.

**Engineering Decisions**

To ensure computational efficiency and statistical validity, the following constraints were applied:
* **Vocabulary Pareto Cutoff**: Instead of an arbitrary number, I selected a vocabulary size that covers 90% of all tokens in the corpus, ensuring we ignore rare typos while keeping high-value medical terms.
* **Memory Optimization**: Frequency-based features are stored as Compressed Sparse Row (CSR) matrices, while embeddings are handled as dense NumPy arrays to optimize the RAM-to-Signal ratio.
* **Normalization Alignment**: All vectorizers share the same "Fingerprinted" vocabulary to ensure consistency across models.

---

In [None]:
import numpy as np
import pandas as pd
import scipy.sparse as sp
import joblib
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import EarlyStopping
from collections import Counter
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from gensim.models import Word2Vec
import gensim.downloader as api
from transformers import BertTokenizer, BertModel

# Import our custom source files
from src.text_processing import clean_mesh_smart, feature_cleaner
from src.visualization import plot_raw_mesh_frequencies, plot_label_shift, plot_pareto_coverage
from src.feature_extraction import get_doc_vec, get_glove_vec, get_biobert_embeddings_memory_safe
from src.ml_models import run_classical_ml_experiment
from src.dl_utils import batch_densify, build_sequence_model, evaluate_dl_model_memory
from src.visualization import plot_history

# Download required NLTK data
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

In [None]:
# Load Data
splits = {'train': 'train-00000-of-00001.parquet', 'test': 'test-00000-of-00001.parquet'}
train = pd.read_parquet("hf://datasets/community-datasets/ohsumed/" + splits["train"])
test = pd.read_parquet("hf://datasets/community-datasets/ohsumed/" + splits["test"])

# Visualize raw tags
plot_raw_mesh_frequencies(train, top_n=30)

In [None]:
# Clean out generic demographic and meta-tags using our heuristic filter
train['clinical_terms_list'] = train['mesh_terms'].apply(clean_mesh_smart)
test['clinical_terms_list'] = test['mesh_terms'].apply(clean_mesh_smart)

plot_label_shift(train)

In [None]:
# Flatten cleaned labels and calculate coverage
all_clinical = [label for sublist in train['clinical_terms_list'] for label in sublist]
counts_series = pd.Series(Counter(all_clinical)).sort_values(ascending=False)

cumulative_perc = counts_series.cumsum() / counts_series.sum()
n_labels_at_80 = (cumulative_perc <= 0.80).sum()

print(f"Number of labels required to capture 80% of clinical targets: {n_labels_at_80}")
plot_pareto_coverage(cumulative_perc, n_labels_at_80)

# Filter targets to the vital few
top_clinical_labels = counts_series.head(n_labels_at_80).index.tolist()
train['final_target_labels'] = train['clinical_terms_list'].apply(lambda x: [t for t in x if t in top_clinical_labels])
test['final_target_labels'] = test['clinical_terms_list'].apply(lambda x:[t for t in x if t in top_clinical_labels])

In [None]:
# Filter empty targets and binarize
train = train[train['final_target_labels'].apply(len) > 0].copy()
test = test[test['final_target_labels'].apply(len) > 0].copy()

mlb = MultiLabelBinarizer(sparse_output=True)
y_train = mlb.fit_transform(train['final_target_labels'])
y_test = mlb.transform(test['final_target_labels'])

# Text Fusion and Cleaning
train['X_raw'] = train['title'].fillna('') + " " + train['abstract'].fillna('')
test['X_raw'] = test['title'].fillna('') + " " + test['abstract'].fillna('')

train['X_cleaned'] = train['X_raw'].apply(feature_cleaner)
test['X_cleaned'] = test['X_raw'].apply(feature_cleaner)

In [None]:
# 1. TF-IDF
tfidf_vectorizer = TfidfVectorizer(max_features=25000, ngram_range=(1, 2), min_df=2)
X_train_tfidf = tfidf_vectorizer.fit_transform(train['X_cleaned'])
X_test_tfidf = tfidf_vectorizer.transform(test['X_cleaned'])

# 2. GloVe
glove_model = api.load("glove-wiki-gigaword-100")
X_train_glove = np.array([get_glove_vec(doc, glove_model, 100) for doc in train['X_cleaned']])
X_test_glove = np.array([get_glove_vec(doc, glove_model, 100) for doc in test['X_cleaned']])

# 3. BioBERT (Memory Safe)
biobert_tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
biobert_model = BertModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

X_train_bert = get_biobert_embeddings_memory_safe(train['X_cleaned'].tolist(), biobert_model, biobert_tokenizer)
X_test_bert = get_biobert_embeddings_memory_safe(test['X_cleaned'].tolist(), biobert_model, biobert_tokenizer)

In [None]:
# Save features for downstream modeling
sp.save_npz('features/X_train_tfidf.npz', X_train_tfidf)
sp.save_npz('features/y_train_sparse.npz', y_train)

np.savez_compressed('features/dense_embeddings.npz', 
                    glove_train=X_train_glove, glove_test=X_test_glove,
                    bert_train=X_train_bert, bert_test=X_test_bert)

joblib.dump(mlb, 'models/mlb_object.pkl')
print("âœ… Data preparation complete. Features and targets saved.")

In [None]:
print("Loading Classical ML Features...")
X_train_bow = sp.load_npz('features/X_train_bow.npz')
X_test_bow = sp.load_npz('features/X_test_bow.npz')
X_train_tfidf = sp.load_npz('features/X_train_tfidf.npz')
X_test_tfidf = sp.load_npz('features/X_test_tfidf.npz')

y_train = sp.load_npz('features/y_train_sparse.npz')
y_test = sp.load_npz('features/y_test_sparse.npz')

dense_data = np.load('features/dense_embeddings.npz')
X_train_w2v, X_test_w2v = dense_data['w2v_train'], dense_data['w2v_test']
X_train_glove, X_test_glove = dense_data['glove_train'], dense_data['glove_test']
X_train_bert, X_test_bert = dense_data['bert_train'], dense_data['bert_test']

feature_sets_train = {"Bag-of-Words": X_train_bow, "TF-IDF": X_train_tfidf, "W2V": X_train_w2v, "GloVe": X_train_glove, "BioBERT": X_train_bert}
feature_sets_test = {"Bag-of-Words": X_test_bow, "TF-IDF": X_test_tfidf, "W2V": X_test_w2v, "GloVe": X_test_glove, "BioBERT": X_test_bert}

In [None]:
# The Classical ML Showdown
# Runs our massive iteration block, skipping slow/bad combinations, and saving to CSV incrementally
results_df = run_classical_ml_experiment(feature_sets_train, feature_sets_test, y_train, y_test, 'results/ohsumed_ml_results.csv')
display(results_df.head(10))

---
7. Machine Learning approach

**Methodology**: Evaluated four feature sets (BoW, TF-IDF, W2V, GloVe) across four classification families using a Binary Relevance (One-Vs-Rest) strategy.

Analysis of Classical Results:
* **The Winner**: XGBoost on Bag-of-Words (0.3274 F1-Micro) achieved the highest score, proving that non-linear decision trees effectively capture
* **Efficiency Leader**: Linear SVM on TF-IDF (0.2991 F1-Micro) provided near-peak performance while being 8x faster than XGBoost.
* **Memory Management**: Sequential processing (n_jobs=1) was implemented to prevent memory duplication during the training of the 2,891 independent label-wise classifiers.

> Note: Bert embeddings though they do appear in the code are skipped due to long hours needed for classification.

---

In [None]:
# Deep Learning Text Preparation
train = pd.read_parquet('data/train_cleaned_dl.parquet')
test = pd.read_parquet('data/test_cleaned_dl.parquet')

# Tokenization setup
VOCAB_SIZE = 6485 
doc_lens = train['X_cleaned'].apply(lambda x: len(x.split()))
MAX_LEN = int(np.percentile(doc_lens, 90))

tokenizer = Tokenizer(num_words=VOCAB_SIZE, oov_token="[UNK]")
tokenizer.fit_on_texts(train['X_cleaned'])

X_train_dl = pad_sequences(tokenizer.texts_to_sequences(train['X_cleaned']), maxlen=MAX_LEN, padding='post')
X_test_final = pad_sequences(tokenizer.texts_to_sequences(test['X_cleaned']), maxlen=MAX_LEN, padding='post')

# Use our memory-safe densifier
y_train_dense = batch_densify(sp.load_npz('features/y_train_binarized.npz'))
y_test_final = batch_densify(sp.load_npz('features/y_test_binarized.npz'))

# Validation Split
X_train_final, X_val, y_train_final, y_val = train_test_split(X_train_dl, y_train_dense, test_size=0.2, random_state=42)
print(f"Deep Learning Data Ready. Train Shape: {X_train_final.shape}, Val Shape: {X_val.shape}")

---
## 8.ðŸ“‘  Deep Learning Performance & Analysis

### 8.1. Objective
The goal of this phase was to evaluate sequential Deep Learning architectures against the classical machine learning baselines. We aimed to determine if modeling the temporal order of medical abstracts provided a significant lift in classifying documents into the 2,891 Pareto-optimized MeSH categories.
    
### 8.2. Experimental Setup
All models utilized the following unified architecture parameters:
* **Vocabulary Size**: 6,485 words (90% Pareto coverage).
* **Sequence Length**: 162 tokens (90th percentile of document length).
* **Embedding Dimension**: 128 (Latent semantic space).
* **Global Architecture**: Utilized Global Average Pooling to mitigate vanishing gradients and capture context across the entire 162-word sequence.
* **Output Layer**: 2,891 neurons with Sigmoid activation (Multi-label requirement).
* **Loss Function**: Binary Cross-entropy.

---

In [None]:
# Deep Learning Training Loop
models_to_train =["Simple RNN", "Bidirectional RNN", "LSTM", "Bidirectional LSTM"]
trained_models = {}
callbacks =[EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)]

for model_name in models_to_train:
    print(f"\n--- Training {model_name} ---")
    model = build_sequence_model(model_name, VOCAB_SIZE, y_train_final.shape[1])
    
    history = model.fit(X_train_final, y_train_final, epochs=30, batch_size=64, 
                        validation_data=(X_val, y_val), callbacks=callbacks, verbose=1)
    
    plot_history(history, model_name)
    model.save(f'models/ohsumed_{model_name.replace(" ", "_").lower()}.keras')
    trained_models[model_name] = model

In [None]:
#  Final Memory-Safe Evaluation
print("Evaluating Deep Learning models on the unseen Test Set...")
dl_results_list =[]

for name, model in trained_models.items():
    # evaluate_dl_model_memory handles the batch-chunking and uint8 conversion behind the scenes!
    res = evaluate_dl_model_memory(model, X_test_final, y_test_final, name)
    dl_results_list.append(res)

# Convert to DataFrame, sort by best F1-Score, and display
results_dl_df = pd.DataFrame(dl_results_list).sort_values(by="F1 (Micro)", ascending=False)
display(results_dl_df)

# Save the deep learning results alongside the classical ML results
results_dl_df.to_csv('results/ohsumed_dl_results.csv', index=False)
print("âœ… Deep Learning evaluation complete. All results saved to disk.")

---
### 8.3. Technical Analysis & Insights
1. **The Sparsity Challenge**
    * In an extreme multi-label environment (2,891 classes), the dataset is 99.9% sparse. During the initial epochs, more complex models (LSTM/Bi-LSTM) tend toward a "conservative" baseline, predicting zero for all labels to minimize the loss function. While the Simple RNN broke this barrier to achieve an F1-Micro of 0.1632 (matching our Logistic Regression baseline), the more complex models reached a plateau at the standard 0.5 decision threshold.

2. **Statistical vs. Sequential Learning** 
* A key finding of this project is the dominance of statistical keyword frequency (XGBoost/TF-IDF) over sequential context (RNN/LSTM) for this specific medical corpus. Because medical terminology is highly specific ("Leukemia" is rarely ambiguous), the word-counting approach of XGBoost (0.3274 F1) remains the state-of-the-art for this hardware-constrained environment.

### 8.4. Engineering Solutions 
To handle the scale of 290,000 documents and 2,900 labels, several engineering strategies were implemented:
1.  Memory-Safe Batch Prediction: Developed a manual chunking algorithm to process the test set in 10,000-row increments, avoiding a RAM spike that caused standard model.predict() to fail.
2. Data Type Optimization: Utilized uint8 encoding for target matrices, reducing the RAM footprint of the labels from 6.2 GB to 810 MB.
3. Feature Caching: Implemented a local storage system for BioBERT embeddings and Deep Learning tensors to ensure persistence and reproducibility.

### 8.5. Final Conclusion  
This project successfully established a robust pipeline for Extreme Multi-label Classification. While Deep Learning offers a higher theoretical AUC (0.82), the Classical XGBoost model on Bag-of-Words provided the most effective practical performance.

### 8.6. Future Work
Further optimization would involve Threshold Tuning (lowering the 0.5 limit) and the application of Weighted Loss Functions to force the LSTMs to prioritize the rare labels over the "easy" zeros.

---

## 9. Project Conclusion
    
**Technical Milestones**
This project established a comprehensive pipeline for Extreme Multi-Label Classification (XMLC) on a massive scale. 
Over the course of the analysis, we successfully navigated the challenges of:
* **High Cardinality**: Managing an output space of 2,891 medical categories.
* **Big Data Engineering**: Processing 350,000 documents and optimizing RAM usage through uint8 data types and chunked batch predictions.
* **Technological Comparison**: Benchmarking 40 years of NLP evolution, from simple frequency counts to 110-million parameter Transformers.

**Key Findings**
* Keyword Density is crucial: In this medical corpus, statistical keyword models (XGBoost/SVM on TF-IDF) proved highly effective. Because medical jargon is so specific, simple word-counting often provided a clearer signal than sequential learning.
* **"The "Winner"**: XGBoost on Bag-of-Words achieved the top score of 0.3274 F1-Micro.
* Sequential Challenges: Sequential models (RNNs/LSTMs) required more training time than classical models to overcome the "all-zero" baseline caused by extreme label sparsity.


## 10. Future Work: 

While the current framework provides a stable baseline for large-scale medical indexing, the following avenues represent the logical progression for improving predictive recall and model depth:
    
**1.** ***Dynamic Threshold Optimization***
The application of a universal 0.5 threshold is inherently conservative for sparse multi-label data. Implementing Per-Label Threshold Optimization would allow the model to adjust for rare diseases, significantly increasing Recall without compromising the Precision of high-frequency categories.
    
**2.** ***Handling Class Imbalance (Focal Loss)***
With 2,891 labels, the "zeros" vastly outnumber the "ones." Implementing Focal Loss or Weighted Cross-Entropy would force the model to focus on the difficult rare diseases rather than taking the "easy win" by predicting zeros.
    
**3.** ***Hierarchical Label Aggregation***
The MeSH vocabulary is a directed acyclic graph. Leveraging Hierarchical Multi-Label Learning (HMLL) would allow the model to utilize relationship dependencies (e.g., predicting "Neoplasms" as a prerequisite for predicting "Lung Neoplasms"), structurally reducing the search space for the classifier.