## Step 0: Import Libraries

In [1]:
import pandas as pd
import numpy as np
import sys
import warnings
import re
import json
import os
from collections import defaultdict
from tqdm import tqdm

warnings.filterwarnings('ignore')



## Step 1: Load MedMentions Dataset

Parse the PubTator format and split by train/dev/test PMIDs

In [2]:
def parse_pubtator_file(file_path, max_docs=None):
    """
    Parse PubTator format file and return list of mention dictionaries.
    
    Args:
        file_path: Path to corpus_pubtator.txt
        max_docs: Optional limit on number of documents
    
    Returns:
        List of mention dictionaries with metadata
    """
    mentions = []
    current_pmid = None
    current_title = ""
    current_abstract = ""
    doc_count = 0
    
    print(f"Parsing {file_path}...")
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            line = line.rstrip('\n')
            
            if not line:  # Empty line = document separator
                current_pmid = None
                current_title = ""
                current_abstract = ""
                continue
            
            parts = line.split('|')
            
            if len(parts) == 3:  # Title or abstract line
                pmid, line_type, text = parts
                if line_type == 't':
                    current_pmid = pmid
                    current_title = text
                    doc_count += 1
                    if max_docs and doc_count > max_docs:
                        break
                elif line_type == 'a':
                    current_abstract = text
            
            elif '\t' in line:  
                fields = line.split('\t')
                if len(fields) >= 6:
                    pmid = fields[0]
                    start_offset = int(fields[1])
                    end_offset = int(fields[2])
                    mention_text = fields[3]
                    entity_type = fields[4]
                    entity_id = fields[5]
                    
                    
                    full_text = current_title + " " + current_abstract
                    
                    # Extract context (200 chars before and after)
                    context_window = 200
                    context_start = max(0, start_offset - context_window)
                    context_end = min(len(full_text), end_offset + context_window)
                    
                    context_left = full_text[context_start:start_offset]
                    context_right = full_text[end_offset:context_end]
                    
                    mentions.append({
                        'pmid': pmid,
                        'mention': mention_text,
                        'entity_id': entity_id,  
                        'entity_type': entity_type,
                        'start': start_offset,
                        'end': end_offset,
                        'context_left': context_left,
                        'context_right': context_right,
                        'full_context': context_left + " " + mention_text + " " + context_right,
                        'title': current_title,
                        'abstract': current_abstract,
                        'text': full_text
                    })
    
    print(f"Parsed {len(mentions)} mentions from {doc_count} documents")
    return mentions

In [3]:

CORPUS_PATH = 'data/MedMentions/corpus_pubtator.txt'
TRAIN_PMIDS_PATH = 'data/MedMentions/corpus_pubtator_pmids_trng.txt'
DEV_PMIDS_PATH = 'data/MedMentions/corpus_pubtator_pmids_dev.txt'
TEST_PMIDS_PATH = 'data/MedMentions/corpus_pubtator_pmids_test.txt'

print("LOADING MEDMENTIONS DATASET")

mentions_list = parse_pubtator_file(CORPUS_PATH)
df_all = pd.DataFrame(mentions_list)

print(f"\nTotal mentions: {len(df_all):,}")
print(f"Total PMIDs: {df_all['pmid'].nunique():,}")
print(f"Total unique UMLS entities: {df_all['entity_id'].nunique():,}")

LOADING MEDMENTIONS DATASET
Parsing data/MedMentions/corpus_pubtator.txt...


365672it [00:02, 128875.31it/s]


Parsed 352496 mentions from 4392 documents

Total mentions: 352,496
Total PMIDs: 4,392
Total unique UMLS entities: 34,724


In [4]:
print("\nLoading train/dev/test splits...")

with open(TRAIN_PMIDS_PATH, 'r') as f:
    train_pmids = set(f.read().splitlines())

with open(DEV_PMIDS_PATH, 'r') as f:
    dev_pmids = set(f.read().splitlines())

