In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, pipeline
from datasets import ClassLabel, Sequence, load_dataset
import evaluate
import numpy as np
import pandas as pd
from spacy import displacy

Load pre-trained SciBERT model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModelForTokenClassification.from_pretrained("allenai/scibert_scivocab_uncased", num_labels=3)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initi

Load dataset and split dataset into training and validation sets

In [3]:
datasets = load_dataset("ade_corpus_v2", "Ade_corpus_v2_drug_ade_relation")

Found cached dataset ade_corpus_v2 (/home/linuxmint/.cache/huggingface/datasets/ade_corpus_v2/Ade_corpus_v2_drug_ade_relation/1.0.0/940d61334dbfac6b01ac5d00286a2122608b8dc79706ee7e9206a1edb172c559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
datasets    

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'indexes'],
        num_rows: 6821
    })
})

In [5]:
datasets["train"][0]

{'text': 'Intravenous azithromycin-induced ototoxicity.',
 'drug': 'azithromycin',
 'effect': 'ototoxicity',
 'indexes': {'drug': {'start_char': [12], 'end_char': [24]},
  'effect': {'start_char': [33], 'end_char': [44]}}}

In [6]:
consolidated_dataset = {}

for row in datasets["train"]:
    if row["text"] in consolidated_dataset:
        consolidated_dataset[row["text"]]["drug_indices_start"].update(row["indexes"]["drug"]["start_char"])
        consolidated_dataset[row["text"]]["drug_indices_end"].update(row["indexes"]["drug"]["end_char"])
        
    else:
        consolidated_dataset[row["text"]] = {
            "text": row["text"],
            "drug": [row["drug"]],
            # use sets because the indices can repeat for various reasons
            "drug_indices_start": set(row["indexes"]["drug"]["start_char"]),
            "drug_indices_end": set(row["indexes"]["drug"]["end_char"])
        }

df = pd.DataFrame(list(consolidated_dataset.values()))
# for this trial use small subset
df = df[:500]
df.head()

Unnamed: 0,text,drug,drug_indices_start,drug_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],{12},{24}
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],{91},{109}
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],{84},{102}
3,METHODS: We report two cases of pseudoporphyri...,[naproxen],"{58, 71}","{80, 66}"
4,"Naproxen, the most common offender, has been a...",[Naproxen],{0},{8}


In [7]:
df["drug_indices_start"] = df["drug_indices_start"].apply(list).apply(sorted)
df["drug_indices_end"] = df["drug_indices_end"].apply(list).apply(sorted)
df.head()

Unnamed: 0,text,drug,drug_indices_start,drug_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],[12],[24]
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],[91],[109]
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],[84],[102]
3,METHODS: We report two cases of pseudoporphyri...,[naproxen],"[58, 71]","[66, 80]"
4,"Naproxen, the most common offender, has been a...",[Naproxen],[0],[8]


In [8]:
# save to JSON to then import into Dataset object
df.to_json("dataset.jsonl", orient="records", lines=True)

cons_dataset = load_dataset("json", data_files="dataset.jsonl")
cons_dataset = cons_dataset["train"].train_test_split()
cons_dataset

Downloading and preparing dataset json/default to /home/linuxmint/.cache/huggingface/datasets/json/default-33347ae98205c6fd/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/linuxmint/.cache/huggingface/datasets/json/default-33347ae98205c6fd/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 375
    })
    test: Dataset({
        features: ['text', 'drug', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 125
    })
})

Token Labeling

O - outside any entity we care about

B-DRUG - the beginning of a DRUG entity

I-DRUG - inside a DRUG entity

In [9]:
label_list = ['O', 'B-DRUG', 'I-DRUG']

custom_seq = Sequence(feature=ClassLabel(num_classes=3, 
                                         names=label_list,
                                         names_file=None, id=None), length=-1, id=None)

cons_dataset["train"].features["ner_tags"] = custom_seq
cons_dataset["test"].features["ner_tags"] = custom_seq

