# HealthCareMagic-100K Dataset Exploration

This notebook explores the HealthCareMagic-100K dataset to understand its structure, content, and prepare it for fine-tuning a medical QA model.

In [None]:
# Import libraries
import os
import json
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from datasets import load_dataset
from wordcloud import WordCloud
from tqdm.notebook import tqdm

# Set figure size
plt.rcParams["figure.figsize"] = (12, 8)

## Load HealthCareMagic-100K Dataset

The HealthCareMagic-100K dataset contains 100,000+ patient questions and doctor answers from a medical consultation website.

In [None]:
# Try loading the dataset directly from Hugging Face
try:
    dataset = load_dataset("vaibhavs10/healthcaremagic-100k")
    print(f"Dataset loaded from Hugging Face: {dataset}")
    
    # Convert to pandas DataFrame
    train_df = pd.DataFrame(dataset["train"])
    test_df = pd.DataFrame(dataset["test"])
    
    print(f"Training set: {len(train_df)} examples")
    print(f"Test set: {len(test_df)} examples")
except Exception as e:
    print(f"Error loading dataset from Hugging Face: {e}")
    
    # Try loading from local files
    try:
        data_dir = "../data/processed"
        
        # Check if processed files exist
        if os.path.exists(os.path.join(data_dir, "train.jsonl")):
            with open(os.path.join(data_dir, "train.jsonl"), "r") as f:
                train_data = [json.loads(line) for line in f]
            
            with open(os.path.join(data_dir, "validation.jsonl"), "r") as f:
                val_data = [json.loads(line) for line in f]
                
            with open(os.path.join(data_dir, "test.jsonl"), "r") as f:
                test_data = [json.loads(line) for line in f]
            
            train_df = pd.DataFrame(train_data)
            val_df = pd.DataFrame(val_data)
            test_df = pd.DataFrame(test_data)
            
            print(f"Training set: {len(train_df)} examples")
            print(f"Validation set: {len(val_df)} examples")
            print(f"Test set: {len(test_df)} examples")
        else:
            print("Processed data files not found. Please run the data preparation script first.")
    except Exception as e2:
        print(f"Error loading local dataset: {e2}")

## Explore Dataset Structure

In [None]:
# Display sample examples
print("Training sample:")
train_df.head(2)

In [None]:
# Column information
print("Dataset columns:")
for col in train_df.columns:
    print(f"- {col}")

## Analyze Data Distributions

In [None]:
# Text length statistics
def get_text_stats(df, text_column):
    lengths = df[text_column].str.split().str.len()
    return {
        "mean": lengths.mean(),
        "median": lengths.median(),
        "min": lengths.min(),
        "max": lengths.max(),
        "90th_percentile": lengths.quantile(0.9),
        "95th_percentile": lengths.quantile(0.95),
        "99th_percentile": lengths.quantile(0.99),
    }

# Input text statistics
input_stats = get_text_stats(train_df, "input")
print("Input text statistics (word count):")
for key, value in input_stats.items():
    print(f"- {key}: {value:.1f}")

# Output text statistics
output_stats = get_text_stats(train_df, "output")
print("\nOutput text statistics (word count):")
for key, value in output_stats.items():
    print(f"- {key}: {value:.1f}")

In [None]:
# Plot text length distributions
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Input length distribution
input_lengths = train_df["input"].str.split().str.len()
sns.histplot(input_lengths.clip(upper=500), ax=axes[0], bins=50, kde=True)
axes[0].set_title("Input Text Length Distribution")
axes[0].set_xlabel("Word Count")
axes[0].set_ylabel("Frequency")

# Output length distribution
output_lengths = train_df["output"].str.split().str.len()
sns.histplot(output_lengths.clip(upper=500), ax=axes[1], bins=50, kde=True)
axes[1].set_title("Output Text Length Distribution")
axes[1].set_xlabel("Word Count")
axes[1].set_ylabel("Frequency")

plt.tight_layout()
plt.show()

## Extract Medical Topics and Conditions

In [None]:
# Extract common medical conditions from text
medical_conditions = [
    "diabetes", "hypertension", "asthma", "arthritis", "depression", "anxiety",
    "cancer", "heart disease", "stroke", "alzheimer", "parkinson", "epilepsy",
    "copd", "flu", "pneumonia", "hiv", "tuberculosis", "malaria", "hepatitis",
    "migraine", "osteoporosis", "thyroid", "lupus", "fibromyalgia", "eczema",
    "psoriasis", "allergy", "infection", "fever", "pain", "inflammation",
    "fracture", "injury", "surgery", "pregnancy"
]

# Count conditions in input and output text
def extract_conditions(text, conditions_list):
    found_conditions = []
    for condition in conditions_list:
        pattern = r'\b' + condition + r'\b'
        if re.search(pattern, text.lower()):
            found_conditions.append(condition)
    return found_conditions

# Apply to the dataset
all_conditions = []
for text in tqdm(train_df["input"] + " " + train_df["output"]):
    all_conditions.extend(extract_conditions(text, medical_conditions))

