In [1]:
import os
import asyncio
import random
import time
from typing import List, Literal, Optional, Dict, Any
from dotenv import load_dotenv
load_dotenv()

import pandas as pd
import numpy as np
from pydantic import BaseModel, Field
import instructor
from openai import AsyncOpenAI

# Paths
IN_DATA_PATH = "/Users/benjamindykstra/development/icd-10-coding/data/processed/structured_datasets.train.parquet"  # user-preferred
ALT_IN_DATA_PATH = "/Users/benjamindykstra/development/icd-10-coding/data/processed/structured_dataset.train.parquet"
OUT_DATA_PATH = "/Users/benjamindykstra/development/icd-10-coding/data/processed/discharge_summaries.train.parquet"

# Model config
MODEL_NAME = "gpt-5-mini"
TEMPERATURE = 0.7
CONCURRENCY = 12
MAX_RETRIES = 5

# Require API key
assert os.environ.get("OPENAI_API_KEY"), "Please set OPENAI_API_KEY in your environment."

# Structured output client (Instructor + AsyncOpenAI)
client = instructor.from_openai(AsyncOpenAI())


In [2]:
class DischargeSummary(BaseModel):
    discharge_summary: str = Field(description="Full discharge summary narrative text (250-400 words)")


def choose_policy(rng: random.Random) -> str:
    r = rng.random()
    if r < 0.5:
        return "primary_only"
    if r < 0.8:
        return "partial"
    return "all"

def get_policy_dict(rng: random.Random, primary_code: str, secondary_codes: List[str]) -> Dict[str, Any]:
    '''
    Returns a dictionary with the policy and the codes that should be included in the diagnosis_codes list
    The policy is one of "primary_only", "partial", "all"
    50% of the time, return only the primary ICD-10 code
    20% of the time, return the primary code and 1-n random secondary codes
    30% of the time, return all of the provided codes
    '''
    
    r = rng.random()
    if r < 0.5:
        return ({
            'policy': "primary_only",
            'diagnosis_codes': [primary_code]
        })
    if r < 0.8:
        # print(secondary_codes)
        if len(secondary_codes):
            return ({
                'policy': "partial",
                'diagnosis_codes': list(set([primary_code] + random.sample(list(secondary_codes), rng.randint(1, len(secondary_codes)))))
            })
        else:
            return ({
                'policy': "primary_only",
                'diagnosis_codes': [primary_code]
            })
    return ({
        'policy': "all",
        'diagnosis_codes': list(set([primary_code] + list(secondary_codes)))
    })


In [3]:
rng = random.Random(17)
# row = rows[1]
for row in rows:
  print(get_policy_dict(rng, row.get('primary_icd10'), row.get('icd10_codes', [])))

NameError: name 'rows' is not defined

In [4]:
from datetime import datetime

format_instructions = DischargeSummary.model_json_schema()

def format_list(values: Optional[List[Any]], limit: int = 10) -> str:
    if not values:
        return "None"
    vs = [str(v) for v in values[:limit]]
    if len(values) > limit:
        vs.append("…")
    return ", ".join(vs)


