# Clinical Synthetic Data Generator
### Week 3 Exercise — Stella Oiro (Andela AI Engineering Bootcamp)

Generate realistic synthetic clinical datasets for research, education, and testing — no real patient data required.

**Dataset types you can generate:**
-  Patient Demographics & Diagnoses
-  Prescription Records
-  Clinical Trial Participants
-  Adverse Drug Event (ADE) Reports
-  Laboratory Results

**Models supported:**
-  HuggingFace: `meta-llama/Meta-Llama-3.1-8B-Instruct` (4-bit, T4 GPU)
-  OpenAI: `gpt-4.1-mini` (via API key)

> **Note:** This notebook is optimised for Google Colab with a T4 GPU.
> For the HuggingFace model you need a `HF_TOKEN` saved in Colab Secrets.
> For GPT you need an `OPENAI_API_KEY` saved in Colab Secrets.

In [None]:
# Install dependencies (Colab)
%pip install -q transformers accelerate bitsandbytes torch gradio openai python-dotenv

In [None]:
# Check GPU availability
import subprocess
try:
    gpu_info = subprocess.check_output(["nvidia-smi"], text=True)
except (FileNotFoundError, subprocess.CalledProcessError):
    gpu_info = "NVIDIA-SMI has failed"
if 'failed' in gpu_info.lower():
    print('  No GPU detected — HuggingFace model will be slow. GPT mode recommended.')
else:
    print(gpu_info)
    if 'T4' in gpu_info:
        print(' Connected to T4 GPU — ready for HuggingFace model.')
    else:
        print(' GPU detected.')

In [None]:
import os
import json
import pandas as pd
import torch
import gradio as gr
from openai import OpenAI

# ── API keys ─────────────────────────────────────────────────────────────────
# In Colab: use Secrets (key icon in left sidebar)
# Locally:  use a .env file
try:
    from google.colab import userdata
    OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
    HF_TOKEN       = userdata.get('HF_TOKEN')
    IN_COLAB = True
    print('Running in Google Colab')
except ImportError:
    from dotenv import load_dotenv
    load_dotenv(override=True)
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
    HF_TOKEN       = os.getenv('HF_TOKEN')
    IN_COLAB = False
    print('Running locally')

if OPENAI_API_KEY:
    print(f'OpenAI key found: {OPENAI_API_KEY[:8]}...')
else:
    print('  No OpenAI key — GPT mode disabled.')

if HF_TOKEN:
    print('HuggingFace token found.')
else:
    print('  No HF token — HuggingFace model disabled.')

In [None]:
# ── HuggingFace login ─────────────────────────────────────────────────────────
from huggingface_hub import login

if HF_TOKEN:
    login(HF_TOKEN, add_to_git_credential=True)
    print('Logged in to HuggingFace.')
else:
    print('Skipping HuggingFace login (no token).')

In [None]:
# ── Load HuggingFace model ────────────────────────────────────────────────────
# Only loads if HF_TOKEN is present and GPU is available.
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

HF_MODEL_NAME = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
hf_model     = None
hf_tokenizer = None

if HF_TOKEN and torch.cuda.is_available():
    print(f'Loading {HF_MODEL_NAME} with 4-bit quantisation...')
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type='nf4'
    )
    hf_tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
    hf_tokenizer.pad_token = hf_tokenizer.eos_token
    hf_model = AutoModelForCausalLM.from_pretrained(
        HF_MODEL_NAME,
        device_map='auto',
        quantization_config=quant_config
    )
    print(' HuggingFace model ready.')
else:
    print('  HuggingFace model not loaded (no GPU or no HF token).')
    print('    GPT mode will be used instead.')

In [None]:
# ── OpenAI client ─────────────────────────────────────────────────────────────
GPT_MODEL = 'gpt-4.1-mini'
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
print(f'GPT client ready: {bool(openai_client)}')

In [None]:
# ── Clinical dataset schemas ──────────────────────────────────────────────────
# Each schema tells the LLM exactly what fields to generate.