# Count frequencies
condition_counts = Counter(all_conditions)
top_conditions = condition_counts.most_common(15)

# Plot top conditions
plt.figure(figsize=(12, 8))
sns.barplot(x=[condition for condition, count in top_conditions],
            y=[count for condition, count in top_conditions])
plt.title("Top 15 Medical Conditions in the Dataset")
plt.xlabel("Condition")
plt.ylabel("Frequency")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

## Word Cloud Visualization

In [None]:
# Create word clouds for input and output text
def create_wordcloud(text, title):
    wordcloud = WordCloud(width=800, height=400, background_color="white", 
                         max_words=100, contour_width=3, contour_color="steelblue")
    
    # Generate word cloud
    wordcloud.generate(text)
    
    # Display
    plt.figure(figsize=(12, 8))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Sample 1000 examples for faster processing
sample_df = train_df.sample(1000, random_state=42)

# Create word clouds
create_wordcloud(" ".join(sample_df["input"]), "Word Cloud of Patient Questions")
create_wordcloud(" ".join(sample_df["output"]), "Word Cloud of Doctor Answers")

## Analyze Response Structure

In [None]:
# Check for diagnosis and treatment sections in responses
def has_diagnosis(text):
    return bool(re.search(r"(?i)diagnosis|assessment|impression", text))

def has_treatment(text):
    return bool(re.search(r"(?i)treatment|plan|recommendation|therapy|management", text))

# Apply to the dataset
train_df["has_diagnosis"] = train_df["output"].apply(has_diagnosis)
train_df["has_treatment"] = train_df["output"].apply(has_treatment)
train_df["has_both"] = train_df["has_diagnosis"] & train_df["has_treatment"]

# Calculate percentages
diagnosis_percent = train_df["has_diagnosis"].mean() * 100
treatment_percent = train_df["has_treatment"].mean() * 100
both_percent = train_df["has_both"].mean() * 100

print(f"Responses with diagnosis section: {diagnosis_percent:.1f}%")
print(f"Responses with treatment section: {treatment_percent:.1f}%")
print(f"Responses with both sections: {both_percent:.1f}%")

# Plot
plt.figure(figsize=(10, 6))
plt.bar(["Diagnosis", "Treatment", "Both"], [diagnosis_percent, treatment_percent, both_percent])
plt.title("Percentage of Responses with Diagnosis and Treatment Sections")
plt.ylabel("Percentage (%)")
plt.ylim(0, 100)
for i, v in enumerate([diagnosis_percent, treatment_percent, both_percent]):
    plt.text(i, v+2, f"{v:.1f}%", ha="center")
plt.show()

## Analyze Instruction Template Performance

In [None]:
# If we have instruction templates in the processed data
if "instruction" in train_df.columns:
    # Count different instruction templates
    instruction_counts = train_df["instruction"].value_counts()
    
    print("Instruction template distribution:")
    for instruction, count in instruction_counts.items():
        print(f"- {instruction}: {count} examples ({count/len(train_df)*100:.1f}%)")
    
    # Plot
    plt.figure(figsize=(14, 6))
    sns.countplot(y="instruction", data=train_df, order=instruction_counts.index)
    plt.title("Distribution of Instruction Templates")
    plt.tight_layout()
    plt.show()

## Check Token Lengths for Model Compatibility

In [None]:
# Load tokenizer
from transformers import AutoTokenizer

try:
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Using a different tokenizer for estimation")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Function to tokenize and get stats
def get_token_stats(texts, tokenizer):
    token_lengths = [len(tokenizer.encode(text)) for text in texts[:1000]]  # Sample for speed
    return {
        "mean": np.mean(token_lengths),
        "median": np.median(token_lengths),
        "95th_percentile": np.percentile(token_lengths, 95),
        "99th_percentile": np.percentile(token_lengths, 99),
        "max": np.max(token_lengths),
    }

# Combine instruction, input, and output as they would be formatted for training
if all(col in train_df.columns for col in ["instruction", "input", "output"]):
    # Sample for faster processing
    sample_df = train_df.sample(min(1000, len(train_df)), random_state=42)
    
    # Format as training examples
    formatted_examples = []
    for _, row in sample_df.iterrows():
        if row["input"]:
            formatted = f"<|user|>\n{row['instruction']}\n\n{row['input']}<|endofuser|>\n<|assistant|>\n{row['output']}<|endofassistant|>"
        else:
            formatted = f"<|user|>\n{row['instruction']}<|endofuser|>\n<|assistant|>\n{row['output']}<|endofassistant|>"
        formatted_examples.append(formatted)
    
    # Get token stats
    token_stats = get_token_stats(formatted_examples, tokenizer)
    
    print("Token length statistics for formatted examples:")
    for key, value in token_stats.items():
        print(f"- {key}: {value:.1f}")
    
    # Plot token length distribution
    token_lengths = [len(tokenizer.encode(text)) for text in formatted_examples]
    
    plt.figure(figsize=(12, 6))
    sns.histplot(token_lengths, bins=50, kde=True)
    plt.axvline(x=4096, color='r', linestyle='--', label='LLaMA 3.1 Context Limit (4096)')
    plt.title("Token Length Distribution of Formatted Examples")
    plt.xlabel("Token Count")
    plt.ylabel("Frequency")
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    # Check percentage exceeding context window
    exceeding = sum(1 for length in token_lengths if length > 4096)
    print(f"Percentage of examples exceeding 4096 tokens: {exceeding/len(token_lengths)*100:.2f}%")