def build_messages(row: pd.Series) -> List[Dict[str, str]]:
    # Extract fields safely
    subject_id = int(row.get("subject_id")) if pd.notna(row.get("subject_id")) else None
    hadm_id = int(row.get("hadm_id")) if pd.notna(row.get("hadm_id")) else None
    age = int(row.get("age_at_admit")) if pd.notna(row.get("age_at_admit")) else None
    gender = str(row.get("gender")) if pd.notna(row.get("gender")) else None
    admittime = row.get("admittime")
    dischtime = row.get("dischtime")
    los_days = float(row.get("length_of_stay_days")) if pd.notna(row.get("length_of_stay_days")) else None
    admission_type = str(row.get("admission_type")) if pd.notna(row.get("admission_type")) else None

    primary_icd10 = row.get("primary_icd10") if pd.notna(row.get("primary_icd10")) else None
    primary_icd10_desc = row.get("primary_icd10_desc") if pd.notna(row.get("primary_icd10_desc")) else None
    icd10_codes = list(row.get("icd10_codes")) if isinstance(row.get("icd10_codes"), np.ndarray) else []
    icd10_descs = list(row.get("icd10_descriptions")) if isinstance(row.get("icd10_descriptions"), np.ndarray) else []

    procedures = list(row.get("procedures_icd10")) if isinstance(row.get("procedures_icd10"), np.ndarray) else []
    meds = list(row.get("meds_discharge_like")) if isinstance(row.get("meds_discharge_like"), np.ndarray) else []

    # Compose context block
    ctx = []
    # if subject_id is not None:
    #     ctx.append(f"subject_id: {subject_id}")
    # if hadm_id is not None:
    #     ctx.append(f"hadm_id: {hadm_id}")
    if age is not None:
        ctx.append(f"age_at_admit: {age}")
    if gender:
        ctx.append(f"gender: {gender}")
    if admission_type:
        ctx.append(f"admission_type: {admission_type}")
    if los_days is not None:
        ctx.append(f"length_of_stay_days: {los_days:.1f}")
    if primary_icd10:
        ctx.append(f"primary_icd10: {primary_icd10}")
    if primary_icd10_desc:
        ctx.append(f"primary_icd10_desc: {primary_icd10_desc}")
    if icd10_codes:
        ctx.append(f"icd10_codes: {format_list(icd10_codes, 30)}")
    if icd10_descs:
        ctx.append(f"icd10_descriptions: {format_list(icd10_descs, 10)}")
    if procedures:
        ctx.append(f"procedures_icd10: {format_list(procedures, 20)}")
    if meds:
        ctx.append(f"meds_discharge_like: {format_list(meds, 15)}")

    context_text = "\n".join(ctx)

    # Preliminary prompt per overall_plan.md spec
    system_prompt = '''
        You are a hospital physician writing a discharge summary for a patient being referred to home health services.
        Write a clinically realistic, concise discharge summary (250-400 words). Maintain authenticity and use appropriate medical terminology.
    '''

    user_prompt = f"""
Patient Information and Context:
{context_text}

Instructions:
- Brief history of present illness
- Hospital course
- Key findings and procedures
- Discharge condition and functional status (2-3 mentions relevant to home health)
- Medications and follow-up plan

Format your output with the following JSON schema: {format_instructions}
"""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    return messages


In [5]:
import json
async def call_model(messages) -> DischargeSummary:
    return await client.chat.completions.create(
        model=MODEL_NAME,
        response_model=DischargeSummary,
        messages=messages,
        # temperature=TEMPERATURE,
    )


async def generate_for_row(row: pd.Series, rng: random.Random, sem: asyncio.Semaphore) -> Dict[str, Any]:
    # semaphore is used to limit the number of concurrent requests to the model
    async with sem:
        # policy = choose_policy(rng)
        policy_dict = get_policy_dict(rng, row.get('primary_icd10'), row.get('icd10_codes', []))
        messages = build_messages(row)
        # print(json.dumps(messages, indent=2))

        delay = 1.0
        for attempt in range(MAX_RETRIES):
            try:
                result: DischargeSummary = await call_model(messages)
                return {
                    "subject_id": int(row["subject_id"]),
                    "hadm_id": int(row["hadm_id"]),
                    "policy": policy_dict['policy'],
                    "discharge_summary": result.discharge_summary,
                    "diagnosis_codes": policy_dict['diagnosis_codes'],
                    "model": MODEL_NAME,
                }
            except Exception as e:
                print(e)
                if attempt == MAX_RETRIES - 1:
                    raise
                await asyncio.sleep(delay + rng.random())
                delay *= 2


async def run_async(rows: List[pd.Series], concurrency: int = CONCURRENCY, seed: int = 17) -> List[Dict[str, Any]]:
    sem = asyncio.Semaphore(concurrency)
    rng = random.Random(seed)
    tasks = [asyncio.create_task(generate_for_row(r, rng, sem)) for r in rows]
    results = await asyncio.gather(*tasks)
    return results


