In [1]:
import os
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from pytorch_pretrained_bert import BertTokenizer
import nltk
import glob
import itertools
from xml.dom import minidom
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
from nltk.tokenize import TreebankWordTokenizer
from nltk.stem.snowball import SnowballStemmer
from collections import Counter

# word_tokenize = TreebankWordTokenizer().tokenize

def get_annotation(element, indicator):
    if element.tagName == 'SMOKER' or element.tagName == 'FAMILY_HIST':
        return (element.getAttribute('text').strip().lower(), element.tagName.lower() + '.' + 
                element.getAttribute(indicator).lower().strip().replace(' ', '_'))
    else:
        return (element.getAttribute('text').strip().lower(), element.tagName.lower() + '.' + 
                element.getAttribute(indicator).lower().strip().replace(' ', '_'), 
                element.getAttribute('time').lower().strip().replace(' ', '_'))
    
def tokenise_annotation(annotation):
    return (word_tokenize(annotation[0]), annotation[1])

def combine_annotations(annotations):
    types = list()
    results = list()
    n = 0
    for annotation in annotations:
        if len(annotation) == 3:
            types.append((annotation[1], annotation[2]))
    for annotation in annotations:
        if len(annotation) == 3:
            if ((annotation[1], 'before_dct') in types and 
                (annotation[1], 'during_dct') in types and 
                (annotation[1], 'after_dct') in types):
                 results.append((annotation[0], annotation[1] + '.continuing'))
            else:
                results.append((annotation[0], annotation[1] + '.' + annotation[2]))
        else:
            results.append(annotation)
    return list(set(results))

def find_sublist(sublist, alist):
    indices = list()
    for index in (i for i, e in enumerate(alist) if e == sublist[0]):
        if alist[index:index + len(sublist)] == sublist:
            indices.append((index, index + len(sublist) - 1))
    return indices

def annotate(tags, annotations, indices):
    for i in range(len(indices)):
        for j in range(len(indices[i])):
            for k in range(indices[i][j][0], indices[i][j][1] + 1):
                tags[k] = 'I-' + annotations[i][1]

def isplit(iterable, splitters):
    return [list(g) for k, g in itertools.groupby(iterable, lambda x: x in splitters) if not k]

def replace_elements(alist, indices):
    for i in range(len(indices)):
        for j in range(len(indices[i])):
            alist[i][indices[i][j]] = -1
            
def write_to_file(filename, data, index):
    file = open(filename, 'w')
    for i in range(len(data)):
        file.write("%d %s\n" % (i + index, data[i][0]))
    file.close()

def generate_files(data, labels, files):
    paths = ['../models/cnn/data/training/', '../models/rnn/data/training/', '../models/lstm/data/training/']
    for i in range(0, len(data)):
        for path in paths:
            file = open(path + files[i][17:-4] + '.txt', 'w')
            for j in range(0, len(data[i])):
                file.write(','.join(str(x) for x in data[i][j]) + ' ' + (','.join(str(x) for x in labels[i][j])) + '\n')
            file.close()
    
def print_data(encoded_data, encoded_labels, data_indices, label_indices):
    for i in range(len(encoded_data)):
        for j in range(len(encoded_data[i])):
            for k in range(len(encoded_data[i][j])):
                print(data_indices[encoded_data[i][j][k] - 2][0] + " " + 
                    label_indices[encoded_labels[i][j][k] - 1][0])

def get_chunks(text, max_length=256, ch='\n'):
    tok = TreebankWordTokenizer()
    words_end = [i[1] for i in tok.span_tokenize(text)]
    chunks = []
    while len(words_end) > max_length:
        all_ch = [i for i, ltr in enumerate(text) if ltr == ch]
        if not all_ch: 
            print('uh oh')
            break
        x = 0
        for i in all_ch:
            if i < words_end[max_length]:
                x = i
        if x == 0: break
        chunk = text[:x]
        chunks.append(chunk)
        text = text[x:]
        words_end = [i[1] for i in tok.span_tokenize(text)]
    chunks.append(text)
    return chunks

In [2]:
tagnames = ['CAD', 'DIABETES', 'FAMILY_HIST', 'HYPERLIPIDEMIA', 'HYPERTENSION', 'MEDICATION', 'OBESE', 'SMOKER']
folder1 = '/host_home/data/i2b2/2014/training/training-RiskFactors-Complete-Set1'
folder2 = '/host_home/data/i2b2/2014/training/training-RiskFactors-Complete-Set2'
files1 = glob.glob(folder1+'/*.xml')
files2 = glob.glob(folder2+'/*.xml')
files = files1 + files2

