In [2]:
import os 
os.environ["HF_HOME"] = "/home/uw8/huggingface_cache"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

import time
import pandas as pd
import numpy as np
import re


import outlines
from outlines import models, generate


from vllm import LLM, SamplingParams

#from fuzzywuzzy import process

from functools import partial
import json

# Exp1 - Direct Generic Extraction (LLM)

In [3]:
# Configuration
CHECKPOINT_FILE = "drug_extraction_checkpoint.csv"
DATA_FILE = "sample_dummy_dataset.csv"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

In [None]:
# Initialize model components
llm = LLM(
    model=MODEL_NAME,
    dtype="float16",
    tensor_parallel_size=2,
    gpu_memory_utilization=0.9,
    max_model_len=4096,
    enforce_eager=True
)

INFO 06-14 20:56:07 __init__.py:207] Automatically detected platform cuda.


INFO 06-14 20:56:12 config.py:549] This model supports multiple tasks: {'generate', 'reward', 'classify', 'score', 'embed'}. Defaulting to 'generate'.
INFO 06-14 20:56:12 config.py:1382] Defaulting to use mp for distributed inference
INFO 06-14 20:56:12 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.3) with config: model='meta-llama/Llama-3.1-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, 

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=2906359)[0;0m INFO 06-14 20:56:17 model_runner.py:1115] Loading model weights took 7.5123 GB
INFO 06-14 20:56:17 model_runner.py:1115] Loading model weights took 7.5123 GB


In [4]:
# Schema Definition
DRUG_SCHEMA = '''{
    "type": "object",
    "properties": {
        "generic_names": {
            "type": "array",
            "items": {"type": "string"},
            "minItems": 0
        }
    },
    "required": ["generic_names"]
}'''


def format_prompt(text: str) -> str:
    return f"""Extract and normalize drug names to generic forms in JSON format. Rules:
1. Convert brands to generics (Oncovin → vincristine)
2. Expand abbreviations (MTX → methotrexate)
3. Correct misspellings (Methotrxate → methotrexate)
4. If a generic equivalent is unknown, include the raw drug name in lowercase.
5. Use lowercase only

Examples:
Input: Administered Oncovin and IT MTX
Output: {{"generic_names": ["vincristine", "methotrexate"]}}



Process this text:
{text}
JSON Output:"""

In [5]:

sampling_params = SamplingParams(
    temperature=0.2,
    top_k=150,
    top_p=0.6,
    repetition_penalty=1.1,
    max_tokens=8000,
    seed=42
)

model = models.VLLM(llm)
generator = generate.json(
    model,
    DRUG_SCHEMA,
    whitespace_pattern=r"[\s]*",
)

In [6]:


def safe_json_loads(x):
    """Safely parse JSON with fallback"""
    if pd.isna(x) or x.strip() in ['', '{}', '[]']:
        return []
    try:
        return json.loads(x)
    except json.JSONDecodeError:
        return []

def process_batch(batch_df):
    prompts = [format_prompt(text) for text in batch_df["text_concat"]]
    results = []
    
    try:
        responses = generator(prompts, sampling_params=sampling_params)
        for response, (_, row) in zip(responses, batch_df.iterrows()):
            try:
                if isinstance(response, str):
                    data = json.loads(response)
                else:
                    data = response
                drugs = data.get("generic_names", [])
                results.append({
                    "unique_key": row["unique_key"],
                    "text_concat": row["text_concat"],
                    "json_extraction": json.dumps(data, ensure_ascii=False),
                    "extracted_drugs": drugs
                })
            except Exception as e:
                results.append({
                    "unique_key": row["unique_key"],
                    "text_concat": row["text_concat"],
                    "json_extraction": "{}",
                    "extracted_drugs": []
                })
    except Exception as e:
        print(f"Batch failed: {str(e)[:200]}")
        for _, row in batch_df.iterrows():
            results.append({
                "unique_key": row["unique_key"],
                "text_concat": row["text_concat"],
                "json_extraction": "{}",
                "extracted_drugs": []
            })
    
    return pd.DataFrame(results)

def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        try:
            df = pd.read_csv(
                CHECKPOINT_FILE,
                converters={
                    'unique_drugs': safe_json_loads,
                    'extracted_drugs': safe_json_loads,
                    'json_extraction': safe_json_loads
                }
            )
            if 'unique_key' not in df.columns:
                raise ValueError("Corrupted checkpoint - missing columns")
            return df
        except Exception as e:
            print(f"Checkpoint reset due to error: {str(e)[:200]}")
            os.rename(CHECKPOINT_FILE, f"{CHECKPOINT_FILE}.corrupted")
            return pd.DataFrame()
    return pd.DataFrame()