In [6]:
DischargeSummary.model_json_schema()

{'properties': {'discharge_summary': {'description': 'Full discharge summary narrative text (250-400 words)',
   'title': 'Discharge Summary',
   'type': 'string'}},
 'required': ['discharge_summary'],
 'title': 'DischargeSummary',
 'type': 'object'}

In [7]:
# Load a small sample and run

def load_input_df() -> pd.DataFrame:
    path = IN_DATA_PATH if os.path.exists(IN_DATA_PATH) else ALT_IN_DATA_PATH
    return pd.read_parquet(path)


def to_parquet_append_safe(df: pd.DataFrame, out_path: str) -> None:
    # Simple overwrite for prototype
    df.to_parquet(out_path, index=False)


sample_n = 20
full_df = load_input_df()
sample_df = full_df.sample(n=min(sample_n, len(full_df)), random_state=17)
rows = [sample_df.iloc[i] for i in range(len(sample_df))]

# res = await run_async(rows, concurrency=CONCURRENCY)
# res_df = pd.DataFrame(res)
# res_df.head()


In [10]:
full_df.shape

(223, 21)

In [None]:
# find what's missing
set(full_df[full_df['subject_id'] == 10378479][['icd10_codes']].values[0][0])



set()

In [140]:
print(json.dumps(res[0], indent=2))

