In [65]:
import os
import xml
import xml.etree.ElementTree as ET
import numpy as np
import pandas as pd

The i2b2 challenges do not provide data in the IOB format.  In this notebook we convert the 2012 dataset to that format so it can be fed into our models.

Note: This notebook is a modification of the notebook from [ClinicalBERT github repo](https://github.com/EmilyAlsentzer/clinicalBERT/blob/master/downstream_tasks/i2b2_preprocessing/i2b2_2012/Reformat.ipynb).  For the most part this is *not* my code.

# Convert i2b2 2012 Data to IOB Format

In [45]:
START_CDATA = "<TEXT><![CDATA["
END_CDATA   = "]]></TEXT>"

TAGS        = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'event', 'FAMILY_HIST']

def read_xml_file(xml_path, event_tag_type='ALL_CHILDREN', match_text=True):
    # print(xml_path)
    with open(xml_path, mode='r') as f:
        lines = f.readlines()
        text, in_text = [], False
        for i, l in enumerate(lines):
            if START_CDATA in l:
                text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))
                in_text = True
            elif END_CDATA in l:
                text.append(list(l[:l.find(END_CDATA)]))
                break
            elif in_text:
#                 if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:
#                     print("Adjusting known error")
#                     l = l[:9] + ' ' * 4 + l[9:]
# #                 elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:
# #                     print("Adjusting known error")
# #                     l = l.replace('&', 'and')
                text.append(list(l))
        
    pos_transformer = {}
    
    linear_pos = 1
    for line, sentence in enumerate(text):
        for char_pos, char in enumerate(sentence):
            pos_transformer[linear_pos] = (line, char_pos)
            linear_pos += 1
        
    try: xml_parsed = ET.parse(xml_path)
    except:
        print(xml_path)
        raise
        
    tag_containers = xml_parsed.findall('TAGS')
    assert len(tag_containers) == 1, "Found multiple tag sets!"
    tag_container = tag_containers[0]
    
#     event_tags = tag_container.getchildren() if event_tag_type == 'ALL_CHILDREN' else tag_container.findall('event')
    event_tags = tag_container.findall('EVENT')
    event_labels = [['O'] * len(sentence) for sentence in text]
    for event_tag in event_tags:
        base_label = event_tag.attrib['type']
        start_pos, end_pos, event_text = event_tag.attrib['start'], event_tag.attrib['end'], event_tag.attrib['text']
        start_pos, end_pos = int(start_pos)+1, int(end_pos)
        event_text = ' '.join(event_text.split())
#         if event_text == "0808 O’neil’s Court":
#             print("Adjusting known error")
#             end_pos -= 4
#         if event_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):
#             print("Adjusting known error")
#             event_text = 'Johnson & Johnson'
        

        (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]
            
        obs_text = []
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char if line == start_line else 0
            e = end_char if line == end_line else len(t)
            obs_text.append(''.join(t[s:e+1]).strip())
        obs_text = ' '.join(obs_text)
        obs_text = ' '.join(obs_text.split())
        
        if '&apos;' in obs_text and '&apos;' not in event_text: event_text = event_text.replace("'", "&apos;")
        if '&quot;' in obs_text and '&quot;' not in event_text: event_text = event_text.replace('"', '&quot;')
              
        if match_text: assert obs_text == event_text, (
            ("Texts don't match! %s v %s" % (event_text, obs_text)) + '\n' + str((
                start_pos, end_pos, line, s, e, t, xml_path
            ))
        )
            
        if base_label.strip() == '': continue
        
        event_labels[end_line][end_char]     = 'I-%s' % base_label
        event_labels[start_line][start_char] = 'B-%s' % base_label
        
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char+1 if line == start_line else 0
            e = end_char-1 if line == end_line else len(t)-1
            for i in range(s, e+1): event_labels[line][i] = 'I-%s' % base_label

    return text, event_labels
    
