<a href="https://colab.research.google.com/github/bluemens/ICD_Coding_Notebook/blob/main/ICD_coding_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# Install dependencies
!pip install transformers nltk --quiet

In [35]:
# Import packages
import re
import nltk
import torch
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score, classification_report
from sklearn.tree import export_text

In [6]:
# Download sentence tokenizer
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [71]:
# Load ClinicalBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [46]:
# Sample raw note
sample_note = """
CHIEF COMPLAINT: SOB, fatigue. Other stuff
HPI: 65 y/o M w/ h/o CHF presents w/ SOB x 3 days, worse on exertion. Denies CP.
ASSESSMENT: Likely CHF exacerbation. Will start IV Lasix.
"""

In [38]:

def parse_sections(note):
    """
    Parse a clinical note into labeled sections using regex-based headers.
    Returns a dict {section_name: section_text}
    """
    # Common section headers (expand as needed)
    section_headers = [
        "CHIEF COMPLAINT", "HISTORY OF PRESENT ILLNESS", "HPI",
        "PAST MEDICAL HISTORY", "FAMILY HISTORY", "SOCIAL HISTORY",
        "PHYSICAL EXAM", "ASSESSMENT", "PLAN", "ASSESSMENT AND PLAN",
        "REVIEW OF SYSTEMS", "MEDICATIONS", "LABS", "DISCHARGE DIAGNOSIS"
    ]

    # Normalize headers
    section_pattern = "|".join([re.escape(h) for h in section_headers])
    regex = re.compile(rf"^\s*({section_pattern})\s*[:\-]?\s*$", re.IGNORECASE | re.MULTILINE)

    sections = {}
    matches = list(regex.finditer(note))

    for i, match in enumerate(matches):
        start = match.end()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(note)
        section_name = match.group(1).strip().upper()
        section_text = note[start:end].strip()
        sections[section_name] = section_text

    return sections

In [73]:
def preprocess_note(note):
    """
    Parses note into sections, tokenizes into sentences,
    maps sentences to sections, prepares input for ClinicalBERT.
    """
    structured = {}
    structured["note_id"] = "example-001"
    structured["sections"] = parse_sections(note)

    sentences = []
    sentence_to_section = {}

    # Tokenize by sentence, track section
    for section, text in structured["sections"].items():
        section_sentences = nltk.sent_tokenize(text)
        for sent in section_sentences:
            sentence_to_section[len(sentences)] = section
            sentences.append(sent)

    # Join all for BERT input
    full_doc = " ".join(sentences)
    encoded = tokenizer(full_doc, padding=True, truncation=True, return_tensors="pt")

    structured["sentences"] = sentences
    structured["sentence_to_section"] = sentence_to_section
    structured["full_doc"] = full_doc
    structured["tokens"] = tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
    structured["token_ids"] = encoded["input_ids"][0].tolist()
    structured["encoded"] = encoded  # optional: keep raw inputs for later

    with torch.no_grad():
        output = bert_model(**encoded)
        structured["cls_embedding"] = output.last_hidden_state[:, 0, :].squeeze(0)  # shape [768]


    return structured

In [77]:
#Experimental example note with Sections (structured not free)
note = """
CHIEF COMPLAINT:
Chest pain and shortness of breath.

HPI:
Patient reports SOB worsening over 3 days.

ASSESSMENT:
CHF exacerbation. Start IV Lasix.
"""

#Free Text Examples
clinical_notes = [
    "Patient presents with chest pain and shortness of breath.",
    "Type 2 diabetes with foot ulcer.",
    "Sepsis post-surgery with fever.",
    "Asthma and wheezing.",
    "Appendicitis and abdominal pain.",
]


data = preprocess_note(note)
print(data["sentences"])
print(data["sentence_to_section"])
print(data["tokens"][:20])  # preview
print(data["token_ids"][:20])  # preview

['Chest pain and shortness of breath.', 'Patient reports SOB worsening over 3 days.', 'CHF exacerbation.', 'Start IV Lasix.']
{0: 'CHIEF COMPLAINT', 1: 'HPI', 2: 'ASSESSMENT', 3: 'ASSESSMENT'}
['[CLS]', 'chest', 'pain', 'and', 'short', '##ness', 'of', 'breath', '.', 'patient', 'reports', 'sob', 'worse', '##ning', 'over', '3', 'days', '.', 'ch', '##f']
[101, 2229, 2489, 1105, 1603, 1757, 1104, 2184, 119, 5351, 3756, 20295, 4146, 3381, 1166, 124, 1552, 119, 22572, 2087]


