## MIMIC-III Misspelling Synthetic Dataset (with new lexicon)

This notebook makes an synthetic typo dataset from MIMIC-III dataset. We randomly choose words from clinical notes, corrupt them.

The possible corruptions are (We do one or two corruptions for each word):
1. Adding a character
2. Deleting a character
3. Substituting a character
4. Swaping two adjacent characters

The dictionary (a set of valid words) generated with the `LRWD` and the `prevariants` table of UMLS and a English dictionary from [here](https://github.com/dwyl/english-words).

The output of this notebook is a dataset of (context, typo, answer) in TSV format.

In [1]:
import os
import sys
import re
import csv
import shutil
import random
import multiprocessing
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict, Counter
from utils import clean_text, sanitize_text

In [2]:
# Data path
mimic_note_fpath = '../data/mimic3/NOTEEVENTS.csv'  # MIMIC-III
mimic_tools_dpath = '../scripts/mimic-tools/'  # Pseudonymization
lexicon_fpath = '../data/lexicon/lexicon.json'  # Dictionary

# Output path
data_root = '../data/mimic_synthetic/'
num_val_examples = 10000
num_test_examples = 10000
num_examples = num_val_examples + num_test_examples
all_output_fpath = os.path.join(data_root, 'all.tsv')
val_output_fpath = os.path.join(data_root, 'val.tsv')
test_output_fpath = os.path.join(data_root, 'test.tsv')
# min_word_len = 3
min_word_len = 1
no_corruption_prob = 0.1
max_corruptions = 2
do_substitution = True
do_transposition = True
DEFAULT_MAX_CHARACTER_POSITIONS = 64

pseudo_in_dpath = os.path.join(data_root, 'temp')
pseudo_out_dpath = os.path.join(data_root, 'temp_pseudonym')

## Load data

In [3]:
# Read MIMIC-III notes
print(f'Read {os.path.basename(mimic_note_fpath)}... ', end='')
df_notes = pd.read_csv(mimic_note_fpath, low_memory=False)
df_notes = df_notes.set_index('ROW_ID')
print(f'done! {len(df_notes)} notes')

Read NOTEEVENTS.csv... done! 2083180 notes


In [4]:
# Load & preprocess clinspell lexicon
print(f'Read {lexicon_fpath}... ', end='')
with open(lexicon_fpath, 'r') as fd:
    vocab = json.load(fd)
vocab_set = set(vocab)
print(f'{len(vocab)} words')

Read ../data/lexicon/lexicon.json... 822919 words


## Select notes & corrupt words

In [5]:
# Choose words
puncs = list("[]!\"#$%&'()*+,./:;<=>?@\^_`{|}~-")

def random_word_context(text, max_trial=100):
    words = text.split()
    
    trial = 0
    done = False
    while trial < max_trial and not done:
        # Select a word
        trial += 1
        w_idx = random.randint(0, len(words)-1)
        word, left_res, right_res = words[w_idx], [], []
        
        # If the word is already in vocab, it's good to go.
        if len(word) >= min_word_len and (word.lower() in vocab_set) and \
                len(word) < DEFAULT_MAX_CHARACTER_POSITIONS - 4:
            done = True
        else:
            # Otherwise, detach puncs at the first and the last char, and check again
            if word[0] in puncs:
                word, left_res = word[1:], [word[0]]
            else:
                word, left_res = word, []
            if not word: continue  # The word was just a punc

            if word[-1] in puncs:
                word, right_res = word[:-1], [word[-1]]
            else:
                word, right_res = word, []

            if len(word) < min_word_len or (not word.lower() in vocab_set) or \
                    len(word) >= DEFAULT_MAX_CHARACTER_POSITIONS - 4:
                continue

            # Check whether it's anonymized field
            right_snip = ' '.join(words[w_idx+1:w_idx+5])
            if '**]' in right_snip and '[**' not in right_snip:
                continue
            left_snip = ' '.join(words[w_idx-4:w_idx])
            if '[**' in left_snip and '**]' not in left_snip:
                continue
            
            # Pass!
            done = True
            
    if done:
        return word, ' '.join(words[:w_idx] + left_res), ' '.join(right_res + words[w_idx+1:])
    else:
        raise ValueError('failed to choose word')

In [6]:
# Corrupt words
alphabet = 'abcdefghijklmnopqrstuvwxyz'
def random_alphabet():
    return random.choice(alphabet)

operation_list = ['ins', 'del']
if do_substitution:
    operation_list.append('sub')
if do_transposition:
    operation_list.append('tra')

def single_corruption(word):
    while True:
        oper = random.choice(operation_list)

        if oper == "del":  # deletion
            if len(word) == 1: continue
            cidx = random.randint(0, len(word)-1)
            ret = word[:cidx] + word[cidx+1:]
            break
        elif oper == "ins":  # insertion
            cidx = random.randint(0, len(word))
            ret = word[:cidx] + random_alphabet() + word[cidx:]
            break
        elif oper == "sub":  # substitution
            cidx = random.randint(0, len(word)-1)
            while True:
                c = random_alphabet()
                if c != word[cidx]:
                    ret = word[:cidx] + c + word[cidx+1:]
                    break
        elif oper == "tra":  # transposition
            if len(word) == 1 : continue
            cidx = random.randint(0, len(word)-2) # swap cidx-th and (cidx+1)-th char
            if word[cidx+1] == word[cidx]: continue
            ret = word[:cidx] + word[cidx+1] + word[cidx] + word[cidx+2:]
            break
        else:
            raise ValueError(f'Wrong operation {oper}')
    return ret

def corrupt_word(word_original, max_corruptions=2):
    if no_corruption_prob > 0.0:
        if random.uniform(0, 1) < no_corruption_prob:
            return word_original

    num_corruption = random.randint(1, max_corruptions)
    while True:
        word = word_original
        for i in range(num_corruption):
            word = single_corruption(word)
        if word_original != word:
            break
    return word

In [7]:
# Select note indexes randomly
random.seed(1234)
note_ids = list(df_notes.index)
random.shuffle(note_ids)

count, typo_noteids = 0, set()
for nid in note_ids:
    note = df_notes.loc[nid].TEXT
    if len(note.strip()) >= 2000 and nid not in typo_noteids:  # Only choose for len(text) >= 1000
        typo_noteids.add(nid)
        count += 1
    if count == num_examples:
        break

typo_noteids = list(typo_noteids)
print(typo_noteids[:10])

[655360, 393217, 524290, 393216, 393224, 16, 393244, 1572897, 34, 786477]


In [8]:
# Select words
examples = []
for nid in tqdm(typo_noteids):
    text = df_notes.loc[nid].TEXT
    word, left, right = random_word_context(text)
    examples.append([word, left, right])

100%|█████████████████████████████████████████████████████| 20000/20000 [00:05<00:00, 3928.25it/s]


In [9]:
# See how many words have punctuations
words = list(zip(*examples))[0]
words_with_punc = list(filter(lambda w: sum([not c.isalpha() for c in w]), words))
print(f'{len(words_with_punc)} words have punctuation')
words_with_punc

129 words have punctuation


['trans-jugular',
 "patient's",
 "CX'S",
 'vancomycin-resistant',
 'intra-',
 "CK's",
 "MAP'S",
 'fat-containing',
 "patient's",
 "Gerota's",
 'non-distended',
 "patient's",
 'Post-hemorrhagic',
 "patient's",
 'NON-CORONARY',
 'sub-mandibular',
 'G-tube',
 'e-mail',
 'MP-RAGE',
 'non-heparin',
 'IN-',
 "rec'd",
 'Non-tender',
 'post-surgical',
 'c-collar',
 'well-defined',
 'ad-lib',
 'lima-lad',
 "patient's",
 't-cell',
 'Non-tender',
 'C-SPINE',
 'time-out',
 'post-op',
 'wall-to-wall',
 'Ill-defined',
 'post-operative',
 'Post-op',
 "patient's",
 'c-pap',
 're-intubate',
 "patient's",
 'intra-abdominal',
 'right-sided',
 'moderate-sized',
 'non-distended',
 'x-ray',
 'Non-tender',
 're-oriented',
 'intra-abdominal',
 'double-lumen',
 "Sat's",
 'work-up',
 'Non-distended',
 "non-Hodgkin's",
 'Non-tender',
 "Patient's",
 'year-old',
 'Non-invasive',
 'mild-to-moderate',
 'vaso-vagal',
 'mild-to-moderate',
 'extra-axial',
 "family's",
 "Pt's",
 'third-order',
 'LIMA-LAD',
 'Non-tender'

In [10]:
# Write notes with misspells
if os.path.exists(pseudo_in_dpath):
    shutil.rmtree(pseudo_in_dpath)        
if os.path.exists(pseudo_out_dpath):
    shutil.rmtree(pseudo_out_dpath)
    
os.makedirs(pseudo_in_dpath)
for noteid, (_, left, right) in zip(typo_noteids, examples):
    with open(os.path.join(pseudo_in_dpath, f'{noteid}_left.txt'), 'w', encoding='utf-8') as fd:
        fd.write(left)
    with open(os.path.join(pseudo_in_dpath, f'{noteid}_right.txt'), 'w', encoding='utf-8') as fd:
        fd.write(right)

In [11]:
# pip install requests joblib sqlalchemy gensim
! python {os.path.join(mimic_tools_dpath, 'main.py')} REPLACE \
    --input-dir {os.path.join(os.getcwd(), pseudo_in_dpath)} \
    --output-dir {os.path.join(os.getcwd(), pseudo_out_dpath)} \
    --list-dir {os.path.join(mimic_tools_dpath, 'lists')}

2022-04-21 20:02:38,532 Starting placeholder replacing
2022-04-21 20:02:38,533 Loading lists
2022-04-21 20:02:38,557 * Postal addresses: 20000 [656C Newport Court Coatesville, PA 19320 ...]
2022-04-21 20:02:38,803 * Last names: 88799 [SMITH, JOHNSON, WILLIAMS, JONES, BROWN ...]
2022-04-21 20:02:38,806 * Male first names: 1219 [JAMES, JOHN, ROBERT, MICHAEL, WILLIAM ...]
2022-04-21 20:02:38,816 * Female first names: 4275 [MARY, PATRICIA, LINDA, BARBARA, ELIZABETH ...]
2022-04-21 20:02:38,848 * Phone numbers: 20000 [(666) 372-7835, (923) 739-2644 ...]
2022-04-21 20:02:38,874 * Companies: 20000 [Ligula Aenean Gravida Ltd, Non Bibendum Sed LLC ...]
2022-04-21 20:02:38,874 * Countries: 264 [Afghanistan, Albania, Algeria, American Samoa ...]
2022-04-21 20:02:38,896 * Emails: 20000 [enim.Suspendisse.aliquet@Crasdictum.com, sapien.Cras.dolor@Curabitur.org ...]
2022-04-21 20:02:38,897 * Holiday names: 187 [Administrative Professionals Day, Air Force Birthday ...]
2022-04-21 20:02:38,901 * Hospit

In [12]:
# Read pseudonymized notes
def process_note(note):
    note = re.sub('\n', ' ', note)
    note = re.sub('\t', ' ', note)
    note = sanitize_text(clean_text(note))
    return note

for nid, example in tqdm(zip(typo_noteids, examples), total=len(typo_noteids)):
    with open(os.path.join(pseudo_out_dpath, f'{nid}_left.txt'), 'r', encoding='utf-8') as fd:
        note = fd.read()
        note = process_note(note)
        example[1] = note
    with open(os.path.join(pseudo_out_dpath, f'{nid}_right.txt'), 'r', encoding='utf-8') as fd:
        note = fd.read()
        note = process_note(note)
        example[2] = note
    example[0] = example[0].lower()

100%|██████████████████████████████████████████████████████| 20000/20000 [00:39<00:00, 506.99it/s]


In [13]:
# Corrupt words
print('Generate corrupted words... ')
random.seed(1234)
correct_words = [e[0] for e in examples]
typo_words = [corrupt_word(w, max_corruptions) for w in correct_words]
for i, (w1, w2) in enumerate(zip(correct_words, typo_words)):
    print(f'\t{w1} -> {w2}')
    if i == 5: break
print('done!')

Generate corrupted words... 
	bp -> dcbp
	tracking -> ztracking
	not -> not
	much -> uwmch
	to -> ot
	patient -> patient
done!


## Generate dataset for BERT

Data format (TSV): `index`,`note_id`,`word`,`left`,`right`,`correct`

- `index`: index of the data (starting from 0)
- `note_id`: ROW_ID of MIMIC-III `NOTEEVENTS.csv`
- `word`: the word of interest (typo)
- `left`: left context
- `right`: right context
- `correct`: correction

In [14]:
random.seed(1234)
data_split_idx = list(range(num_examples))
random.shuffle(data_split_idx)
val_idx, test_idx = data_split_idx[:num_val_examples], data_split_idx[num_val_examples:]
val_idx.sort()
test_idx.sort()

In [15]:
if not os.path.exists(data_root):
    os.makedirs(data_root)

for fpath, idx_list in [(all_output_fpath, list(range(num_examples))),
                   (val_output_fpath, val_idx),
                   (test_output_fpath, test_idx)]:
    print(f'Write examples to {fpath}... ', end='', flush=True)
    with open(fpath, 'w') as fd:
        writer = csv.writer(fd, delimiter='\t')
        writer.writerow(['index', 'note_id', 'word', 'left', 'right', 'correct'])  
        for i in idx_list:
            nid, (correct, left, right), typo = typo_noteids[i], examples[i], typo_words[i]
            left = ' '.join(left.split(' ')[-128:])
            right = ' '.join(right.split(' ')[:128])
            line = [i, nid, typo, left, right, correct]
            writer.writerow(line)
    print('done!')

Write examples to ../data/mimic_synthetic/all.tsv... done!
Write examples to ../data/mimic_synthetic/val.tsv... done!
Write examples to ../data/mimic_synthetic/test.tsv... done!