with open(TEST_PMIDS_PATH, 'r') as f:
    test_pmids = set(f.read().splitlines())


df_train = df_all[df_all['pmid'].isin(train_pmids)].copy()
df_val = df_all[df_all['pmid'].isin(dev_pmids)].copy()
df_test = df_all[df_all['pmid'].isin(test_pmids)].copy()

print(f"\n Split Statistics:")
print(f"Train: {len(df_train):,} mentions from {df_train['pmid'].nunique():,} documents")
print(f"Val:   {len(df_val):,} mentions from {df_val['pmid'].nunique():,} documents")
print(f"Test:  {len(df_test):,} mentions from {df_test['pmid'].nunique():,} documents")


Loading train/dev/test splits...

 Split Statistics:
Train: 211,029 mentions from 2,635 documents
Val:   71,062 mentions from 878 documents
Test:  70,405 mentions from 879 documents


## Step 2: Filter Valid Entities (Optional)

Remove mentions without valid UMLS CUIDs

In [5]:
def filter_valid_entities(df):
    """
    Keep only mentions with valid UMLS CUIDs.
    """
    original_count = len(df)
    
    df_filtered = df[(df['entity_id'].notna()) & (df['entity_id'] != '-')].copy()
    
    filtered_count = len(df_filtered)
    removed = original_count - filtered_count
    
    print(f"Original: {original_count:,} | Valid: {filtered_count:,} | Removed: {removed:,} ({removed/original_count*100:.2f}%)")
    
    return df_filtered

print("Filtering valid entities...")
print("\nTrain:")
df_train = filter_valid_entities(df_train)
print("\nVal:")
df_val = filter_valid_entities(df_val)
print("\nTest:")
df_test = filter_valid_entities(df_test)

Filtering valid entities...

Train:
Original: 211,029 | Valid: 211,029 | Removed: 0 (0.00%)

Val:
Original: 71,062 | Valid: 71,062 | Removed: 0 (0.00%)

Test:
Original: 70,405 | Valid: 70,405 | Removed: 0 (0.00%)


## Step 3: Normalize Mentions

Clean and normalize mention text

In [6]:
def normalize_mention(mention):
    """
    Normalize mention text:
    - Remove extra whitespace
    - Strip leading/trailing punctuation
    - Preserve medical abbreviations (keep uppercase)
    """
    normalized = re.sub(r'\s+', ' ', mention).strip()
    
    normalized = re.sub(r'^[^\w-]+|[^\w-]+$', '', normalized)
    
    return normalized

print("Normalizing mentions...")

df_train['normalized_mention'] = df_train['mention'].apply(normalize_mention)
df_val['normalized_mention'] = df_val['mention'].apply(normalize_mention)
df_test['normalized_mention'] = df_test['mention'].apply(normalize_mention)

print("\nExamples:")
sample = df_train.head(10)
for orig, norm in zip(sample['mention'], sample['normalized_mention']):
    if orig != norm:
        print(f"  '{orig}' → '{norm}'")

print("\n Normalization complete!")

Normalizing mentions...

Examples:

 Normalization complete!


## Step 4: Extract Biomedical Features

Add domain-specific features for medical entity linking

In [7]:
def extract_biomedical_features(df):
    """
    Extract biomedical-specific features:
    - is_abbreviation: All caps, short (≤6 chars)
    - mention_word_count: Number of words in mention
    - semantic_type_main: First semantic type from entity_type
    """
    df['is_abbreviation'] = df['mention'].apply(
        lambda x: x.isupper() and len(x) <= 6
    )
    
    df['mention_word_count'] = df['mention'].str.split().str.len()
    
    df['semantic_type_main'] = df['entity_type'].str.split(',').str[0]
    
    df['mention_length'] = df['mention'].str.len()
    
    df['context_length'] = df['full_context'].str.split().str.len()
    
    return df

print("Extracting biomedical features...")

