we are going to be doing a toy project where the modeling task is to predict the first character of the ICD10 code based on it's description. 

run this cell to import the tools we need 

In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%matplotlib inline

import polars as pl
import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline

from clinical_ai.data_utils import render_chat_inputs, predict_letters_batched 

the model is composed of a large language model (LLM) and a tokenizer

In [13]:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"

llm = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_id,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# (set pad_token_id to avoid warnings if the tokenizer lacks one)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

model = pipeline(
    task="text-generation",
    model=llm,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
)

Device set to use cpu


this is how you use the model

you give it a set of instructions, user_instruction, on how to use the data, user_context, then the two are packed into a prompt and the model predicts what comes next after the prompt, hopefully the thing it predicts comes next is the right answer

In [19]:
user_instruction = """
I am about to give you a description of a disease. 
Predict the first character of the ICD10 Code associated with that description.
In general E codes are related to Endocrine, nutritional and metabolic diseases
'Type 2 diabetes mellitus with diabetic chronic kidney disease' corresponds to 'E11.22'
So for example: 
if I say 'Type 2 diabetes mellitus with diabetic chronic kidney disease', then you just respond 'E'
"""

enc = render_chat_inputs(
    tokenizer = tokenizer,
    user_instruction = user_instruction,
    user_context = "Type 2 diabetes mellitus with diabetic chronic kidney disease",
)

prompt = enc["text"]

print(repr(prompt),"\n")

# Greedy decoding for exactly 1 new tokens
out = model(
    prompt,
    max_new_tokens=1,
    do_sample=False,            # greedy (argmax) decoding
    num_return_sequences=1,
    return_full_text=False,     # only return the newly generated continuation
    pad_token_id=pad_id,
    eos_token_id=tokenizer.eos_token_id,
)

greedy_3_tokens_text = out[0]["generated_text"]
print(f"The first character is: {greedy_3_tokens_text}")

"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nI am about to give you a description of a disease. \nPredict the first character of the ICD10 Code associated with that description.\nIn general E codes are related to Endocrine, nutritional and metabolic diseases\n'Type 2 diabetes mellitus with diabetic chronic kidney disease' corresponds to 'E11.22'\nSo for example: \nif I say 'Type 2 diabetes mellitus with diabetic chronic kidney disease', then you just respond 'E'\n<|im_end|>\n<|im_start|>user\nType 2 diabetes mellitus with diabetic chronic kidney disease<|im_end|>\n<|im_start|>assistant\n" 

The first character is: E


now lets load some data for the tasks

In [20]:
df = pl.read_csv("icd10_order_fy26.csv")
# filter rows where CODE starts with A, B, C, D, or E
df = df.filter(
    pl.col("CODE").str.slice(0, 1).is_in(["A", "B", "C", "D", "E"])
)
df = df.sample(n=100, with_replacement=False, shuffle=True, seed=42)
df.head(4)

ORDER_NUMBER,CODE,IS_VALID,SHORT_DESCRIPTION,LONG_DESCRIPTION,PARENT_ORDER_NUMBER,PARENT_CODE,PARENT_IS_VALID,PARENT_SHORT_DESCRIPTION,PARENT_LONG_DESCRIPTION
i64,str,i64,str,str,i64,str,i64,str,str
4198,"""E103211""",1,"""Type 1 diab with mild nonp rtn…","""Type 1 diabetes mellitus with …",4184,"""E10""",0,"""Type 1 diabetes mellitus""","""Type 1 diabetes mellitus"""
1645,"""C4419""",0,"""Oth malignant neoplasm of skin…","""Other specified malignant neop…",1606,"""C44""",0,"""Other and unspecified malignan…","""Other and unspecified malignan…"
5073,"""E849""",1,"""Cystic fibrosis, unspecified""","""Cystic fibrosis, unspecified""",5067,"""E84""",0,"""Cystic fibrosis""","""Cystic fibrosis"""
3616,"""D6103""",1,"""Fanconi anemia""","""Fanconi anemia""",3612,"""D61""",0,"""Oth aplastic anemias and other…","""Other aplastic anemias and oth…"


In [21]:
df.schema

Schema([('ORDER_NUMBER', Int64),
        ('CODE', String),
        ('IS_VALID', Int64),
        ('SHORT_DESCRIPTION', String),
        ('LONG_DESCRIPTION', String),
        ('PARENT_ORDER_NUMBER', Int64),
        ('PARENT_CODE', String),
        ('PARENT_IS_VALID', Int64),
        ('PARENT_SHORT_DESCRIPTION', String),
        ('PARENT_LONG_DESCRIPTION', String)])

