In [1]:
import os
import sys
sys.path.append("../")
import json
from argparse import ArgumentParser
from datasets import Dataset
from transformers import DataCollatorForTokenClassification, BertForTokenClassification, BertTokenizer 
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from pprint import pprint
import pandas as pd

import spacy
from spacy import displacy


def load_data(path):
    data = pd.read_csv(path, header=None, delimiter='\t')
    data.columns = ['sent_id', 'text', 'label']
    data = data.groupby('sent_id').agg(list).reset_index()
    data = [(row.text, row.label) for row in data.itertuples()]
    return data

def tokenize_with_labels(tokenizer, sent_words, sent_labels, special_label):
    tok_sent = []
    labels = []
    for word, label in zip(sent_words, sent_labels):
        if type(word) == str:
            tok_word = tokenizer.tokenize(word)
            n_subwords = len(tok_word)

            tok_sent.extend(tok_word)
            labels.extend([label] * n_subwords)
    
    # Add special tokens
    if tok_sent[0] != '[CLS]':
        tok_sent.insert(0, '[CLS]')
        labels.insert(0, special_label)
    if tok_sent[-1] != '[SEP]':
        if tok_sent[-1] not in '.!?;':
            tok_sent.append('.')
            labels.append('O')
        tok_sent.append('[SEP]')
        labels.append(special_label)

    return tok_sent, labels

In [21]:
# raw_train_data = load_data("../data/ner_data_formatted/train.tsv")
# raw_test_data = load_data("../data/ner_data_formatted/test.tsv")
# raw_train_data = [(i, j) for i, j in raw_train_data if len(i) > 2 and not all(k=="O" for k in j)]
# raw_test_data = [(i, j) for i, j in raw_test_data if len(i) > 2 and not all(k=="O" for k in j)]

tokenizer = BertTokenizer(vocab_file="/anvil/projects/tdm/corporate/battelle-nl/ADE_NER_2023-02-04_4/vocab.txt", do_lower_case=False)

# # Tokenize data
# train_data = [tokenize_with_labels(tokenizer, i, j, '[PAD]') for i, j in raw_train_data if len(i) > 2]
# test_data = [tokenize_with_labels(tokenizer, i, j, '[PAD]') for i, j in raw_test_data if len(i) > 2]
# train_sents, train_labels = zip(*train_data)
# test_sents, test_labels = zip(*test_data)

# print("Labels:")
# pprint(set([l for sent in train_labels for l in sent])) 

labels = ['B-ADE',
    'B-Dosage',
    'B-Drug',
    'B-Duration',
    'B-Form',
    'B-Frequency',
    'B-Reason',
    'B-Route',
    'B-Strength',
    'I-ADE',
    'I-Dosage',
    'I-Drug',
    'I-Duration',
    'I-Form',
    'I-Frequency',
    'I-Reason',
    'I-Route',
    'I-Strength',
    'L-ADE',
    'L-Dosage',
    'L-Drug',
    'L-Duration',
    'L-Form',
    'L-Frequency',
    'L-Reason',
    'L-Route',
    'L-Strength',
    'O',
    'U-ADE',
    'U-Dosage',
    'U-Drug',
    'U-Duration',
    'U-Form',
    'U-Frequency',
    'U-Reason',
    'U-Route',
    'U-Strength',
    '[PAD]']



print("Loading pipeline")
nlp = pipeline("ner", model="/anvil/projects/tdm/corporate/battelle-nl/ADE_NER_2023-02-04_4", tokenizer="bert-base-cased")


Loading pipeline


In [97]:
print("Loading test.txt")

# display a menu to pick a file to process
files = os.listdir("../data/ner_data_formatted/txt/")
file = files[109]

text = ""
lines = []
processedlines = []
with open(os.path.join("../data/ner_data_formatted/txt/", file)) as f:
    for line in f:
        text += line
        lines.append(line)
        processedlines.append(nlp(line))

zipped = zip(lines, processedlines)

Loading test.txt


In [98]:
#process syllables
combinedlines = []

currlen = 0
for item in zipped:
    processed = item[1]
    combined = []
    for i in range(len(processed)):
        if processed[i]["word"].startswith('##'):
            continue
        # Otherwise, combine it with the next string if it starts with "##"
        word = processed[i]["word"]
        start = processed[i]["start"]
        end = processed[i]["end"]
        for j in range(i+1, len(processed)):
            if processed[j]["word"].startswith('##'):
                word += processed[j]["word"][2:]
                end = processed[j]["end"]
            else:
                break
        # consider the previous end
        combined.append({"word": word, "entity": processed[i]["entity"], "start": start + currlen, "end": end + currlen})
    currlen += len(item[0])
    combinedlines.extend(combined)





for i in range(len(combinedlines)):
    # example output: {'end': None, 'entity': 'LABEL_27', 'index': 131, 'score': 0.9996716, 'start': None, 'word': 'and'}
    combinedlines[i]["label"] = labels[int(combinedlines[i]["entity"].split("_")[1])]
    # if i == 0:
    #     combined[i]["start"] = 0
    #     combined[i]["end"] = len(combined[i]["word"])
    # else:
    #     combined[i]["start"] = combined[i-1]["end"] + 2
    #     combined[i]["end"] = combined[i]["start"] + len(combined[i]["word"])

pprint(combinedlines[102:150])
    
# Generate the visualization using displacy module
# options = {"ents": labels}
options = {"ents": [ent for ent in labels if ent != "O"]}

print("len of ents")
print(len(options["ents"]))
# doc = {"text": text, "ents": [{"start": i["start"], "end": i["end"], "label": i["label"]} for i in combinedlines]}
doc = {"text": text, "ents": [{"start": i["start"], "end": i["end"], "label": i["label"]} for i in combinedlines if i["label"] != "O"]}



[{'end': 417,
  'entity': 'LABEL_27',
  'label': 'O',
  'start': 411,
  'word': 'cancer'},
 {'end': 418, 'entity': 'LABEL_27', 'label': 'O', 'start': 417, 'word': ','},
 {'end': 426,
  'entity': 'LABEL_34',
  'label': 'U-Reason',
  'start': 419,
  'word': 'colitis'},
 {'end': 430, 'entity': 'LABEL_27', 'label': 'O', 'start': 427, 'word': 'NOS'},
 {'end': 441,
  'entity': 'LABEL_27',
  'label': 'O',
  'start': 431,
  'word': 'presenting'},
 {'end': 446, 'entity': 'LABEL_27', 'label': 'O', 'start': 442, 'word': 'with'},
 {'end': 453,
  'entity': 'LABEL_28',
  'label': 'U-ADE',
  'start': 447,
  'word': 'melana'},
 {'end': 454, 'entity': 'LABEL_27', 'label': 'O', 'start': 453, 'word': '.'},
 {'end': 463,
  'entity': 'LABEL_27',
  'label': 'O',
  'start': 456,
  'word': 'Patient'},
 {'end': 467, 'entity': 'LABEL_27', 'label': 'O', 'start': 464, 'word': 'was'},
 {'end': 470, 'entity': 'LABEL_27', 'label': 'O', 'start': 468, 'word': 'in'},
 {'end': 474, 'entity': 'LABEL_27', 'label': 'O', 's

In [99]:
displacy.render(doc, style="ent", options=options, manual=True)