In [25]:
#Concept Extraction
def load_icd10_blocks_from_txt(file_path):
    """
    Load ICD-10 block concept vocabulary from a structured .txt file.
    Each line should follow the format: CODE-CODE  Description
    Example: A00-A09  Intestinal infectious diseases
    """
    block_pattern = re.compile(r'^([A-Z]\d{2})-([A-Z]\d{2})\s+(.*)$')
    blocks = []

    with open(file_path, 'r') as file:
        for line in file:
            match = block_pattern.match(line.strip())
            if match:
                start_code, end_code, description = match.groups()
                blocks.append({
                    "block_start": start_code,
                    "block_end": end_code,
                    "concept_name": description
                })

    return pd.DataFrame(blocks)


def icd_code_to_block(code, block_df):
    code = code.strip().upper()
    for _, row in block_df.iterrows():
        if row["block_start"] <= code <= row["block_end"]:
            return row["concept_name"]
    return None

# Embed all concept descriptions
def get_concept_description_embeddings(concept_names):
    embeddings = []
    with torch.no_grad():
        for desc in concept_names:
            inputs = tokenizer(desc, return_tensors="pt").to(device)
            output = bert_model(**inputs)
            cls = output.last_hidden_state[:, 0, :]  # [CLS] token
            embeddings.append(cls.squeeze(0))
    return torch.stack(embeddings)  # [K, 768]


In [37]:
block_df = load_icd10_blocks_from_txt("/content/icd_blocks.txt")
NUM_CONCEPTS = block_df.shape[0]
concept_vocab = block_df["concept_name"].tolist()
concept_index = {concept: i for i, concept in enumerate(concept_vocab)}

description_embeddings = get_concept_description_embeddings(concept_vocab).to(device)

In [36]:
class ConceptDataset(Dataset):
    def __init__(self, notes, label_blocks, concept_index):
        self.notes = notes
        self.labels = label_blocks
        self.concept_index = concept_index

    def __len__(self):
        return len(self.notes)

    def __getitem__(self, idx):
        cls = self.data[idx]["cls_embedding"]
        c = [0] * len(self.concept_index)
        for label in self.labels[idx]:
            if label in self.concept_index:
                c[self.concept_index[label]] = 1
        return cls, torch.tensor(c).float()


class ConceptPredictor(nn.Module):
    def __init__(self, hidden_dim=768, num_concepts=len(concept_names)):
        super().__init__()
        self.feature_net = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.output_layer = nn.Linear(256, num_concepts)
        self.sigmoid = nn.Sigmoid()

    def forward(self, cls_embedding):
        feats = self.feature_net(cls_embedding)
        logits = self.output_layer(feats)
        probs = self.sigmoid(logits)
        return probs

In [None]:
def collate_batch(batch):
    cls_embeddings, labels = zip(*batch)
    return torch.stack(cls_embeddings), torch.stack(labels)

In [None]:
# === Simulated dataset ===
note_icd_codes = [
    ["I50.9", "J18.9"],     # Heart failure, Pneumonia
    ["E11.9"],              # Type 2 diabetes
    ["A41.9"]             # Sepsis
]

note_labels = []
for code_list in note_icd_codes:
    blocks = set()
    for code in code_list:
        block = icd_code_to_block(code, block_df)
        if block:
            blocks.add(block)
    note_labels.append(list(blocks))

In [None]:
# === Setup training ===
dataset = ConceptDataset(clinical_notes, note_labels, concept_index)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)

concept_model = ConceptPredictor().to(device)
optimizer = torch.optim.Adam(concept_model.parameters(), lr=1e-4)
loss_fn = nn.BCELoss()

In [None]:
#Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(5):
    total_loss = 0
    concept_model.train()

    for cls_batch, label_batch in dataloader:
        cls_batch = cls_batch.to(device)
        label_batch = label_batch.to(device)

        preds = concept_model(cls_batch)

        # Standard concept prediction loss
        L_bce = loss_fn(preds, targets)

        # Alignment loss: output weights should match BERT embeddings of concepts
        W = concept_model.output_layer.weight  # shape: [K, 256]
        W_proj = W @ concept_model.feature_net[0].weight.T  # Project to input space
        L_align = torch.nn.functional.mse_loss(W_proj, description_embeddings)

        # Total loss (λ = 0.1)
        loss = L_bce + 0.1 * L_align

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}: Loss = {total_loss:.4f}")