{
  "subject_id": 10923503,
  "hadm_id": 23939725,
  "policy": "partial",
  "discharge_summary": "**Discharge Summary**  \n**Patient Name:** [Patient Name]  \n**Date of Admission:** [Admission Date]  \n**Date of Discharge:** [Discharge Date]  \n**Attending Physician:** [Attending Physician Name]  \n  \n**History of Present Illness:**  \nThe patient is a 62-year-old male who presented with severe recurrent major depressive disorder (ICD-10: F332) and suicidal ideations. He was admitted urgently due to an exacerbation of depressive symptoms, exacerbated by recent psychosocial stressors and a history of alcohol abuse. During hospitalization, the patient was assessed for safety and initiated on a psychiatric treatment plan.  \n  \n**Hospital Course:**  \nOver the course of 6.8 days, the patient underwent a comprehensive psychiatric evaluation and management. He was stabilized on Sertraline, a selective serotonin reuptake inhibitor, along with supportive therapy. His suicidal ideations were

In [135]:
set(res_df[res_df['subject_id'] == 10378479]['diagnosis_codes'].values[0]).difference(set(full_df[full_df['subject_id'] == 10378479]['icd10_codes'].values[0]))

set()

In [129]:
list(full_df[full_df['subject_id'] == 10378479]['icd10_codes'].values[0]) == res_df[res_df['subject_id'] == 10378479]['diagnosis_codes'].values[0]

False

In [47]:
print(OUT_DATA_PATH)

/Users/benjamindykstra/development/icd-10-coding/data/processed/discharge_summaries.train.parquet


In [61]:
full_df['procedures_icd10'][0]

array(['0W9B30Z'], dtype=object)

In [8]:
# Run generations on the full dataset 
rows_full = [full_df.iloc[i] for i in range(len(full_df))]
res_full = await run_async(rows_full, concurrency=CONCURRENCY)
res_full_df = pd.DataFrame(res_full)

# Save full results (overwrites OUT_DATA_PATH)
to_parquet_append_safe(res_full_df, OUT_DATA_PATH)
res_full_df.head()


Unnamed: 0,subject_id,hadm_id,policy,discharge_summary,diagnosis_codes,model
0,17504528,20171885,partial,History of Present Illness: 56-year-old female...,"[I10, Z8572, Z5181, I5032, I2510, K219, I071, ...",gpt-5-mini
1,14273001,20371042,all,History of Present Illness: 73-year-old female...,"[I10, Z87891, C3431, I5031, Y929, T380X5A, K21...",gpt-5-mini
2,11357031,27612249,primary_only,History of present illness: 58-year-old man ad...,[I5033],gpt-5-mini
3,13673554,25741865,partial,History of Present Illness: 71-year-old male w...,"[N184, I5033, E1022, J918, Z9641, I739, I2510,...",gpt-5-mini
4,19017808,20589756,partial,Brief history of present illness: 77-year-old ...,"[W19XXXA, K7460, B964, Y929, I5032, E119, N288...",gpt-5-mini


In [144]:
res_full_df['policy'].value_counts()

policy
primary_only    530
partial         301
all             215
Name: count, dtype: int64

In [11]:
res_full_df.shape

(1046, 6)

In [14]:
res_full_df = res_full_df.drop(columns=['subject_id'])

In [9]:
training_df = pd.read_parquet('../data/processed/structured_dataset.train.parquet')

# join the training df with the res_full_df on hadm_id
training_df_with_summaries = training_df.merge(res_full_df, on=['hadm_id'], how='inner')
training_df_with_summaries.head()




Unnamed: 0,subject_id_x,hadm_id,admittime,dischtime,admission_type,discharge_location,gender,dod,age_at_admit,length_of_stay_days,...,primary_pdgm_bucket_simple,procedures_icd10,num_procedures_total,meds_discharge_like,medication_count,subject_id_y,policy,discharge_summary,diagnosis_codes,model
0,17504528,20171885,2137-06-17 20:40:00,2137-06-20 16:41:00,DIRECT EMER.,HOME,F,2142-07-24,56,2.834028,...,Cardiac & Circulatory,[0W9B30Z],1.0,"[5% Dextrose, Albuterol 0.083% Neb Soln, Aspir...",24.0,17504528,partial,History of Present Illness: 56-year-old female...,"[I10, Z8572, Z5181, I5032, I2510, K219, I071, ...",gpt-5-mini
1,14273001,20371042,2187-10-19 19:00:00,2187-10-21 18:13:00,OBSERVATION ADMIT,HOME,F,2188-03-27,73,1.967361,...,Cardiac & Circulatory,,,"[Acetaminophen, Aspirin, Atorvastatin, Bisacod...",19.0,14273001,all,History of Present Illness: 73-year-old female...,"[I10, Z87891, C3431, I5031, Y929, T380X5A, K21...",gpt-5-mini
2,11357031,27612249,2139-01-17 21:04:00,2139-01-22 18:00:00,OBSERVATION ADMIT,HOME HEALTH CARE,M,2144-10-28,58,4.872222,...,Cardiac & Circulatory,,,"[0.9% Sodium Chloride (Mini Bag Plus), Amoxici...",30.0,11357031,primary_only,History of present illness: 58-year-old man ad...,[I5033],gpt-5-mini
3,13673554,25741865,2176-02-25 19:39:00,2176-03-02 17:30:00,OBSERVATION ADMIT,HOME HEALTH CARE,M,2180-11-21,71,5.910417,...,Cardiac & Circulatory,[0W993ZX],1.0,"[Acetaminophen, Aspirin, Atorvastatin, Cepacol...",17.0,13673554,partial,History of Present Illness: 71-year-old male w...,"[N184, I5033, E1022, J918, Z9641, I739, I2510,...",gpt-5-mini
4,19017808,20589756,2183-01-05 21:59:00,2183-01-08 18:35:00,OBSERVATION ADMIT,HOME HEALTH CARE,F,2184-11-16,77,2.858333,...,Cardiac & Circulatory,,,"[0.9% Sodium Chloride (Mini Bag Plus), Acetami...",25.0,19017808,partial,Brief history of present illness: 77-year-old ...,"[W19XXXA, K7460, B964, Y929, I5032, E119, N288...",gpt-5-mini


In [10]:
training_df_with_summaries['primary_icd10'].apply(lambda x: [x]) + training_df_with_summaries['icd10_codes'].apply(lambda x: list(x))

0       [I5032, I5032, E118, I10, E785, I2510, I071, D...
1       [I5031, I5031, I82501, C7931, C3431, E039, Z79...
2       [I5033, I5033, J9692, J9691, E870, E872, E662,...
3       [I5033, I5033, C9110, J918, N184, N179, I129, ...
4       [I5032, I5032, I8510, N179, I272, K7460, N390,...
                              ...                        
1041    [Z5111, Z5111, C8339, I10, G40909, E8339, R740...
1042    [Z5112, Z5112, G92, C774, C786, C7989, C772, N...
1043    [Z5111, Z5111, C8378, C8290, Z86711, E119, I35...
1044    [Z5111, Z5111, C8235, C8333, E876, R12, T380X5...
1045    [Z432, Z432, K50914, O8619, O9963, Z3A37, Z904...
Length: 1046, dtype: object

In [11]:
# for each row get the all the icd10 codes, and find what the missing codes are using a set difference
training_df_with_summaries['true_icd_codes'] = training_df_with_summaries['icd10_codes'].apply(lambda x: list(x))
training_df_with_summaries['missing_codes'] = training_df_with_summaries[['true_icd_codes', 'diagnosis_codes']].apply(lambda x: list(set(x['true_icd_codes']) - set(x['diagnosis_codes'])), axis=1)
training_df_with_summaries['missing_codes'].head()




0    [I480, D473, Z8673, Z7901, E785, E118, Z954, F...
1                                                   []
2    [E662, E1165, Z6833, R000, J9692, E785, I272, ...
3                                                   []
4    [S5001XA, K7581, I10, I4891, N390, I878, N179,...
Name: missing_codes, dtype: object

In [12]:
training_df.shape

(1046, 21)

In [16]:
training_df_with_summaries['missing_codes'][0]

['Z7901',
 'I272',
 'E039',
 'D72829',
 'F419',
 'E785',
 'J45909',
 'I480',
 'Z8673',
 'I2510',
 'F329']

In [17]:
training_df_with_summaries['diagnosis_codes'][0]

['E118',
 'Z5181',
 'I5032',
 'K219',
 'D473',
 'I10',
 'I071',
 'M8580',
 'Z8572',
 'D649',
 'Z954']

In [13]:
training_df_with_summaries.to_parquet('../data/processed/structured_dataset_with_discharge_summaries.train.parquet', index = False)

In [14]:
res = await generate_for_row(rows[1], random.Random(17), asyncio.Semaphore(CONCURRENCY))

In [16]:
print(json.dumps(res, indent=2))

{
  "subject_id": 10378479,
  "hadm_id": 24421815,
  "policy": "partial",
  "discharge_summary": "Brief history of present illness: 63-year-old man presented to the emergency department with acute onset substernal chest pain and ST-elevation in anterior leads. Cardiac enzymes were elevated consistent with an acute anterior STEMI. Past medical history notable for prior liver and kidney transplant, paroxysmal atrial fibrillation, type 2 diabetes mellitus, hypertension, atherosclerotic coronary disease with prior angina, and long-term anticoagulant use.\n\nHospital course: Patient was admitted directly to the cardiac catheterization lab and underwent urgent percutaneous coronary intervention to the culprit anterior coronary distribution with successful revascularization. Post-procedure he was monitored in the cardiac unit for 48 hours. Chest pain resolved, troponin trended down, and telemetry showed rate-controlled atrial fibrillation without sustained arrhythmia. No access-site complicat

In [24]:
type(rows[0].get("icd10_codes"))

numpy.ndarray

In [37]:
res_df['policy'].value_counts()

policy
primary_only    11
all              6
partial          3
Name: count, dtype: int64

In [None]:
# Save results for the sample
if not res_df.empty:
    to_parquet_append_safe(res_df, OUT_DATA_PATH)
res_df[["subject_id", "hadm_id", "policy"]].head()