def merge_into_words(text_by_char, all_labels_by_char):
    assert len(text_by_char) == len(all_labels_by_char), "Incorrect # of sentences!"
    
    N = len(text_by_char)
    
    text_by_word, all_labels_by_word = [], []
    
    for sentence_num in range(N):
        sentence_by_char = text_by_char[sentence_num]
        labels_by_char   = all_labels_by_char[sentence_num]
        
        assert len(sentence_by_char) == len(labels_by_char), "Incorrect # of chars in sentence!"
        S = len(sentence_by_char)
        
        if labels_by_char == (['O'] * len(sentence_by_char)):
            sentence_by_word = ''.join(sentence_by_char).split()
            labels_by_word   = ['O'] * len(sentence_by_word)
        else: 
            sentence_by_word, labels_by_word = [], []
            text_chunks, labels_chunks = [], []
            s = 0
            for i in range(S):
                if i == S-1:
                    text_chunks.append(sentence_by_char[s:])
                    labels_chunks.append(labels_by_char[s:])
                elif labels_by_char[i] == 'O': continue
                else:
                    if i > 0 and labels_by_char[i-1] == 'O':
                        text_chunks.append(sentence_by_char[s:i])
                        labels_chunks.append(labels_by_char[s:i])
                        s = i
                    if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:
                        text_chunks.append(sentence_by_char[s:i+1])
                        labels_chunks.append(labels_by_char[s:i+1])
                        s = i+1
                
            for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):
                assert len(text_chunk) == len(labels_chunk), "Bad Chunking (len)"
                assert len(text_chunk) > 0, "Bad chunking (len 0)" + str(text_chunks) + str(labels_chunks)
                
                labels_set = set(labels_chunk)
                assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (
                    ("Bad chunking (contents) %s" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)
                )
                
                text_chunk_by_word = ''.join(text_chunk).split()
                W = len(text_chunk_by_word)
                if W == 0: 
#                     assert labels_set == set(['O']), "0-word chunking and non-0 label!" + str(
#                         text_chunks) + str(labels_chunks
#                     )
                    continue
                
                if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W
                elif W == 1:               labels_chunk_by_word = [labels_chunk[0]]
                elif W == 2:               labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]
                else:                      labels_chunk_by_word = [
                        labels_chunk[0]
                    ] + [labels_chunk[1]] * (W - 2) + [
                        labels_chunk[-1]
                    ]
                    
                sentence_by_word.extend(text_chunk_by_word)
                labels_by_word.extend(labels_chunk_by_word)

        assert len(sentence_by_word) == len(labels_by_word), "Incorrect # of words in sentence!"    
        
        if len(sentence_by_word) == 0: continue
            
        text_by_word.append(sentence_by_word)
        all_labels_by_word.append(labels_by_word)
    return text_by_word, all_labels_by_word

def reprocess_event_labels(folders, base_path='.', event_tag_type='event', match_text=True, dev_set_size=None):
    all_texts_by_patient, all_labels_by_patient = {}, {}

    for folder in folders:
        folder_dir = os.path.join(base_path, folder)
        xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]
        for xml_filename in xml_filenames:
            patient_num = int(xml_filename[:-4])
            xml_filepath = os.path.join(folder_dir, xml_filename)
            
            text_by_char, labels_by_char = read_xml_file(
                xml_filepath,
                event_tag_type=event_tag_type,
                match_text=match_text
            )
            text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)
            
            if patient_num not in all_texts_by_patient:
                all_texts_by_patient[patient_num] = []
                all_labels_by_patient[patient_num] = []
            
            all_texts_by_patient[patient_num].extend(text_by_word)
            all_labels_by_patient[patient_num].extend(labels_by_word)
            
    patients = set(all_texts_by_patient.keys())
    
    if dev_set_size is None: train_patients, dev_patients = list(patients), []
    else:
        N_train = int(len(patients) * (1-dev_set_size))
        patients_random = np.random.permutation(list(patients))
        train_patients = list(patients_random[:N_train])
        dev_patients   = list(patients_random[N_train:])
    
    train_texts, train_labels = [], []
    dev_texts, dev_labels = [], []

    print(f"Number of train recs = {len(train_patients)}")
    print(f"Number of dev recs = {len(dev_patients)}")

    train_doc_ids = []
    for patient_num in train_patients:
        for sent in all_texts_by_patient[patient_num]:
            train_doc_ids.append([patient_num]*len(sent))
        train_texts.extend(all_texts_by_patient[patient_num])
        train_labels.extend(all_labels_by_patient[patient_num])

    dev_doc_ids = []
    for patient_num in dev_patients:
        dev_texts.extend(all_texts_by_patient[patient_num])
        for sent in all_texts_by_patient[patient_num]:
            dev_doc_ids.append([patient_num]*len(sent))
        dev_labels.extend(all_labels_by_patient[patient_num])

    train_out_text_by_sentence = []
    for doc_ids, text, labels in zip(train_doc_ids, train_texts, train_labels):
        train_out_text_by_sentence.append('\n'.join('%s %s %s' % x for x in zip(doc_ids, text, labels)))
    dev_out_text_by_sentence = []
    for doc_ids, text, labels in zip(dev_doc_ids, dev_texts, dev_labels):
        dev_out_text_by_sentence.append('\n'.join('%s %s %s' % x for x in zip(doc_ids, text, labels)))

    return '\n\n'.join(train_out_text_by_sentence), '\n\n'.join(dev_out_text_by_sentence)

In [46]:
final_train_text, final_dev_text = reprocess_event_labels(
    ['data/i2b2/2012/2012-07-15.original-annotation.release'], dev_set_size=0.1, match_text=True
)

Number of train recs = 171
Number of dev recs = 19