In [None]:
#SYMBOLIC COMPRESSION
concept_model.eval()
concept_vectors = []
target_labels = []

with torch.no_grad():
    for cls_batch, label_batch in dataloader:
        cls_batch = cls_batch.to(device)
        preds = concept_model(cls_batch).cpu()
        concept_vectors.append(preds)
        target_labels.append(label_batch)

X = torch.cat(concept_vectors).numpy()  # shape: [N, K]
Y = torch.cat(target_labels).numpy()    # shape: [N, K]

In [None]:
#Tree Fitting
multi_tree = MultiOutputClassifier(
    DecisionTreeClassifier(max_depth=4, min_samples_leaf=5)
)
multi_tree.fit(X, Y)

In [None]:
#Code Prediction Evaluation

Y_pred = multi_tree.predict(X)

print("Micro F1:", f1_score(Y, Y_pred, average='micro'))
print("Macro F1:", f1_score(Y, Y_pred, average='macro'))
print("\nClassification Report:\n", classification_report(Y, Y_pred, target_names=concept_names))

for i, tree in enumerate(multi_tree.estimators_):
    block_name = concept_names[i]
    print(f"\n--- Rules for: {block_name} ---")
    print(export_text(tree, feature_names=list(concept_index.keys())))

In [None]:
def predict_icd_from_note(note_text):
    # Preprocess
    structured = preprocess_note(note_text, tokenizer, bert_model)
    c_hat = concept_model(structured["cls_embedding"].unsqueeze(0).to(device)).cpu().numpy()

    # Predict ICD blocks using symbolic layer
    icd_pred = multi_tree.predict(c_hat)
    predicted_blocks = [concept_names[i] for i, val in enumerate(icd_pred[0]) if val == 1]

    return predicted_blocks

In [None]:
#Decision Tree Visualization

plt.figure(figsize=(12, 6))
plot_tree(multi_tree.estimators_[block_index],
          feature_names=list(concept_index.keys()),
          class_names=["Absent", "Present"],
          filled=True)
plt.show()

In [None]:
# Interpretability Metric Evaluation

# === 1. Get predictions ===
Y_pred = multi_tree.predict(X)

# === 2. Compute F1 Scores ===
micro_f1 = f1_score(Y_true, Y_pred, average='micro')
macro_f1 = f1_score(Y_true, Y_pred, average='macro')

print(f"Micro F1 score: {micro_f1:.4f}")
print(f"Macro F1 score: {macro_f1:.4f}")
print()

# === 3. Avg. Number of Rules per Block ===
def count_positive_leaves(tree):
    """Count number of leaf nodes that predict class 1."""
    tree_ = tree.tree_
    n_nodes = tree_.node_count
    children_left = tree_.children_left
    children_right = tree_.children_right
    value = tree_.value  # shape [node_id, 1, 2]

    def is_leaf(node_id):
        return children_left[node_id] == children_right[node_id] == -1

    count = 0
    for node_id in range(n_nodes):
        if is_leaf(node_id):
            if np.argmax(value[node_id][0]) == 1:  # class 1 (positive)
                count += 1
    return count

total_rules = 0
for tree in multi_tree.estimators_:
    total_rules += count_positive_leaves(tree)

avg_rules_per_block = total_rules / len(multi_tree.estimators_)
print(f"Avg. number of rules per block: {avg_rules_per_block:.2f}")

# === 4. Avg. Concepts per Rule ===
def count_concepts_per_rule(tree, feature_names):
    tree_ = tree.tree_
    paths = []

    def recurse(node, path):
        if tree_.children_left[node] == -1 and tree_.children_right[node] == -1:
            paths.append(path)
            return
        feat = feature_names[tree_.feature[node]]
        recurse(tree_.children_left[node], path + [feat])
        recurse(tree_.children_right[node], path + [feat])

    recurse(0, [])
    concept_counts = [len(set(path)) for path in paths if path]  # avoid empty
    return concept_counts

all_counts = []
feature_names = list(concept_index.keys())
for tree in multi_tree.estimators_:
    counts = count_concepts_per_rule(tree, feature_names)
    all_counts.extend(counts)

avg_concepts_per_rule = np.mean(all_counts)
print(f"Avg. concepts used per rule: {avg_concepts_per_rule:.2f}")

# === 5. Optional: Show a few symbolic rules ===
print("\n--- Example symbolic rules ---")
for i, tree in enumerate(multi_tree.estimators_[:3]):
    print(f"\nRules for: {concept_names[i]}")
    print(export_text(tree, feature_names=feature_names))