In [10]:
def generate_row_labels(row, verbose=False):
    """ Given a row from the consolidated `Ade_corpus_v2_drug_ade_relation` dataset, 
    generates BIO tags for drug and effect entities. 
    
    """

    text = row["text"]

    labels = []
    label = "O"
    prefix = ""
    
    # while iterating through tokens, increment to traverse all drug and effect spans
    drug_index = 0
    
    tokens = tokenizer(text, return_offsets_mapping=True)

    for n in range(len(tokens["input_ids"])):
        offset_start, offset_end = tokens["offset_mapping"][n]

        # should only happen for [CLS] and [SEP]
        if offset_end - offset_start == 0:
            labels.append(-100)
            continue
        
        if drug_index < len(row["drug_indices_start"]) and offset_start == row["drug_indices_start"][drug_index]:
            label = "DRUG"
            prefix = "B-"
        
        labels.append(label_list.index(f"{prefix}{label}"))
            
        if drug_index < len(row["drug_indices_end"]) and offset_end == row["drug_indices_end"][drug_index]:
            label = "O"
            prefix = ""
            drug_index += 1

        # need to transition "inside" if we just entered an entity
        if prefix == "B-":
            prefix = "I-"
    
    if verbose:
        print(f"{row}\n")
        orig = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
        for n in range(len(labels)):
            print(orig[n], labels[n])
    tokens["labels"] = labels
    
    return tokens

In [11]:
generate_row_labels(cons_dataset["train"][2], verbose=True)

{'text': 'In some cases this seems to happen because spironolactone causes diarrhoea.', 'drug': ['spironolactone'], 'drug_indices_start': [43], 'drug_indices_end': [57]}

[CLS] -100
in 0
some 0
cases 0
this 0
seems 0
to 0
happen 0
because 0
spiro 1
##no 2
##lact 2
##one 2
causes 0
diarrhoea 0
. 0
[SEP] -100


{'input_ids': [102, 121, 693, 1299, 238, 4109, 147, 12535, 923, 28968, 2682, 15438, 574, 4080, 27596, 205, 103], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 2), (3, 7), (8, 13), (14, 18), (19, 24), (25, 27), (28, 34), (35, 42), (43, 48), (48, 50), (50, 54), (54, 57), (58, 64), (65, 74), (74, 75), (0, 0)], 'labels': [-100, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, -100]}

In [12]:
labeled_dataset = cons_dataset.map(generate_row_labels)

Map:   0%|          | 0/375 [00:00<?, ? examples/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Fine-tuning

In [13]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "allenai/scibert_scivocab_uncased"
batch_size = 16

In [14]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.05,
    logging_steps=1
)

In [15]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [16]:
metric = evaluate.load("seqeval")

In [17]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [18]:
trainer = Trainer(
    model,
    args,
    train_dataset=labeled_dataset["train"],
    eval_dataset=labeled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, 

)

