In [1]:
import os, xml, xml.etree.ElementTree as ET, numpy as np

In [2]:
START_CDATA = "<TEXT><![CDATA["
END_CDATA   = "]]></TEXT>"

TAGS        = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'event', 'FAMILY_HIST']

def read_xml_file(xml_path, event_tag_type='ALL_CHILDREN', match_text=True):
    with open(xml_path, mode='r') as f:
        lines = f.readlines()
        text, in_text = [], False
        for i, l in enumerate(lines):
            if START_CDATA in l:
                text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))
                in_text = True
            elif END_CDATA in l:
                text.append(list(l[:l.find(END_CDATA)]))
                break
            elif in_text:
#                 if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:
#                     print("Adjusting known error")
#                     l = l[:9] + ' ' * 4 + l[9:]
# #                 elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:
# #                     print("Adjusting known error")
# #                     l = l.replace('&', 'and')
                text.append(list(l))
        
    pos_transformer = {}
    
    linear_pos = 1
    for line, sentence in enumerate(text):
        for char_pos, char in enumerate(sentence):
            pos_transformer[linear_pos] = (line, char_pos)
            linear_pos += 1
        
    try: xml_parsed = ET.parse(xml_path)
    except:
        print(xml_path)
        raise
        
    tag_containers = xml_parsed.findall('TAGS')
    assert len(tag_containers) == 1, "Found multiple tag sets!"
    tag_container = tag_containers[0]
    
#     event_tags = tag_container.getchildren() if event_tag_type == 'ALL_CHILDREN' else tag_container.findall('event')
    event_tags = tag_container.findall('EVENT')
    event_labels = [['O'] * len(sentence) for sentence in text]
    for event_tag in event_tags:
        base_label = event_tag.attrib['type']
        start_pos, end_pos, event_text = event_tag.attrib['start'], event_tag.attrib['end'], event_tag.attrib['text']
        start_pos, end_pos = int(start_pos)+1, int(end_pos)
        event_text = ' '.join(event_text.split())
#         if event_text == "0808 O’neil’s Court":
#             print("Adjusting known error")
#             end_pos -= 4
#         if event_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):
#             print("Adjusting known error")
#             event_text = 'Johnson & Johnson'
        

        (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]
            
        obs_text = []
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char if line == start_line else 0
            e = end_char if line == end_line else len(t)
            obs_text.append(''.join(t[s:e+1]).strip())
        obs_text = ' '.join(obs_text)
        obs_text = ' '.join(obs_text.split())
        
        if '&apos;' in obs_text and '&apos;' not in event_text: event_text = event_text.replace("'", "&apos;")
        if '&quot;' in obs_text and '&quot;' not in event_text: event_text = event_text.replace('"', '&quot;')
              
        if match_text: assert obs_text == event_text, (
            ("Texts don't match! %s v %s" % (event_text, obs_text)) + '\n' + str((
                start_pos, end_pos, line, s, e, t, xml_path
            ))
        )
            
        if base_label.strip() == '': continue
        
        event_labels[end_line][end_char]     = 'I-%s' % base_label
        event_labels[start_line][start_char] = 'B-%s' % base_label
        
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char+1 if line == start_line else 0
            e = end_char-1 if line == end_line else len(t)-1
            for i in range(s, e+1): event_labels[line][i] = 'I-%s' % base_label

    return text, event_labels
    
def merge_into_words(text_by_char, all_labels_by_char):
    assert len(text_by_char) == len(all_labels_by_char), "Incorrect # of sentences!"
    
    N = len(text_by_char)
    
    text_by_word, all_labels_by_word = [], []
    
    for sentence_num in range(N):
        sentence_by_char = text_by_char[sentence_num]
        labels_by_char   = all_labels_by_char[sentence_num]
        
        assert len(sentence_by_char) == len(labels_by_char), "Incorrect # of chars in sentence!"
        S = len(sentence_by_char)
        
        if labels_by_char == (['O'] * len(sentence_by_char)):
            sentence_by_word = ''.join(sentence_by_char).split()
            labels_by_word   = ['O'] * len(sentence_by_word)
        else: 
            sentence_by_word, labels_by_word = [], []
            text_chunks, labels_chunks = [], []
            s = 0
            for i in range(S):
                if i == S-1:
                    text_chunks.append(sentence_by_char[s:])
                    labels_chunks.append(labels_by_char[s:])
                elif labels_by_char[i] == 'O': continue
                else:
                    if i > 0 and labels_by_char[i-1] == 'O':
                        text_chunks.append(sentence_by_char[s:i])
                        labels_chunks.append(labels_by_char[s:i])
                        s = i
                    if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:
                        text_chunks.append(sentence_by_char[s:i+1])
                        labels_chunks.append(labels_by_char[s:i+1])
                        s = i+1
                
            for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):
                assert len(text_chunk) == len(labels_chunk), "Bad Chunking (len)"
                assert len(text_chunk) > 0, "Bad chunking (len 0)" + str(text_chunks) + str(labels_chunks)
                
                labels_set = set(labels_chunk)
                assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (
                    ("Bad chunking (contents) %s" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)
                )
                
                text_chunk_by_word = ''.join(text_chunk).split()
                W = len(text_chunk_by_word)
                if W == 0: 