df_train = extract_biomedical_features(df_train)
df_val = extract_biomedical_features(df_val)
df_test = extract_biomedical_features(df_test)

print("\n Feature Statistics (Train):")
print(f"Abbreviations: {df_train['is_abbreviation'].sum():,} ({df_train['is_abbreviation'].mean()*100:.1f}%)")
print(f"Mean mention length: {df_train['mention_length'].mean():.1f} chars")
print(f"Mean context length: {df_train['context_length'].mean():.1f} words")
print(f"Mean words per mention: {df_train['mention_word_count'].mean():.2f}")
print(f"\nTop 5 semantic types:")
print(df_train['semantic_type_main'].value_counts().head())

Extracting biomedical features...

 Feature Statistics (Train):
Abbreviations: 17,526 (8.3%)
Mean mention length: 10.7 chars
Mean context length: 56.8 words
Mean words per mention: 1.37

Top 5 semantic types:
semantic_type_main
T080    18689
T169    14241
T081    11888
T033     9511
T116     9194
Name: count, dtype: int64


## Step 5: Create Mention-Candidate Pairs

Format data for entity linking models

In [8]:
def create_mention_record(row):
    """
    Convert dataframe row to standardized mention record.
    """
    record = {
        # Core fields
        'pmid': row['pmid'],
        'mention': row['mention'],
        'normalized_mention': row['normalized_mention'],
        
        # Context
        'context_left': row['context_left'],
        'context_right': row['context_right'],
        'full_context': row['full_context'],
        
        # Entity information
        'label_id': row['entity_id'], 
        'entity_type': row['entity_type'],
        'semantic_type_main': row['semantic_type_main'],
        
        # Position
        'start': row['start'],
        'end': row['end'],
        
        # Features
        'is_abbreviation': row['is_abbreviation'],
        'mention_length': row['mention_length'],
        'mention_word_count': row['mention_word_count'],
        'context_length': row['context_length'],
        
        # Metadata
        'title': row['title'],
        'abstract': row['abstract'],
        
        # Placeholder for candidates (to be populated later)
        'candidates': []
    }
    
    return record

print("Creating standardized mention records...")

train_records = [create_mention_record(row) for _, row in df_train.iterrows()]
val_records = [create_mention_record(row) for _, row in df_val.iterrows()]
test_records = [create_mention_record(row) for _, row in df_test.iterrows()]

print(f"\nCreated {len(train_records):,} train records")
print(f"Created {len(val_records):,} val records")
print(f"Created {len(test_records):,} test records")


print("\n Example record:")
import json
print(json.dumps(train_records[0], indent=2))

Creating standardized mention records...