we use the instructions above to make predictions on all the disease descriptions. this may take several seconds. you will know its done predicting when the table is finally printed.

In [28]:
# --- Build prompts for each LONG_DESCRIPTION ---
# render_chat_inputs(...) -> {"text": "..."} for the pipeline

descriptions = df["LONG_DESCRIPTION"].to_list()

prompts = [
    render_chat_inputs(
        tokenizer=tokenizer,
        user_instruction=user_instruction,
        user_context=desc if isinstance(desc, str) else ""
    )["text"]
    for desc in descriptions
]

print(len(prompts), "prompts")

preds = predict_letters_batched(prompts, model, tokenizer, pad_id, batch_size=32)

# --- Attach to Polars DataFrame ---
df = df.with_columns(pl.Series(name="PREDICTION", values=preds))

# (Optional) quick sanity check
print(df.select(["CODE", "LONG_DESCRIPTION", "PREDICTION"]).head(10))

100 prompts
shape: (10, 3)
┌─────────┬─────────────────────────────────┬────────────┐
│ CODE    ┆ LONG_DESCRIPTION                ┆ PREDICTION │
│ ---     ┆ ---                             ┆ ---        │
│ str     ┆ str                             ┆ str        │
╞═════════╪═════════════════════════════════╪════════════╡
│ E103211 ┆ Type 1 diabetes mellitus with … ┆ E          │
│ C4419   ┆ Other specified malignant neop… ┆ T          │
│ E849    ┆ Cystic fibrosis, unspecified    ┆ C          │
│ D6103   ┆ Fanconi anemia                  ┆ C          │
│ E093293 ┆ Drug or chemical induced diabe… ┆ D          │
│ D142    ┆ Benign neoplasm of trachea      ┆ B          │
│ A880    ┆ Enteroviral exanthematous feve… ┆ E          │
│ D2210   ┆ Melanocytic nevi of unspecifie… ┆ C          │
│ B582    ┆ Toxoplasma meningoencephalitis  ┆ T          │
│ E71121  ┆ Propionic acidemia              ┆ T          │
└─────────┴─────────────────────────────────┴────────────┘


I will show you a few examples to give you an idea of the general shape of the answers

First, here is how you calculate the accuracy of predicting when a description belongs to the E group of ICDs or not. 

In [30]:
# Add E first character
df = df.with_columns(
    pl.col("CODE").str.slice(0, 1).alias("E_FIRST_CHAR")
)

# Binary ground truth: 1 if true code starts with E, else 0
df = df.with_columns(
    (pl.col("E_FIRST_CHAR") == "E").cast(pl.Int8).alias("IS_E_TRUE"),
    (pl.col("PREDICTION") == "E").cast(pl.Int8).alias("IS_E_PRED")
)

# Correct if prediction matches truth
df = df.with_columns(
    (pl.col("IS_E_TRUE") == pl.col("IS_E_PRED")).alias("IS_CORRECT_E_BINARY")
)

# Compute accuracy
accuracy_E_binary = df["IS_CORRECT_E_BINARY"].mean()
print(f"Binary accuracy (is E vs not E): {accuracy_E_binary:.4f}")

Binary accuracy (is E vs not E): 0.7800


this is how you calculate the recall, also known as the sensitivity, of the model being able to find all the E descriptions

In [25]:
# All actual E codes
true_E = df.filter(pl.col("E_FIRST_CHAR") == "E")

# True positives: predicted E when actual is E
true_pos_E = true_E.filter(pl.col("PREDICTION") == "E").height

# Total actual E
total_true_E = true_E.height

recall_E = true_pos_E / total_true_E if total_true_E > 0 else float("nan")

print(f"Recall (sensitivity) for E: {recall_E:.4f}")

Recall (sensitivity) for E: 0.2800


# 3 Tasks

1. Make one cell to calculate the positive predictive value (PPV), also known as the precision. In other words, if the model predicts the description to be an E, what is the probability that it really is an E.

2. Given this confusion matrix:

```
TN = 91 
TP = 3 
FP = 3 
FN = 3
```

What is the accuracy, precision (PPV) and recall (sensitivity)? if the accuracy is >90% and for that reason is considered good enough to use in practice, what is the customer of the model most likely to complain about?

3. In the cells below, redo the above code but make some modifications to the instruction part of the prompt to increase the positive predictive value (PPV), also known as the precision.