In [47]:
test_text, _ = reprocess_event_labels(
    ['data/i2b2/2012/2012-08-08.test-data.event-timex-groundtruth/xml'], match_text=False, dev_set_size=None
)

Number of train recs = 120
Number of dev recs = 0


In [88]:
120+171+19

310

In [48]:
len(final_train_text)

1310674

In [55]:
print(final_train_text[:400])

18 Admission B-OCCURRENCE
18 Date O
18 : O

18 2016-08-08 O

18 Discharge B-OCCURRENCE
18 Date O
18 : O

18 2016-08-15 O

18 Discharge B-OCCURRENCE
18 Date O
18 : O

18 2016-08-15 O

18 HISTORY O
18 OF O
18 PRESENT O
18 ILLNESS O
18 : O

18 The O
18 patient O
18 is O
18 a O
18 37 O
18 year O
18 old O
18 lady O
18 with O
18 type B-PROBLEM
18 1 I-PROBLEM
18 diabetes I-PROBLEM
18 mellitus I-PROBLEM
1


Convert each split text to pandas frame and save.

In [79]:
def convert_to_pandas_df(text):
    docids, words, labels = [], [], []
    for l in text.split('\n'):
        if len(l.strip()) > 0:
            docid, word, label = l.split(' ')
            assert len(docid.strip()) != 0
            assert len(word.strip()) != 0
            assert len(label.strip()) != 0
            docids.append(docid)
            words.append(word)
            labels.append(label)
    return pd.DataFrame({'docid': docids, 'word': words, 'NER_tag': labels})

In [82]:
train_df = convert_to_pandas_df(final_train_text)
train_df.to_pickle('data/i2b2/2012/i2b2_train_dataset_df.pkl')

In [51]:
print(final_dev_text[:400])

311 ADMISSION B-OCCURRENCE
311 DATE O
311 : O

311 04/07/97 O

311 DISCHARGE B-OCCURRENCE
311 DATE O
311 : O

311 04/08/97 O

311 HISTORY O
311 OF O
311 PRESENT O
311 ILLNESS O
311 : O

311 Mr. O
311 Vessels O
311 is O
311 a O
311 49-year-old O
311 man O
311 status O
311 post O
311 orthotopic B-TREATMENT
311 heart I-TREATMENT
311 transplantation I-TREATMENT
311 in O
311 1991 O
311 at O
311 Dauteno


In [83]:
dev_df = convert_to_pandas_df(final_dev_text)
dev_df.to_pickle('data/i2b2/2012/i2b2_dev_dataset_df.pkl')

In [52]:
print(test_text[:400])

516 ADMISSION B-OCCURRENCE
516 DATE O
516 : O

516 10/14/96 O

516 DISCHARGE B-OCCURRENCE
516 DATE O
516 : O

516 10/27/96 O
516 date O
516 of O
516 birth B-OCCURRENCE
516 ; O
516 September O
516 30 O
516 , O
516 1917 O

516 THER O
516 PROCEDURES O
516 : O

516 arterial B-TEST
516 catheterization I-TEST
516 on O
516 10/14/96 O
516 , O
516 head B-TEST
516 CT I-TEST
516 scan I-TEST
516 on O
516 10/1


In [84]:
test_df = convert_to_pandas_df(test_text)
test_df.to_pickle('data/i2b2/2012/i2b2_test_dataset_df.pkl')

In [53]:
labels = {}
for s in final_train_text, final_dev_text, test_text:
    for line in s.split('\n'):
        if line == '': continue
        label = line.split()[-1]
        assert label == 'O' or label.startswith('B-') or label.startswith('I-'), "label wrong! %s" % label
        if label not in labels: labels[label] = 1
        else: labels[label] += 1

In [54]:
labels

{'B-OCCURRENCE': 5774,
 'O': 114910,
 'B-PROBLEM': 9319,
 'I-PROBLEM': 13543,
 'B-TREATMENT': 7098,
 'I-TREATMENT': 6748,
 'I-OCCURRENCE': 3590,
 'B-EVIDENTIAL': 1334,
 'B-TEST': 4762,
 'I-TEST': 5931,
 'I-EVIDENTIAL': 84,
 'B-CLINICAL_DEPT': 1724,
 'I-CLINICAL_DEPT': 3253}

In [39]:
# with open('data/i2b2/2012/processed/train.tsv', mode='w') as f:
#     f.write(final_train_text)
# with open('data/i2b2/2012/processed/dev.tsv', mode='w') as f:
#     f.write(final_dev_text)
# with open('data/i2b2/2012/processed/test.tsv', mode='w') as f:
#     f.write(test_text)

In [86]:
result_df = pd.concat([train_df, dev_df, test_df])

In [87]:
result_df.to_pickle('data/i2b2/2012/i2b2_dataset_df.pkl')