max_length=128

In [6]:
# data, data_list, labels, label_list = list(), list(), list(), list()
tagged_sents = list()

for file in tqdm(files):
    root = minidom.parse(file)
    annotation_objects = [root.getElementsByTagName(x) for x in tagnames]
    annotations = [[[get_annotation(z, 'type1')
                if z.tagName == 'MEDICATION' else get_annotation(z, 'status')
                if z.tagName == 'SMOKER' else get_annotation(z, 'indicator')
                for z in y.getElementsByTagName(y.tagName)] 
                for y in x] for x in annotation_objects]
    annotations = [[y for y in x if len(y) > 0] for x in annotations if len(x) > 0]
    annotations = list(set([y for x in [y for x in annotations for y in x] for y in x]))
    annotations = [x for x in annotations if x[1] != 'family_hist.not_present' and x[1] != 'smoker.unknown']
    annotations = [x for x in annotations if x[0] != '']
    
    annotations = combine_annotations(annotations)
    annotations = [tokenise_annotation(x) for x in annotations]
    annotations.sort(key=lambda x: len(x[0]), reverse=True)
    
    text = root.getElementsByTagName("TEXT")[0].firstChild.data
    text = word_tokenize(text.lower())
#     text_chunks = get_chunks(text, max_length=max_length)

    indices = [find_sublist(x[0], text) for x in annotations]
    tags = ['O' for x in text]
    annotate(tags, annotations, indices) 
    
    stemmer = SnowballStemmer("english")
    text = [stemmer.stem(x) for x in text]
    
    tagged_sent = list(zip(text, tags)) 

#     for i in text_chunks:
#         t_chunk = word_tokenize(i.lower())
#         ind_chunk = [find_sublist(x[0], i) for x in annotations]
#         tg_chunk = ['O' for x in i]
#         annotate(tg_chunk, annotations, ind_chunk)
#         t_chunk = [stemmer.stem(x) for x in t_chunk]
#         tagged_sent = list(zip(t_chunk, tg_chunk)) 
#         tagged_sents.append(tagged_sent)
        
    if len(tagged_sent) > max_length:
        ls_tagged_sent = [tagged_sent[i * max_length:(i + 1) * max_length] for i in range((len(tagged_sent) + max_length - 1) // max_length )]
        tagged_sents.extend(ls_tagged_sent)
    else:
        tagged_sents.append(tagged_sent)


HBox(children=(IntProgress(value=0, max=790), HTML(value='')))




In [7]:
tagged_sents[0]

[('record', 'O'),
 ('date', 'O'),
 (':', 'O'),
 ('2067-05-03', 'O'),
 ('narrat', 'O'),
 ('histori', 'O'),
 ('55', 'O'),
 ('yo', 'O'),
 ('woman', 'O'),
 ('who', 'O'),
 ('present', 'O'),
 ('for', 'O'),
 ('f/u', 'O'),
 ('seen', 'O'),
 ('in', 'O'),
 ('cardiac', 'O'),
 ('rehab', 'O'),
 ('local', 'O'),
 ('last', 'O'),
 ('week', 'O'),
 ('and', 'O'),
 ('bp', 'I-hypertension.high_bp.before_dct'),
 ('170/80', 'I-hypertension.high_bp.before_dct'),
 ('.', 'I-hypertension.high_bp.before_dct'),
 ('they', 'O'),
 ('call', 'O'),
 ('us', 'O'),
 ('and', 'O'),
 ('we', 'O'),
 ('increas', 'O'),
 ('her', 'O'),
 ('hctz', 'I-medication.diuretic.continuing'),
 ('to', 'O'),
 ('25', 'O'),
 ('mg', 'O'),
 ('from', 'O'),
 ('12.5', 'O'),
 ('mg.', 'O'),
 ('state', 'O'),
 ('her', 'O'),
 ('bp', 'O'),
 ("'s", 'O'),
 ('were', 'O'),
 ('fine', 'O'),
 ('there', 'O'),
 ('sinc', 'O'),
 ('-', 'O'),
 ('130-140/70-80', 'O'),
 ('.', 'O'),
 ('saw', 'O'),
 ('dr', 'O'),
 ('oakley', 'O'),
 ('4/5/67', 'O'),
 ('-', 'O'),
 ('she', 'O'),


In [8]:
# tagged_sents = tagged_sents[:4]

tags = list(set(word_pos[1] for sent in tagged_sents for word_pos in sent))

# By convention, the 0'th slot is reserved for padding.
tags = ["<pad>"] + tags

tag2idx = {tag:idx for idx, tag in enumerate(tags)}
idx2tag = {idx:tag for idx, tag in enumerate(tags)}

# Let's split the data into train and test (or eval)
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(tagged_sents, test_size=.1)
len(train_data), len(test_data)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

class PosDataset(data.Dataset):
    def __init__(self, tagged_sents):
        sents, tags_li = [], [] # list of lists
        for sent in tagged_sents:
            words = [word_pos[0] for word_pos in sent]
            tags = [word_pos[1] for word_pos in sent]
            sents.append(["[CLS]"] + words + ["[SEP]"])
            tags_li.append(["<pad>"] + tags + ["<pad>"])
        self.sents, self.tags_li = sents, tags_li

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list

        # We give credits only to the first piece.
        x, y = [], [] # list of ids
        is_heads = [] # list. 1: the token is the first piece of a word
        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0]*(len(tokens) - 1)

            t = [t] + ["<pad>"] * (len(tokens) - 1)  # <PAD>: no decision
            yy = [tag2idx[each] for each in t]  # (T,)

            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        assert len(x)==len(y)==len(is_heads), "len(x)={}, len(y)={}, len(is_heads)={}".format(len(x), len(y), len(is_heads))

        # seqlen
        seqlen = len(y)

        # to string
        words = " ".join(words)
        tags = " ".join(tags)
        return words, x, is_heads, tags, y, seqlen

In [9]:
def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    tags = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)


    f = torch.LongTensor

    return words, f(x), is_heads, tags, f(y), seqlens

In [11]:
from pytorch_pretrained_bert import BertModel

class Net(nn.Module):
    def __init__(self, vocab_size=None):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')

        self.fc = nn.Linear(768, vocab_size)
        self.device = device

    def forward(self, x, y):
        '''
        x: (N, T). int64
        y: (N, T). int64
        '''
        x = x.to(device)
        y = y.to(device)
        
        if self.training:
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]
        
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat

In [12]:
def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i%10==0: # monitoring
            print("step: {}, loss: {}".format(i, loss.item()))

def eval(model, iterator):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            Words.extend(words)
            Is_heads.extend(is_heads)
            Tags.extend(tags)
            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    ## gets results and save
    with open("result", 'w') as fout:
        for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):
            y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
            preds = [idx2tag[hat] for hat in y_hat]
            assert len(preds)==len(words.split())==len(tags.split())
            for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                fout.write("{} {} {}\n".format(w, t, p))
            fout.write("\n")
            
    ## calc metric
    y_true =  np.array([tag2idx[line.split()[1]] for line in open('result', 'r').read().splitlines() if len(line) > 0])
    y_pred =  np.array([tag2idx[line.split()[2]] for line in open('result', 'r').read().splitlines() if len(line) > 0])

    acc = (y_true==y_pred).astype(np.int32).sum() / len(y_true)

    print("acc=%.2f"%acc)

In [13]:
model = Net(vocab_size=len(tag2idx))
model.to(device)
model = nn.DataParallel(model)

train_dataset = PosDataset(train_data)
eval_dataset = PosDataset(test_data)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=4,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=pad)
test_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

optimizer = optim.Adam(model.parameters(), lr = 0.0001)

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [14]:
train(model, train_iter, optimizer, criterion)
eval(model, test_iter)

step: 0, loss: 4.719751358032227
step: 10, loss: 0.31458622217178345
step: 20, loss: 0.357718825340271
step: 30, loss: 0.5714498162269592
step: 40, loss: 1.133885383605957
step: 50, loss: 0.35433152318000793
step: 60, loss: 0.2622257471084595
step: 70, loss: 0.2912757396697998
step: 80, loss: 0.2950582206249237
step: 90, loss: 0.30344244837760925
step: 100, loss: 0.24819862842559814
step: 110, loss: 0.11084044724702835
step: 120, loss: 0.27604228258132935
step: 130, loss: 0.30938076972961426
step: 140, loss: 0.2378695160150528
step: 150, loss: 0.22441233694553375
step: 160, loss: 0.25023359060287476
step: 170, loss: 0.12621469795703888
step: 180, loss: 0.26662611961364746
step: 190, loss: 0.20178253948688507
step: 200, loss: 0.08162254095077515
step: 210, loss: 0.19173061847686768
step: 220, loss: 0.23193714022636414
step: 230, loss: 0.1111311987042427
step: 240, loss: 0.3526153266429901
step: 250, loss: 0.17804056406021118
step: 260, loss: 0.31672778725624084
step: 270, loss: 0.060564