## Extract and Visualize Medical Specialties

In [None]:
# Define medical specialties and their associated keywords
specialties = {
    "cardiology": ["heart", "cardiac", "myocardial", "infarction", "angina", "hypertension", "arrhythmia"],
    "pulmonology": ["lung", "pulmonary", "respiratory", "asthma", "copd", "pneumonia"],
    "neurology": ["brain", "neural", "seizure", "stroke", "dementia", "alzheimer", "parkinson"],
    "gastroenterology": ["stomach", "intestine", "bowel", "liver", "pancreas", "gallbladder", "ulcer"],
    "orthopedics": ["bone", "joint", "fracture", "arthritis", "osteoporosis", "rheumatoid"],
    "endocrinology": ["diabetes", "thyroid", "hormone", "insulin", "adrenal", "pituitary"],
    "infectious_disease": ["infection", "bacterial", "viral", "fungal", "sepsis", "antibiotic"],
    "dermatology": ["skin", "rash", "eczema", "psoriasis", "acne", "dermatitis"],
    "obstetrics_gynecology": ["pregnancy", "menstrual", "uterus", "ovary", "cervical", "ovulation"],
    "psychiatry": ["depression", "anxiety", "bipolar", "schizophrenia", "mental", "psychiatric"],
    "urology": ["kidney", "bladder", "urinary", "prostate", "urination", "urine"],
    "oncology": ["cancer", "tumor", "malignant", "chemotherapy", "radiation", "carcinoma"],
}

# Function to categorize text by specialty
def categorize_by_specialty(text, specialties_dict):
    text = text.lower()
    matched_specialties = []
    
    for specialty, keywords in specialties_dict.items():
        for keyword in keywords:
            if re.search(r'\b' + keyword + r'\b', text):
                matched_specialties.append(specialty)
                break
    
    return matched_specialties

# Apply to a sample for faster processing
sample_df = train_df.sample(min(1000, len(train_df)), random_state=42)

# Extract specialties
all_specialties = []
for text in tqdm(sample_df["input"] + " " + sample_df["output"]):
    all_specialties.extend(categorize_by_specialty(text, specialties))

# Count frequencies
specialty_counts = Counter(all_specialties)
total_samples = len(sample_df)

# Plot specialty distribution
plt.figure(figsize=(12, 8))
specialty_df = pd.DataFrame({
    "Specialty": list(specialty_counts.keys()),
    "Count": list(specialty_counts.values()),
    "Percentage": [count/total_samples*100 for count in specialty_counts.values()]
}).sort_values("Count", ascending=False)

sns.barplot(x="Percentage", y="Specialty", data=specialty_df)
plt.title("Medical Specialties Distribution in Dataset")
plt.xlabel("Percentage of Examples (%)")
plt.tight_layout()
plt.show()

## Summary and Recommendations

### Dataset Summary

Based on the exploration above, here's a summary of the HealthCareMagic-100K dataset:

1. **Size**: Approximately 100,000 medical consultations divided into training and testing sets
2. **Format**: Patient questions (input) and doctor answers (output) pairs
3. **Text Length**: 
   - Patient questions: Average of X words (Y tokens)
   - Doctor answers: Average of X words (Y tokens)
4. **Structure**: 
   - X% of responses contain diagnosis sections
   - X% contain treatment recommendations
   - X% have both diagnosis and treatment
5. **Medical Specialties**: The dataset covers a wide range of specialties, with X, Y, and Z being the most common
6. **Common Conditions**: The most frequently mentioned conditions include X, Y, and Z

### Recommendations for Fine-tuning

Based on this analysis, here are recommendations for fine-tuning LLaMA 3.1 8B on this dataset:

1. **Context Length**: The model context length of 4096 tokens should be sufficient for X% of examples. For longer examples, consider truncation strategies that preserve the most important information.

2. **Instruction Templates**: Use templates that explicitly prompt for diagnosis and treatment plans to encourage structured responses.

3. **Response Structure**: Consider adding post-processing to ensure responses have clear diagnosis and treatment sections, as this is the format in X% of the training data.

4. **Batch Size**: Given the average token length of examples, a batch size of 2-4 should be feasible on 8 A100 GPUs with gradient accumulation.

5. **Evaluation Metrics**: Beyond general text generation metrics, include specialized metrics for medical response quality, such as diagnosis presence, treatment recommendation quality, and medical terminology usage.

6. **Domain Coverage**: The dataset has good coverage across major medical specialties, but consider evaluating performance separately for each specialty to identify areas where the model might need improvement.

7. **Data Preprocessing**: Consider structuring outputs more consistently with clear diagnosis and treatment sections for examples where these aren't explicitly marked.