DATASET_SCHEMAS = {
    "Patient Demographics & Diagnoses": {
        "description": "Synthetic patient records with demographics and primary diagnoses.",
        "fields": [
            "patient_id (e.g. PT-001)",
            "age (18-90)",
            "sex (Male/Female/Non-binary)",
            "ethnicity",
            "primary_diagnosis (ICD-10 name and code)",
            "comorbidities (list of 1-3)",
            "bmi",
            "smoking_status (Never/Ex/Current)",
            "date_of_admission (YYYY-MM-DD, 2022-2024)"
        ],
        "example": {
            "patient_id": "PT-001",
            "age": 54,
            "sex": "Female",
            "ethnicity": "Black British",
            "primary_diagnosis": "Type 2 Diabetes Mellitus (E11)",
            "comorbidities": ["Hypertension", "Obesity"],
            "bmi": 31.4,
            "smoking_status": "Ex",
            "date_of_admission": "2023-07-14"
        }
    },
    "Prescription Records": {
        "description": "Synthetic prescriptions including drug, dose, and prescriber details.",
        "fields": [
            "prescription_id (e.g. RX-001)",
            "patient_id",
            "drug_name (generic)",
            "drug_class",
            "dose_mg",
            "frequency (e.g. once daily, BD, TDS)",
            "route (oral/IV/inhaled/topical)",
            "indication",
            "prescriber_role (GP/Consultant/Registrar/NP)",
            "date_prescribed (YYYY-MM-DD)",
            "duration_days"
        ],
        "example": {
            "prescription_id": "RX-001",
            "patient_id": "PT-001",
            "drug_name": "metformin",
            "drug_class": "Biguanide",
            "dose_mg": 500,
            "frequency": "twice daily",
            "route": "oral",
            "indication": "Type 2 Diabetes Mellitus",
            "prescriber_role": "GP",
            "date_prescribed": "2023-07-15",
            "duration_days": 90
        }
    },
    "Clinical Trial Participants": {
        "description": "Synthetic clinical trial enrolment records with outcomes.",
        "fields": [
            "participant_id (e.g. CTP-001)",
            "trial_id (e.g. TRIAL-2023-HF01)",
            "trial_phase (I/II/III/IV)",
            "intervention_arm (Drug/Placebo/Active comparator)",
            "drug_name",
            "dose_mg",
            "age",
            "sex",
            "primary_endpoint_met (Yes/No)",
            "serious_adverse_event (Yes/No)",
            "withdrawal_reason (Completed/AE/Lost to follow-up/Withdrawn consent)",
            "follow_up_weeks"
        ],
        "example": {
            "participant_id": "CTP-001",
            "trial_id": "TRIAL-2023-HF01",
            "trial_phase": "III",
            "intervention_arm": "Drug",
            "drug_name": "empagliflozin",
            "dose_mg": 10,
            "age": 67,
            "sex": "Male",
            "primary_endpoint_met": "Yes",
            "serious_adverse_event": "No",
            "withdrawal_reason": "Completed",
            "follow_up_weeks": 52
        }
    },
    "Adverse Drug Event Reports": {
        "description": "Synthetic pharmacovigilance reports modelled on WHO/CIOMS format.",
        "fields": [
            "report_id (e.g. ADE-001)",
            "suspect_drug",
            "indication",
            "adverse_event (MedDRA preferred term)",
            "severity (Mild/Moderate/Severe/Life-threatening)",
            "seriousness (Serious/Non-serious)",
            "causality (Certain/Probable/Possible/Unlikely)",
            "patient_age",
            "patient_sex",
            "time_to_onset_days",
            "outcome (Recovered/Recovering/Not recovered/Fatal/Unknown)",
            "reporter_type (Physician/Pharmacist/Patient/Nurse)"
        ],
        "example": {
            "report_id": "ADE-001",
            "suspect_drug": "warfarin",
            "indication": "Atrial fibrillation",
            "adverse_event": "Gastrointestinal haemorrhage",
            "severity": "Severe",
            "seriousness": "Serious",
            "causality": "Probable",
            "patient_age": 74,
            "patient_sex": "Female",
            "time_to_onset_days": 12,
            "outcome": "Recovered",
            "reporter_type": "Physician"
        }
    },
    "Laboratory Results": {
        "description": "Synthetic lab panels including haematology and metabolic results.",
        "fields": [
            "lab_id (e.g. LAB-001)",
            "patient_id",
            "collection_date (YYYY-MM-DD)",
            "haemoglobin_g_dL (normal 12-17)",
            "wbc_10e9_L (normal 4-11)",
            "platelets_10e9_L (normal 150-400)",
            "sodium_mmol_L (normal 135-145)",
            "potassium_mmol_L (normal 3.5-5)",
            "creatinine_umol_L (normal 60-110)",
            "egfr_mL_min_1.73m2",
            "hba1c_mmol_mol",
            "flag (Normal/Abnormal/Critical)"
        ],
        "example": {
            "lab_id": "LAB-001",
            "patient_id": "PT-001",
            "collection_date": "2023-07-15",
            "haemoglobin_g_dL": 13.2,
            "wbc_10e9_L": 6.8,
            "platelets_10e9_L": 234,
            "sodium_mmol_L": 138,
            "potassium_mmol_L": 4.1,
            "creatinine_umol_L": 89,
            "egfr_mL_min_1.73m2": 72,
            "hba1c_mmol_mol": 58,
            "flag": "Abnormal"
        }
    }
}