In [19]:
trainer.train()



  0%|          | 0/120 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.1622, 'learning_rate': 9.916666666666668e-06, 'epoch': 0.04}
{'loss': 0.956, 'learning_rate': 9.833333333333333e-06, 'epoch': 0.08}
{'loss': 0.7715, 'learning_rate': 9.75e-06, 'epoch': 0.12}
{'loss': 0.6826, 'learning_rate': 9.666666666666667e-06, 'epoch': 0.17}
{'loss': 0.5312, 'learning_rate': 9.583333333333335e-06, 'epoch': 0.21}
{'loss': 0.5858, 'learning_rate': 9.5e-06, 'epoch': 0.25}
{'loss': 0.4812, 'learning_rate': 9.416666666666667e-06, 'epoch': 0.29}
{'loss': 0.4224, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.33}
{'loss': 0.3507, 'learning_rate': 9.250000000000001e-06, 'epoch': 0.38}
{'loss': 0.4277, 'learning_rate': 9.166666666666666e-06, 'epoch': 0.42}
{'loss': 0.3596, 'learning_rate': 9.083333333333333e-06, 'epoch': 0.46}
{'loss': 0.3445, 'learning_rate': 9e-06, 'epoch': 0.5}
{'loss': 0.2788, 'learning_rate': 8.916666666666667e-06, 'epoch': 0.54}
{'loss': 0.3014, 'learning_rate': 8.833333333333334e-06, 'epoch': 0.58}
{'loss': 0.304, 'learning_rate': 8.75

  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.15661506354808807, 'eval_precision': 0.6413793103448275, 'eval_recall': 0.636986301369863, 'eval_f1': 0.6391752577319587, 'eval_accuracy': 0.9452599388379205, 'eval_runtime': 87.8638, 'eval_samples_per_second': 1.423, 'eval_steps_per_second': 0.091, 'epoch': 1.0}
{'loss': 0.1228, 'learning_rate': 7.916666666666667e-06, 'epoch': 1.04}
{'loss': 0.145, 'learning_rate': 7.833333333333333e-06, 'epoch': 1.08}
{'loss': 0.1777, 'learning_rate': 7.75e-06, 'epoch': 1.12}
{'loss': 0.1058, 'learning_rate': 7.666666666666667e-06, 'epoch': 1.17}
{'loss': 0.135, 'learning_rate': 7.583333333333333e-06, 'epoch': 1.21}
{'loss': 0.1531, 'learning_rate': 7.500000000000001e-06, 'epoch': 1.25}
{'loss': 0.1542, 'learning_rate': 7.416666666666668e-06, 'epoch': 1.29}
{'loss': 0.1007, 'learning_rate': 7.333333333333333e-06, 'epoch': 1.33}
{'loss': 0.0992, 'learning_rate': 7.25e-06, 'epoch': 1.38}
{'loss': 0.131, 'learning_rate': 7.166666666666667e-06, 'epoch': 1.42}
{'loss': 0.1051, 'learning_ra

  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.08330574631690979, 'eval_precision': 0.7810650887573964, 'eval_recall': 0.9041095890410958, 'eval_f1': 0.838095238095238, 'eval_accuracy': 0.9755351681957186, 'eval_runtime': 79.4681, 'eval_samples_per_second': 1.573, 'eval_steps_per_second': 0.101, 'epoch': 2.0}
{'loss': 0.0362, 'learning_rate': 5.916666666666667e-06, 'epoch': 2.04}
{'loss': 0.0509, 'learning_rate': 5.833333333333334e-06, 'epoch': 2.08}
{'loss': 0.0458, 'learning_rate': 5.75e-06, 'epoch': 2.12}
{'loss': 0.079, 'learning_rate': 5.666666666666667e-06, 'epoch': 2.17}
{'loss': 0.0396, 'learning_rate': 5.583333333333334e-06, 'epoch': 2.21}
{'loss': 0.0569, 'learning_rate': 5.500000000000001e-06, 'epoch': 2.25}
{'loss': 0.0489, 'learning_rate': 5.416666666666667e-06, 'epoch': 2.29}
{'loss': 0.0696, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.33}
{'loss': 0.043, 'learning_rate': 5.2500000000000006e-06, 'epoch': 2.38}
{'loss': 0.0619, 'learning_rate': 5.1666666666666675e-06, 'epoch': 2.42}
{'loss': 0.03

  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.05714305490255356, 'eval_precision': 0.8535031847133758, 'eval_recall': 0.9178082191780822, 'eval_f1': 0.8844884488448845, 'eval_accuracy': 0.9807339449541285, 'eval_runtime': 70.3165, 'eval_samples_per_second': 1.778, 'eval_steps_per_second': 0.114, 'epoch': 3.0}
{'loss': 0.055, 'learning_rate': 3.916666666666667e-06, 'epoch': 3.04}
{'loss': 0.0454, 'learning_rate': 3.833333333333334e-06, 'epoch': 3.08}
{'loss': 0.0297, 'learning_rate': 3.7500000000000005e-06, 'epoch': 3.12}
{'loss': 0.025, 'learning_rate': 3.6666666666666666e-06, 'epoch': 3.17}
{'loss': 0.0388, 'learning_rate': 3.5833333333333335e-06, 'epoch': 3.21}
{'loss': 0.0278, 'learning_rate': 3.5e-06, 'epoch': 3.25}
{'loss': 0.0293, 'learning_rate': 3.416666666666667e-06, 'epoch': 3.29}
{'loss': 0.0291, 'learning_rate': 3.3333333333333333e-06, 'epoch': 3.33}
{'loss': 0.0612, 'learning_rate': 3.2500000000000002e-06, 'epoch': 3.38}
{'loss': 0.0202, 'learning_rate': 3.1666666666666667e-06, 'epoch': 3.42}
{'loss': 

  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.0610656663775444, 'eval_precision': 0.7976878612716763, 'eval_recall': 0.9452054794520548, 'eval_f1': 0.865203761755486, 'eval_accuracy': 0.9782874617737003, 'eval_runtime': 85.8129, 'eval_samples_per_second': 1.457, 'eval_steps_per_second': 0.093, 'epoch': 4.0}
{'loss': 0.0591, 'learning_rate': 1.916666666666667e-06, 'epoch': 4.04}
{'loss': 0.0297, 'learning_rate': 1.8333333333333333e-06, 'epoch': 4.08}
{'loss': 0.017, 'learning_rate': 1.75e-06, 'epoch': 4.12}
{'loss': 0.0321, 'learning_rate': 1.6666666666666667e-06, 'epoch': 4.17}
{'loss': 0.0363, 'learning_rate': 1.5833333333333333e-06, 'epoch': 4.21}
{'loss': 0.037, 'learning_rate': 1.5e-06, 'epoch': 4.25}
{'loss': 0.1067, 'learning_rate': 1.4166666666666667e-06, 'epoch': 4.29}
{'loss': 0.0141, 'learning_rate': 1.3333333333333334e-06, 'epoch': 4.33}
{'loss': 0.0162, 'learning_rate': 1.25e-06, 'epoch': 4.38}
{'loss': 0.0136, 'learning_rate': 1.1666666666666668e-06, 'epoch': 4.42}
{'loss': 0.0822, 'learning_rate': 1.0

  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.051752250641584396, 'eval_precision': 0.8466257668711656, 'eval_recall': 0.9452054794520548, 'eval_f1': 0.8932038834951457, 'eval_accuracy': 0.982262996941896, 'eval_runtime': 69.9311, 'eval_samples_per_second': 1.787, 'eval_steps_per_second': 0.114, 'epoch': 5.0}
{'train_runtime': 4418.5339, 'train_samples_per_second': 0.424, 'train_steps_per_second': 0.027, 'train_loss': 0.1257776054398467, 'epoch': 5.0}


