# Abstraction Alignment to Analyze the MIMIC-III Dataset
We apply abstraction alignment to analyze the abstractions encoded in the MIMIC-III dataset. We use the ICD-9 hierarchy as the human abstraction. This example is loosly based on the Analyzing Medical Dataset Encodings with Healthcare Professionals case study in the abstraction alignment paper.


In this notebook, we load the MIMIC-III dataset of clinical notes labeled with multiple ICD-9 codes. We treat the labels as the dataset's encodings and use abstraction alignment to compute the most commonly confused concepts.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import csv
import json
from tqdm import tqdm
from treelib import Tree
from itertools import combinations

import mimic

## MIMIC-III Data and ICD-9 Hierarchy
First, we load the MIMIC-III dataset that contains clinical notes and their corresponding ICD-9 codes.

In [3]:
# User-defined paths -- TODO: update with your own
DATA_DIR = '/nobackup/users/aboggust/data/mimic/mimicdata/'
ICD9_FILE = os.path.join(DATA_DIR, 'ICD9_descriptions')
TEST_DATA_FILE = os.path.join(DATA_DIR, 'mimic3', 'test_full.csv')
OUTPUT_FILE = os.path.join(DATA_DIR, 'mimic_node_pair_confusions.pkl')

In [4]:
# Load the MIMIC-III dataset of clincial notes and labels
def load_data(filename):
    notes = {}
    labels = {}
    with open(filename, 'r') as f:
        notes_reader = csv.reader(f)
        for i, note in enumerate(notes_reader):
            if i == 0: continue # skip header
            hadm_id = note[1]
            notes[hadm_id] = note[2]
            labels[hadm_id] = note[3].split(';')
    return notes, labels

test_notes, test_labels = load_data(TEST_DATA_FILE)

notes = {}
notes.update(test_notes)

labels = {}
labels.update(test_labels)

print(f'{len(notes)} test data instances')

3372 test data instances)


In [6]:
# Create the MIMIC tree
TREE = mimic.make_tree(ICD9_FILE)
print('Created MIMIC tree of ICD-9 codes')
print(f'Tree depth: {TREE.depth()}; Num nodes: {TREE.size()}; Num leaves: {len(TREE.leaves())}')
print(mimic.show(TREE))

471693it [01:51, 4103.74it/s]

Created MIMIC tree of ICD-9 codes
Tree depth: 7; Num nodes: 21166; Num leaves: 17447
@: ICD9 Hierarchy Root (None)
├── 00-99.99: PROCEDURES (None)
│   ├── 00: Procedures and interventions, Not Elsewhere Classified (None)
│   │   ├── 00.0: Therapeutic ultrasound (None)
│   │   │   ├── 00.01: Therapeutic ultrasound of vessels of head and neck (None)
│   │   │   ├── 00.02: Therapeutic ultrasound of heart (None)
│   │   │   ├── 00.03: Therapeutic ultrasound of peripheral vascular vessels (None)
│   │   │   └── 00.09: Other therapeutic ultrasound (None)
│   │   ├── 00.1: Pharmaceuticals (None)
│   │   │   ├── 00.10: Implantation of chemotherapeutic agent (None)
│   │   │   ├── 00.11: Infusion of drotrecogin alfa (activated) (None)
│   │   │   ├── 00.12: Administration of inhaled nitric oxide (None)
│   │   │   ├── 00.13: Injection or infusion of nesiritide (None)
│   │   │   ├── 00.14: Injection or infusion of oxazolidinone class of antibiotics (None)
│   │   │   ├── 00.15: High-dose infusi

In [7]:
# Prune the tree to remove unused nodes
relevant_codes = []
for label in labels.values():
    relevant_codes.extend(label)
relevant_codes = set(relevant_codes)
print(f"{len(relevant_codes)} relevant codes.")

relevant_nodes = set([])
for code in relevant_codes:
    node = TREE.get_node(code)
    if node is None: continue
    relevant_nodes.add(code)
    for ancestor in TREE.rsearch(code):
        relevant_nodes.add(ancestor)
print(f"{len(relevant_nodes)} relevant nodes.")

pruned_tree = Tree()
for level in tqdm(range(TREE.depth() + 1)):
    level_nodes = TREE.filter_nodes(lambda n: TREE.depth(n) == level)
    for node in level_nodes:
        if node.identifier in relevant_nodes:
            parent = TREE.parent(node.identifier)
            if parent is not None:
                parent = parent.identifier
            pruned_tree.create_node(
                tag=node.tag, 
                identifier=node.identifier, 
                parent=parent,
                data=None
            )
print('Pruned MIMIC tree of ICD-9 codes')
print(f'Tree depth: {pruned_tree.depth()}; Num nodes: {pruned_tree.size()}; Num leaves: {len(pruned_tree.leaves())}')
print(mimic.show(pruned_tree))