print(f'Schemas loaded for {len(DATASET_SCHEMAS)} dataset types.')

In [None]:
# ── Prompt builder ────────────────────────────────────────────────────────────

def build_prompt(dataset_type: str, num_records: int) -> tuple[str, str]:
    """Build system + user prompt for the chosen dataset type."""
    schema = DATASET_SCHEMAS[dataset_type]
    fields_str = '\n'.join(f'  - {f}' for f in schema['fields'])
    example_str = json.dumps(schema['example'], indent=2)

    system_prompt = (
        "You are a clinical data engineer generating realistic but entirely synthetic medical datasets.\n"
        "Your output MUST be valid JSON only — a JSON array of objects, no markdown, no commentary.\n"
        "Use realistic clinical values, proper medical terminology, and plausible variation between records."
    )

    user_prompt = (
        f"Generate exactly {num_records} synthetic {dataset_type} records.\n\n"
        f"Description: {schema['description']}\n\n"
        f"Required fields:\n{fields_str}\n\n"
        f"Example of ONE record (follow this exact structure):\n{example_str}\n\n"
        f"Return ONLY a JSON array of {num_records} objects. No extra text."
    )

    return system_prompt, user_prompt

print('Prompt builder ready.')

In [None]:
# ── Generation functions ──────────────────────────────────────────────────────

def generate_with_gpt(dataset_type: str, num_records: int) -> list:
    """Generate synthetic data using GPT-4.1-mini."""
    if not openai_client:
        raise RuntimeError('OpenAI client not configured. Add OPENAI_API_KEY.')
    system_prompt, user_prompt = build_prompt(dataset_type, num_records)
    response = openai_client.chat.completions.create(
        model=GPT_MODEL,
        messages=[
            {'role': 'system',  'content': system_prompt},
            {'role': 'user',    'content': user_prompt}
        ],
        response_format={'type': 'json_object'}
    )
    raw = response.choices[0].message.content
    parsed = json.loads(raw)
    # GPT json_object wraps arrays — unwrap if needed
    if isinstance(parsed, dict):
        for v in parsed.values():
            if isinstance(v, list):
                return v
    return parsed


def generate_with_hf(dataset_type: str, num_records: int) -> list:
    """Generate synthetic data using local HuggingFace model."""
    if hf_model is None or hf_tokenizer is None:
        raise RuntimeError('HuggingFace model not loaded.')
    system_prompt, user_prompt = build_prompt(dataset_type, num_records)
    messages = [
        {'role': 'system', 'content': system_prompt},
        {'role': 'user',   'content': user_prompt}
    ]
    input_ids = hf_tokenizer.apply_chat_template(
        messages, return_tensors='pt', add_generation_prompt=True
    ).to(hf_model.device)
    with torch.no_grad():
        outputs = hf_model.generate(
            input_ids,
            max_new_tokens=min(3000, num_records * 250),
            temperature=0.7,
            do_sample=True,
            top_p=0.9,
            pad_token_id=hf_tokenizer.eos_token_id
        )
    text = hf_tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract JSON array from output
    start = text.find('[')
    end   = text.rfind(']') + 1
    if start == -1 or end == 0:
        raise ValueError('No JSON array found in model output.')
    return json.loads(text[start:end])