TrainOutput(global_step=120, training_loss=0.1257776054398467, metrics={'train_runtime': 4418.5339, 'train_samples_per_second': 0.424, 'train_steps_per_second': 0.027, 'train_loss': 0.1257776054398467, 'epoch': 5.0})

In [20]:
predictions, labels, _ = trainer.predict(labeled_dataset["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

  0%|          | 0/8 [00:00<?, ?it/s]

{'DRUG': {'precision': 0.8466257668711656,
  'recall': 0.9452054794520548,
  'f1': 0.8932038834951457,
  'number': 146},
 'overall_precision': 0.8466257668711656,
 'overall_recall': 0.9452054794520548,
 'overall_f1': 0.8932038834951457,
 'overall_accuracy': 0.982262996941896}

In [21]:
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer, device=-1)

In [22]:
def visualize_entities(sentence):
    tokens = effect_ner_model(sentence)
    entities = []
    
    for token in tokens:
        label = int(token["entity"][-1])
        if label != 0:
            token["label"] = label_list[label]
            entities.append(token)
    
    params = [{"text": sentence,
               "ents": entities,
               "title": None}]
    
    html = displacy.render(params, style="ent", manual=True, options={
        "colors": {
                   "B-DRUG": "#f08080",
                   "I-DRUG": "#f08080",
               },
    })
    

In [25]:
#examples = [
#    "Abortion, miscarriage or uterine hemorrhage associated with misoprostol (Cytotec), a labor-inducing drug.",
#    "Addiction to many sedatives and analgesics, such as diazepam, morphine, etc.",
#    "Birth defects associated with thalidomide",
#    "Bleeding of the intestine associated with aspirin therapy",
#    "Cardiovascular disease associated with COX-2 inhibitors (i.e. Vioxx)",
#    "Deafness and kidney failure associated with gentamicin (an antibiotic)"
#]

# for example in examples:
#     visualize_entities(example)
#     print(f"{'*' * 50}\n")

with open("Albers.txt", "r", encoding="utf8") as file:
    examples = file.readlines()
for example in examples:
    if example != "":
        visualize_entities(example)
        print(f"{'*' * 50}\n")

**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************