def save_checkpoint(df):
    df.to_csv(CHECKPOINT_FILE, index=False)

def compute_metrics(row):
    true_drugs = set(row['unique_drugs'])
    pred_drugs = set(row['extracted_drugs'])
    
    tp = len(true_drugs & pred_drugs)
    precision = tp / len(pred_drugs) if pred_drugs else 0
    recall = tp / len(true_drugs) if true_drugs else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0
    
    return pd.Series({
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "missing_drugs": list(true_drugs - pred_drugs),
        "hallucinated_drugs": list(pred_drugs - true_drugs)
    })

def print_avg_metrics(df):
    if not df.empty and 'precision' in df.columns:
        print("\nAverage Metrics Across All Processed Records:")
        print(f"Precision: {df['precision'].mean():.4f}")
        print(f"Recall:    {df['recall'].mean():.4f}")
        print(f"F1 Score:  {df['f1'].mean():.4f}")
    else:
        print("\nNo metrics available - empty dataset or missing columns")

def run_pipeline():
    df = load_checkpoint()
    raw_data = pd.read_csv(DATA_FILE)
    
    # Create unique key first
    raw_data["unique_key"] = (
        raw_data["patient_id_number"].astype(str) + "_" +
        raw_data["tumor_record_number"].astype(str) + "_" +
        raw_data["admission_id"].astype(str)
    )
    raw_data["unique_drugs"] = raw_data["unique_drugs"].apply(
        lambda x: x.split(", ") if isinstance(x, str) else []
    )
    
    # Find unprocessed records
    processed_keys = set(df["unique_key"]) if not df.empty else set()
    todo = raw_data[~raw_data["unique_key"].isin(processed_keys)]
    
    if todo.empty:
        print("All data processed")
        print_avg_metrics(df)
        return df
    
    print(f"Processing {len(todo)} new records...")
    
    # Process in batches
    batch_size = 32
    for i in range(0, len(todo), batch_size):
        batch = todo.iloc[i:i+batch_size].copy()
        batch_results = process_batch(batch)
        
        # Merge results
        merged = pd.merge(
            batch,
            batch_results,
            on=["unique_key", "text_concat"],
            how="left"
        )
        
        # Calculate metrics
        if not merged.empty:
            merged[["precision", "recall", "f1", "missing_drugs", "hallucinated_drugs"]] = \
                merged.apply(compute_metrics, axis=1)
        
        # Update and save incrementally
        df = pd.concat([df, merged], ignore_index=True)
        save_checkpoint(df)
        print(f"Processed batch {i//batch_size + 1}/{(len(todo)//batch_size)+1}")
    
    print_avg_metrics(df)
    return df

In [7]:
if __name__ == "__main__":
    final_df = run_pipeline()
    print("\nPipeline completed. Sample output:")
    if not final_df.empty:
        print(final_df[["unique_key", "text_concat", "extracted_drugs"]].head(2).to_string(index=False))

Processing 100 new records...


Processed prompts:   0%|          | 0/32 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts: 100%|██████████| 32/32 [01:13<00:00,  2.29s/it, est. speed input: 70.95 toks/s, output: 68.17 toks/s]  


Batch failed: Unterminated string starting at: line 1 column 11814 (char 11813)
Processed batch 1/4


Processed prompts: 100%|██████████| 32/32 [00:09<00:00,  3.21it/s, est. speed input: 520.21 toks/s, output: 111.02 toks/s]


Processed batch 2/4


Processed prompts: 100%|██████████| 32/32 [00:10<00:00,  3.15it/s, est. speed input: 513.56 toks/s, output: 112.01 toks/s]


Processed batch 3/4


Processed prompts: 100%|██████████| 4/4 [01:06<00:00, 16.52s/it, est. speed input: 9.80 toks/s, output: 61.32 toks/s]  

Batch failed: Expecting value: line 1 column 10568 (char 10567)
Processed batch 4/4

Average Metrics Across All Processed Records:
Precision: 0.4957
Recall:    0.4329
F1 Score:  0.4571

Pipeline completed. Sample output:
    unique_key                                                                                                                                     text_concat extracted_drugs
8270_3570_7659 The regimen included CHOP along with Methotrexate and Prednisone. Added Prednisone for hormonal therapy. Prescribed Caplacizumab and Rituximab.              []
1860_7056_4291            Administered IT MTX and Doxorubicin. Added Prednisone for hormonal therapy. Patient received Atezolizumab for lung cancer treatment.              []