def generate_records(dataset_type: str, num_records: int, model_choice: str) -> list:
    """Route to the correct generation function."""
    if model_choice.startswith('GPT'):
        return generate_with_gpt(dataset_type, num_records)
    else:
        return generate_with_hf(dataset_type, num_records)

print('Generation functions ready.')

In [None]:
# ── Export helpers ────────────────────────────────────────────────────────────

def records_to_csv(records: list, dataset_type: str) -> str:
    """Save records to CSV and return the file path."""
    filename = f"{dataset_type.lower().replace(' ', '_').replace('&', 'and')}.csv"
    pd.DataFrame(records).to_csv(filename, index=False)
    return filename


def records_to_json(records: list, dataset_type: str) -> str:
    """Save records to JSON and return the file path."""
    filename = f"{dataset_type.lower().replace(' ', '_').replace('&', 'and')}.json"
    with open(filename, 'w') as f:
        json.dump(records, f, indent=2)
    return filename


print('Export helpers ready.')

In [None]:
# ── Gradio UI ─────────────────────────────────────────────────────────────────

MODEL_CHOICES = []
if openai_client:
    MODEL_CHOICES.append(f'GPT ({GPT_MODEL})')
if hf_model:
    MODEL_CHOICES.append('HuggingFace (Llama-3.1-8B)')
if not MODEL_CHOICES:
    MODEL_CHOICES = ['GPT (gpt-4.1-mini) — needs API key']


def gradio_generate(dataset_type, num_records, model_choice):
    """Called by Gradio when the user clicks Generate."""
    try:
        records = generate_records(dataset_type, int(num_records), model_choice)
        df = pd.DataFrame(records)
        csv_path  = records_to_csv(records,  dataset_type)
        json_path = records_to_json(records, dataset_type)
        status = f' Generated {len(records)} {dataset_type} records using {model_choice}.'
        return status, df, csv_path, json_path
    except Exception as e:
        return f' Error: {e}', None, None, None


with gr.Blocks(title='Clinical Synthetic Data Generator', theme=gr.themes.Soft()) as ui:

    gr.Markdown("""
    #  Clinical Synthetic Data Generator
    Generate realistic synthetic clinical datasets for research, education, and testing.
    *All data is entirely synthetic — no real patient information is used or produced.*
    """)

    with gr.Row():
        with gr.Column(scale=1):
            dataset_dd = gr.Dropdown(
                choices=list(DATASET_SCHEMAS.keys()),
                value=list(DATASET_SCHEMAS.keys())[0],
                label='Dataset Type'
            )
            num_slider = gr.Slider(
                minimum=5, maximum=20, value=10, step=5,
                label='Number of Records'
            )
            model_dd = gr.Dropdown(
                choices=MODEL_CHOICES,
                value=MODEL_CHOICES[0],
                label='Model'
            )
            generate_btn = gr.Button('Generate Dataset', variant='primary')

            gr.Markdown('### Download')
            csv_file  = gr.File(label='CSV')
            json_file = gr.File(label='JSON')

        with gr.Column(scale=2):
            status_box = gr.Textbox(label='Status', interactive=False)
            data_table = gr.Dataframe(label='Preview', wrap=True)

    gr.Examples(
        examples=[
            ['Patient Demographics & Diagnoses',  10, MODEL_CHOICES[0]],
            ['Adverse Drug Event Reports',        10, MODEL_CHOICES[0]],
            ['Laboratory Results',                 5, MODEL_CHOICES[0]],
        ],
        inputs=[dataset_dd, num_slider, model_dd]
    )

    generate_btn.click(
        fn=gradio_generate,
        inputs=[dataset_dd, num_slider, model_dd],
        outputs=[status_box, data_table, csv_file, json_file]
    )

print('UI built.')

In [None]:
ui.launch(share=True, debug=True)