Created 211,029 train records
Created 71,062 val records
Created 70,405 test records

 Example record:
{
  "pmid": "25763772",
  "mention": "DCTN4",
  "normalized_mention": "DCTN4",
  "context_left": "",
  "context_right": " as a modifier of chronic Pseudomonas aeruginosa infection in cystic fibrosis Pseudomonas aeruginosa (Pa) infection in cystic fibrosis (CF) patients is associated with worse long-term pulmonary diseas",
  "full_context": " DCTN4  as a modifier of chronic Pseudomonas aeruginosa infection in cystic fibrosis Pseudomonas aeruginosa (Pa) infection in cystic fibrosis (CF) patients is associated with worse long-term pulmonary diseas",
  "label_id": "C4308010",
  "entity_type": "T116,T123",
  "semantic_type_main": "T116",
  "start": 0,
  "end": 5,
  "is_abbreviation": true,
  "mention_length": 5,
  "mention_word_count": 1,
  "context_length": 28,
  "title": "DCTN4 as a modifier of chronic Pseudomonas aeruginosa infection in cystic f

## Step 6: Export Preprocessed Data

Save in multiple formats for different use cases

In [9]:
os.makedirs('data/processed/medmentions', exist_ok=True)

print("EXPORTING PREPROCESSED DATA")

# Format 1: JSONL (for entity linking models)
print("\n1. Exporting JSONL format...")
with open('data/processed/medmentions/train.jsonl', 'w', encoding='utf-8') as f:
    for record in train_records:
        f.write(json.dumps(record) + '\n')
print("   train.jsonl")

with open('data/processed/medmentions/val.jsonl', 'w', encoding='utf-8') as f:
    for record in val_records:
        f.write(json.dumps(record) + '\n')
print("   val.jsonl")

with open('data/processed/medmentions/test.jsonl', 'w', encoding='utf-8') as f:
    for record in test_records:
        f.write(json.dumps(record) + '\n')
print("   test.jsonl")

# Format 2: Parquet (for fast loading)
print("\n2. Exporting Parquet format...")
df_train_export = pd.DataFrame(train_records)
df_val_export = pd.DataFrame(val_records)
df_test_export = pd.DataFrame(test_records)

df_train_export.to_parquet('data/processed/medmentions/train.parquet', index=False)
df_val_export.to_parquet('data/processed/medmentions/val.parquet', index=False)
df_test_export.to_parquet('data/processed/medmentions/test.parquet', index=False)
print("   train.parquet")
print("   val.parquet")
print("   test.parquet")

print("\n All exports complete!")

EXPORTING PREPROCESSED DATA

1. Exporting JSONL format...
   train.jsonl
   val.jsonl
   test.jsonl

2. Exporting Parquet format...
   train.parquet
   val.parquet
   test.parquet

 All exports complete!


## Step 7: Generate Preprocessing Statistics

Summary statistics for documentation and analysis

In [10]:
def compute_split_statistics(df, split_name):
    """
    Compute comprehensive statistics for a data split.
    """
    stats = {
        'split': split_name,
        'num_mentions': len(df),
        'num_pmids': df['pmid'].nunique(),
        'num_unique_entities': df['entity_id'].nunique(),
        'num_unique_semantic_types': df['semantic_type_main'].nunique(),
        'num_abbreviations': df['is_abbreviation'].sum(),
        'pct_abbreviations': f"{df['is_abbreviation'].mean()*100:.2f}%",
        'avg_mention_length': f"{df['mention_length'].mean():.1f}",
        'avg_context_length': f"{df['context_length'].mean():.1f}",
        'avg_words_per_mention': f"{df['mention_word_count'].mean():.2f}",
        'top_semantic_type': df['semantic_type_main'].value_counts().index[0],
        'top_semantic_type_count': df['semantic_type_main'].value_counts().iloc[0]
    }
    return stats

print("Computing preprocessing statistics...\n")

train_stats = compute_split_statistics(df_train, 'TRAIN')
val_stats = compute_split_statistics(df_val, 'VAL')
test_stats = compute_split_statistics(df_test, 'TEST')

stats_df = pd.DataFrame([train_stats, val_stats, test_stats])
print("MEDMENTIONS PREPROCESSING STATISTICS")
print("="*70)
print(stats_df.to_string(index=False))

stats_df.to_csv('data/processed/medmentions/preprocessing_stats.csv', index=False)
print("\nStatistics saved to preprocessing_stats.csv")

Computing preprocessing statistics...

MEDMENTIONS PREPROCESSING STATISTICS
split  num_mentions  num_pmids  num_unique_entities  num_unique_semantic_types  num_abbreviations pct_abbreviations avg_mention_length avg_context_length avg_words_per_mention top_semantic_type  top_semantic_type_count
TRAIN        211029       2635                25691                        126              17526             8.31%               10.7               56.8                  1.37              T080                    18689
  VAL         71062        878                12610                        124               5744             8.08%               10.7               56.9                  1.37              T080                     6435
 TEST         70405        879                12419                        123               5922             8.41%               10.7               56.8                  1.37              T080                     6361

Statistics saved to preprocessing_stats.csv