#                     assert labels_set == set(['O']), "0-word chunking and non-0 label!" + str(
#                         text_chunks) + str(labels_chunks
#                     )
                    continue
                
                if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W
                elif W == 1:               labels_chunk_by_word = [labels_chunk[0]]
                elif W == 2:               labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]
                else:                      labels_chunk_by_word = [
                        labels_chunk[0]
                    ] + [labels_chunk[1]] * (W - 2) + [
                        labels_chunk[-1]
                    ]
                    
                sentence_by_word.extend(text_chunk_by_word)
                labels_by_word.extend(labels_chunk_by_word)

        assert len(sentence_by_word) == len(labels_by_word), "Incorrect # of words in sentence!"    
        
        if len(sentence_by_word) == 0: continue
            
        text_by_word.append(sentence_by_word)
        all_labels_by_word.append(labels_by_word)
    return text_by_word, all_labels_by_word

def reprocess_event_labels(folders, base_path='.', event_tag_type='event', match_text=True, dev_set_size=None):
    all_texts_by_patient, all_labels_by_patient = {}, {}

    for folder in folders:
        folder_dir = os.path.join(base_path, folder)
        xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]
        for xml_filename in xml_filenames:
            try:
                patient_num = int(xml_filename[:-4])
                xml_filepath = os.path.join(folder_dir, xml_filename)

                text_by_char, labels_by_char = read_xml_file(
                    xml_filepath,
                    event_tag_type=event_tag_type,
                    match_text=match_text
                )
                text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)

                if patient_num not in all_texts_by_patient:
                    all_texts_by_patient[patient_num] = []
                    all_labels_by_patient[patient_num] = []

                all_texts_by_patient[patient_num].extend(text_by_word)
                all_labels_by_patient[patient_num].extend(labels_by_word)
            except:
                print(xml_filename)
            
    patients = set(all_texts_by_patient.keys())
    
    if dev_set_size is None: train_patients, dev_patients = list(patients), []
    else:
        N_train = int(len(patients) * (1-dev_set_size))
        patients_random = np.random.permutation(list(patients))
        train_patients = list(patients_random[:N_train])
        dev_patients   = list(patients_random[N_train:])
    
    train_texts, train_labels = [], []
    dev_texts, dev_labels = [], []
    
    for patient_num in train_patients:
        train_texts.extend(all_texts_by_patient[patient_num])
        train_labels.extend(all_labels_by_patient[patient_num])

    for patient_num in dev_patients:
        dev_texts.extend(all_texts_by_patient[patient_num])
        dev_labels.extend(all_labels_by_patient[patient_num])


    train_out_text_by_sentence = []
    for text, labels in zip(train_texts, train_labels):
        train_out_text_by_sentence.append('\n'.join('%s %s' % x for x in zip(text, labels)))
    dev_out_text_by_sentence = []
    for text, labels in zip(dev_texts, dev_labels):
        dev_out_text_by_sentence.append('\n'.join('%s %s' % x for x in zip(text, labels)))

    return '\n\n'.join(train_out_text_by_sentence), '\n\n'.join(dev_out_text_by_sentence)

In [3]:
final_train_text, final_dev_text = reprocess_event_labels(
    ['2012-07-15.original-annotation.release'], dev_set_size=0.1, match_text=True
)

./2012-07-15.original-annotation.release/382.xml
382.xml
./2012-07-15.original-annotation.release/152.xml
152.xml
./2012-07-15.original-annotation.release/143.xml
143.xml
./2012-07-15.original-annotation.release/422.xml
422.xml
./2012-07-15.original-annotation.release/272.xml
272.xml
./2012-07-15.original-annotation.release/547.xml
547.xml
./2012-07-15.original-annotation.release/23.xml
23.xml
./2012-07-15.original-annotation.release/807.xml
807.xml