4075 relevant codes.
5779 relevant nodes.



  0%|                                                                                                                  | 0/8 [00:00<?, ?it/s][A
 12%|█████████████▎                                                                                            | 1/8 [00:00<00:01,  3.64it/s][A
 25%|██████████████████████████▌                                                                               | 2/8 [00:00<00:01,  3.65it/s][A
 38%|███████████████████████████████████████▊                                                                  | 3/8 [00:00<00:01,  3.63it/s][A
 50%|█████████████████████████████████████████████████████                                                     | 4/8 [00:01<00:01,  3.60it/s][A
 62%|██████████████████████████████████████████████████████████████████▎                                       | 5/8 [00:01<00:00,  3.48it/s][A
 75%|███████████████████████████████████████████████████████████████████████████████▌                          | 6/8 [00:01<00:00

Pruned MIMIC tree of ICD-9 codes
Tree depth: 7; Num nodes: 5779; Num leaves: 3805
@: ICD9 Hierarchy Root (None)
├── 00-99.99: PROCEDURES (None)
│   ├── 08-16.99: OPERATIONS ON THE EYE (None)
│   │   ├── 10: Operations on conjunctiva (None)
│   │   │   ├── 10.2: Diagnostic procedures on conjunctiva (None)
│   │   │   └── 10.9: Other operations on conjunctiva (None)
│   │   ├── 11: Operations on cornea (None)
│   │   │   ├── 11.0: Magnetic removal of embedded foreign body from cornea (None)
│   │   │   ├── 11.3: Excision of pterygium (None)
│   │   │   └── 11.4: Excision or destruction of tissue or other lesion of cornea (None)
│   │   ├── 12: Operations on iris, ciliary body, sclera, and anterior chamber (None)
│   │   │   ├── 12.0: Removal of intraocular foreign body from anterior segment of eye (None)
│   │   │   ├── 12.3: Iridoplasty and coreoplasty (None)
│   │   │   ├── 12.4: Excision or destruction of lesion of iris and ciliary body (None)
│   │   │   ├── 12.5: Facilitation of int

### Analyze the dataset with abstraction alignment
We use concept co-confusio to see pairs of dataset concepts that commonly co-occur.

In [None]:
# Get node pairs from the fitted abstractions
node_pairs = {}
for hadm_id, label in tqdm(labels.items()):
    tree = mimic.propagate(label, TREE)
    weighted_nodes = tree.filter_nodes(lambda n: n.data > 0)
    for pair in combinations(weighted_nodes, 2):
        pair = tuple(sorted(list(pair)))
        node_pairs.setdefault(pair, 0)
        node_pairs[pair] += 1
normalized_node_pairs = {pair: value/len(labels) for pair, value in node_pairs.items()}


  0%|                                                                                                               | 0/3372 [00:00<?, ?it/s][A
  0%|                                                                                                     | 1/3372 [00:02<2:18:41,  2.47s/it][A
  0%|                                                                                                     | 2/3372 [00:04<2:18:28,  2.47s/it][A
  0%|                                                                                                     | 3/3372 [00:07<2:18:28,  2.47s/it][A
  0%|                                                                                                     | 4/3372 [00:09<2:18:30,  2.47s/it][A
  0%|▏                                                                                                    | 5/3372 [00:12<2:18:30,  2.47s/it][A
  0%|▏                                                                                                    | 6/3372 [00:14<2:18:26

In [None]:
# Remove pairs that are connected
unconnected_pairs = {}
for pair, value in tqdm(normalized_node_pairs.items()):
    level_a = TREE.level(pair[0].identifier)
    level_b = TREE.level(pair[1].identifier)
    if level_a < level_b:
        if pruned_tree.is_ancestor(pair[0].identifier, pair[1].identifier):
            continue
    if level_a > level_b:
        if pruned_tree.is_ancestor(pair[1].identifier, pair[0].identifier):
            continue
    unconnected_pairs[pair] = value
print(f'Originally {len(normalized_node_pairs)} pairs; {len(unconnected_pairs)} unconnected pairs')

In [None]:
# Top confused concepts
top_pairs = sorted(list(unconnected_pairs.keys()), key=lambda x: unconnected_pairs[x], reverse=True)
print('HIGHEST OVERALL CONFUSION:')
for i in range(5):
    print(f'{top_pairs[i]} --- {unconnected_pairs[top_pairs[i]]:.2%}')

In [None]:
# Level-1 concept confusion
# Plot types of confusion
top_level_pairs = {p: v for p, v in unconnected_pairs.items() if TREE.level(p[0].identifier) == 1 and TREE.level(p[1].identifier) == 1}
top_level_pairs = sorted(list(top_level_pairs.keys()), key=lambda x: top_level_pairs[x], reverse=True)
print('LEVEL-1 CONCEPT CO_CONFUSION')
for i in range(5):
    print(f'{top_level_pairs[i]} --- {unconnected_pairs[top_level_pairs[i]]:.2%}')

In [None]:
# Codable concept confusion
codable_pairs = {p: v for p, v in unconnected_pairs.items() if '-' not in p[0].identifier and '-' not in p[1].identifier}
codable_pairs = sorted(list(codable_pairs.keys()), key=lambda x: codable_pairs[x], reverse=True)
print('CODABLE NODE CO_CONFUSION:')
for i in range(5):
    print(f'{codable_pairs[i]} --- {unconnected_pairs[codable_pairs[i]]:.2%}')