In [4]:
test_text, _ = reprocess_event_labels(
    ['2012-08-08.test-data.event-timex-groundtruth/xml'], match_text=False, dev_set_size=None
)

./2012-08-08.test-data.event-timex-groundtruth/xml/527.xml
527.xml
./2012-08-08.test-data.event-timex-groundtruth/xml/53.xml
53.xml
./2012-08-08.test-data.event-timex-groundtruth/xml/687.xml
687.xml
./2012-08-08.test-data.event-timex-groundtruth/xml/802.xml
802.xml
./2012-08-08.test-data.event-timex-groundtruth/xml/397.xml
397.xml
./2012-08-08.test-data.event-timex-groundtruth/xml/627.xml
627.xml


In [5]:
print(final_train_text[:500])

Admission B-OCCURRENCE
Date O
: O

2013-10-27 O

Discharge B-OCCURRENCE
Date O
: O

2013-11-03 O

Service O
: O

MEDICINE O

History O
of O
Present O
Illness O
: O

42 O
year O
old O
female O
with O
h/o O
cholangiocarcinoma B-PROBLEM
dx O
in O
2009 O
s/p O
resection B-TREATMENT
, O
with O
recent B-TEST
CT I-TEST
showing B-EVIDENTIAL
met B-PROBLEM
cholangiocarcinoma I-PROBLEM
in O
9/2004 O
. O

Pt O
was O
recently O
admitted B-OCCURRENCE
for O
fever B-PROBLEM
due O
to O
cholangitis B-PROBLEM
on O


In [6]:
print(final_dev_text[:400])

ADMISSION B-OCCURRENCE
DATE O
: O

03/11/2002 O

DISCHARGE B-OCCURRENCE
DATE O
: O

03/14/2002 O

DISCHARGE B-OCCURRENCE
DATE O
: O

03/14/2002 O

HISTORY O
OF O
PRESENT O
ILLNESS O
: O

This O
is O
a O
62-year-old O
hospice O
chaplain O
who O
was O
referred B-OCCURRENCE
by O
Dr. O
Tomedankell O
Flowayles O
and O
Dr. O
Es O
Oarekote O
for O
evaluation B-TEST
of O
his B-PROBLEM
right I-PROBLEM
hip 


In [7]:
print(test_text[:400])

ADMISSION B-OCCURRENCE
DATE O
: O

10/14/96 O

DISCHARGE B-OCCURRENCE
DATE O
: O

10/27/96 O
date O
of O
birth B-OCCURRENCE
; O
September O
30 O
, O
1917 O

THER O
PROCEDURES O
: O

arterial B-TEST
catheterization I-TEST
on O
10/14/96 O
, O
head B-TEST
CT I-TEST
scan I-TEST
on O
10/14/96 O

HISTORY O
AND O
REASON O
FOR O
HOSPITALIZATION O
: O

Granrivern O
Call O
is O
a O
79-year-old O
right O
han


In [8]:
labels = {}
for s in final_train_text, final_dev_text, test_text:
    for line in s.split('\n'):
        if line == '': continue
        label = line.split()[-1]
        assert label == 'O' or label.startswith('B-') or label.startswith('I-'), "label wrong! %s" % label
        if label not in labels: labels[label] = 1
        else: labels[label] += 1

In [9]:
labels

{'B-OCCURRENCE': 5351,
 'O': 107645,
 'B-PROBLEM': 8717,
 'B-TREATMENT': 6571,
 'B-TEST': 4477,
 'I-TEST': 5655,
 'B-EVIDENTIAL': 1273,
 'I-PROBLEM': 12870,
 'I-OCCURRENCE': 3258,
 'I-TREATMENT': 6300,
 'B-CLINICAL_DEPT': 1537,
 'I-CLINICAL_DEPT': 2858,
 'I-EVIDENTIAL': 82}

In [10]:
f = open("../../NER/2012/label.txt", "w")
for label in labels:
    f.write(label+"\n")
f.close()

with open('../../NER/2012/train.txt', mode='w') as f:
    f.write(final_train_text)
with open('../../NER/2012/dev.txt', mode='w') as f:
    f.write(final_dev_text)
with open('../../NER/2012/test.txt', mode='w') as f:
    f.write(test_text)