In [18]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, pipeline
from datasets import ClassLabel, Sequence, load_dataset
import evaluate
import numpy as np
import pandas as pd
from spacy import displacy

Load pre-trained SciBERT model and tokenizer

In [19]:
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModelForTokenClassification.from_pretrained("allenai/scibert_scivocab_uncased", num_labels=3)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initi

Load dataset and split dataset into training and validation sets

In [20]:
datasets = load_dataset("ade_corpus_v2", "Ade_corpus_v2_drug_ade_relation")

Found cached dataset ade_corpus_v2 (C:/Users/albbl/.cache/huggingface/datasets/ade_corpus_v2/Ade_corpus_v2_drug_ade_relation/1.0.0/940d61334dbfac6b01ac5d00286a2122608b8dc79706ee7e9206a1edb172c559)
100%|██████████| 1/1 [00:00<00:00, 1000.31it/s]


In [21]:
datasets    

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'indexes'],
        num_rows: 6821
    })
})

In [22]:
datasets["train"][0]

{'text': 'Intravenous azithromycin-induced ototoxicity.',
 'drug': 'azithromycin',
 'effect': 'ototoxicity',
 'indexes': {'drug': {'start_char': [12], 'end_char': [24]},
  'effect': {'start_char': [33], 'end_char': [44]}}}

In [23]:
consolidated_dataset = {}

for row in datasets["train"]:
    if row["text"] in consolidated_dataset:
        consolidated_dataset[row["text"]]["drug_indices_start"].update(row["indexes"]["drug"]["start_char"])
        consolidated_dataset[row["text"]]["drug_indices_end"].update(row["indexes"]["drug"]["end_char"])
        
    else:
        consolidated_dataset[row["text"]] = {
            "text": row["text"],
            "drug": [row["drug"]],
            # use sets because the indices can repeat for various reasons
            "drug_indices_start": set(row["indexes"]["drug"]["start_char"]),
            "drug_indices_end": set(row["indexes"]["drug"]["end_char"])
        }

df = pd.DataFrame(list(consolidated_dataset.values()))
# for this trial use small subset
df = df[:50]
df.head()

Unnamed: 0,text,drug,drug_indices_start,drug_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],{12},{24}
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],{91},{109}
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],{84},{102}
3,METHODS: We report two cases of pseudoporphyri...,[naproxen],"{58, 71}","{80, 66}"
4,"Naproxen, the most common offender, has been a...",[Naproxen],{0},{8}


In [24]:
df["drug_indices_start"] = df["drug_indices_start"].apply(list).apply(sorted)
df["drug_indices_end"] = df["drug_indices_end"].apply(list).apply(sorted)
df.head()

Unnamed: 0,text,drug,drug_indices_start,drug_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],[12],[24]
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],[91],[109]
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],[84],[102]
3,METHODS: We report two cases of pseudoporphyri...,[naproxen],"[58, 71]","[66, 80]"
4,"Naproxen, the most common offender, has been a...",[Naproxen],[0],[8]


In [25]:
# save to JSON to then import into Dataset object
df.to_json("./data/dataset.jsonl", orient="records", lines=True)

cons_dataset = load_dataset("json", data_files="./data/dataset.jsonl")
cons_dataset = cons_dataset["train"].train_test_split()
cons_dataset

Found cached dataset json (C:/Users/albbl/.cache/huggingface/datasets/json/default-a5cef1f4c130dd6b/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 1/1 [00:00<00:00, 478.75it/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 375
    })
    test: Dataset({
        features: ['text', 'drug', 'drug_indices_start', 'drug_indices_end'],
        num_rows: 125
    })
})

Token Labeling

O - outside any entity we care about

B-DRUG - the beginning of a DRUG entity

I-DRUG - inside a DRUG entity

In [26]:
label_list = ['O', 'B-DRUG', 'I-DRUG']

custom_seq = Sequence(feature=ClassLabel(num_classes=3, 
                                         names=label_list,
                                         names_file=None, id=None), length=-1, id=None)

cons_dataset["train"].features["ner_tags"] = custom_seq
cons_dataset["test"].features["ner_tags"] = custom_seq

In [27]:
def generate_row_labels(row, verbose=False):
    """ Given a row from the consolidated `Ade_corpus_v2_drug_ade_relation` dataset, 
    generates BIO tags for drug and effect entities. 
    
    """

    text = row["text"]

    labels = []
    label = "O"
    prefix = ""
    
    # while iterating through tokens, increment to traverse all drug and effect spans
    drug_index = 0
    
    tokens = tokenizer(text, return_offsets_mapping=True)

    for n in range(len(tokens["input_ids"])):
        offset_start, offset_end = tokens["offset_mapping"][n]

        # should only happen for [CLS] and [SEP]
        if offset_end - offset_start == 0:
            labels.append(-100)
            continue
        
        if drug_index < len(row["drug_indices_start"]) and offset_start == row["drug_indices_start"][drug_index]:
            label = "DRUG"
            prefix = "B-"
        
        labels.append(label_list.index(f"{prefix}{label}"))
            
        if drug_index < len(row["drug_indices_end"]) and offset_end == row["drug_indices_end"][drug_index]:
            label = "O"
            prefix = ""
            drug_index += 1

        # need to transition "inside" if we just entered an entity
        if prefix == "B-":
            prefix = "I-"
    
    if verbose:
        print(f"{row}\n")
        orig = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
        for n in range(len(labels)):
            print(orig[n], labels[n])
    tokens["labels"] = labels
    
    return tokens

In [28]:
generate_row_labels(cons_dataset["train"][2], verbose=True)

{'text': 'Of the four patients who responded to HU with an increase in total Hb, all reported symptomatic improvement and three have not required further transfusions.', 'drug': ['HU'], 'drug_indices_start': [38], 'drug_indices_end': [40]}

[CLS] -100
of 0
the 0
four 0
patients 0
who 0
responded 0
to 0
hu 1
with 0
an 0
increase 0
in 0
total 0
hb 0
, 0
all 0
reported 0
symptomatic 0
improvement 0
and 0
three 0
have 0
not 0
required 0
further 0
transfusion 0
##s 0
. 0
[SEP] -100


{'input_ids': [102, 131, 111, 1379, 568, 975, 12879, 147, 4000, 190, 130, 1242, 121, 1114, 7009, 422, 355, 1214, 12104, 3523, 137, 874, 360, 302, 1761, 911, 12601, 30113, 205, 103], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 2), (3, 6), (7, 11), (12, 20), (21, 24), (25, 34), (35, 37), (38, 40), (41, 45), (46, 48), (49, 57), (58, 60), (61, 66), (67, 69), (69, 70), (71, 74), (75, 83), (84, 95), (96, 107), (108, 111), (112, 117), (118, 122), (123, 126), (127, 135), (136, 143), (144, 155), (155, 156), (156, 157), (0, 0)], 'labels': [-100, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100]}

In [29]:
labeled_dataset = cons_dataset.map(generate_row_labels)

                                                               

Fine-tuning

In [30]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "allenai/scibert_scivocab_uncased"
batch_size = 16

In [31]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.05,
    logging_steps=1
)

In [32]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [33]:
metric = evaluate.load("seqeval")

In [34]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [35]:
trainer = Trainer(
    model,
    args,
    train_dataset=labeled_dataset["train"],
    eval_dataset=labeled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, 

)

In [36]:
trainer.train()

  0%|          | 0/120 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  1%|          | 1/120 [00:09<19:34,  9.87s/it]

{'loss': 1.1536, 'learning_rate': 9.916666666666668e-06, 'epoch': 0.04}


  2%|▏         | 2/120 [00:21<20:56, 10.65s/it]

{'loss': 0.9336, 'learning_rate': 9.833333333333333e-06, 'epoch': 0.08}


  2%|▎         | 3/120 [00:28<17:54,  9.18s/it]

{'loss': 0.772, 'learning_rate': 9.75e-06, 'epoch': 0.12}


  3%|▎         | 4/120 [00:34<15:32,  8.04s/it]

{'loss': 0.6811, 'learning_rate': 9.666666666666667e-06, 'epoch': 0.17}


  4%|▍         | 5/120 [00:41<14:26,  7.53s/it]

{'loss': 0.5598, 'learning_rate': 9.583333333333335e-06, 'epoch': 0.21}


  5%|▌         | 6/120 [00:48<14:12,  7.48s/it]

{'loss': 0.4502, 'learning_rate': 9.5e-06, 'epoch': 0.25}


  6%|▌         | 7/120 [01:05<19:54, 10.57s/it]

{'loss': 0.4226, 'learning_rate': 9.416666666666667e-06, 'epoch': 0.29}


  7%|▋         | 8/120 [01:14<18:40, 10.00s/it]

{'loss': 0.3976, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.33}


  8%|▊         | 9/120 [01:23<17:50,  9.64s/it]

{'loss': 0.4672, 'learning_rate': 9.250000000000001e-06, 'epoch': 0.38}


  8%|▊         | 10/120 [01:30<15:58,  8.71s/it]

{'loss': 0.4255, 'learning_rate': 9.166666666666666e-06, 'epoch': 0.42}


  9%|▉         | 11/120 [01:36<14:44,  8.11s/it]

{'loss': 0.3899, 'learning_rate': 9.083333333333333e-06, 'epoch': 0.46}


 10%|█         | 12/120 [01:44<14:32,  8.08s/it]

{'loss': 0.3412, 'learning_rate': 9e-06, 'epoch': 0.5}


 11%|█         | 13/120 [01:50<13:15,  7.43s/it]

{'loss': 0.3788, 'learning_rate': 8.916666666666667e-06, 'epoch': 0.54}


 12%|█▏        | 14/120 [01:57<12:54,  7.30s/it]

{'loss': 0.2657, 'learning_rate': 8.833333333333334e-06, 'epoch': 0.58}


 12%|█▎        | 15/120 [02:04<12:19,  7.04s/it]

{'loss': 0.2832, 'learning_rate': 8.750000000000001e-06, 'epoch': 0.62}


 13%|█▎        | 16/120 [02:10<12:01,  6.93s/it]

{'loss': 0.243, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.67}


 14%|█▍        | 17/120 [02:17<11:57,  6.97s/it]

{'loss': 0.188, 'learning_rate': 8.583333333333333e-06, 'epoch': 0.71}


 15%|█▌        | 18/120 [02:26<12:41,  7.47s/it]

{'loss': 0.2275, 'learning_rate': 8.5e-06, 'epoch': 0.75}


 16%|█▌        | 19/120 [02:33<12:26,  7.39s/it]

{'loss': 0.246, 'learning_rate': 8.416666666666667e-06, 'epoch': 0.79}


 17%|█▋        | 20/120 [02:40<12:05,  7.25s/it]

{'loss': 0.1664, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.83}


 18%|█▊        | 21/120 [02:49<12:48,  7.76s/it]

{'loss': 0.2193, 'learning_rate': 8.25e-06, 'epoch': 0.88}


 18%|█▊        | 22/120 [02:57<12:57,  7.93s/it]

{'loss': 0.1507, 'learning_rate': 8.166666666666668e-06, 'epoch': 0.92}


 19%|█▉        | 23/120 [03:05<12:38,  7.82s/it]

{'loss': 0.1669, 'learning_rate': 8.083333333333334e-06, 'epoch': 0.96}


 20%|██        | 24/120 [03:09<10:29,  6.55s/it]

{'loss': 0.1045, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.0}


                                                
 20%|██        | 24/120 [03:26<10:29,  6.55s/it]

{'eval_loss': 0.14067256450653076, 'eval_precision': 0.8195488721804511, 'eval_recall': 0.7622377622377622, 'eval_f1': 0.7898550724637682, 'eval_accuracy': 0.9613746958637469, 'eval_runtime': 17.3047, 'eval_samples_per_second': 7.223, 'eval_steps_per_second': 0.462, 'epoch': 1.0}


 21%|██        | 25/120 [03:35<19:39, 12.42s/it]

{'loss': 0.1266, 'learning_rate': 7.916666666666667e-06, 'epoch': 1.04}


 22%|██▏       | 26/120 [03:42<16:52, 10.77s/it]

{'loss': 0.2151, 'learning_rate': 7.833333333333333e-06, 'epoch': 1.08}


 22%|██▎       | 27/120 [03:51<15:54, 10.27s/it]

{'loss': 0.1273, 'learning_rate': 7.75e-06, 'epoch': 1.12}


 23%|██▎       | 28/120 [03:57<14:01,  9.14s/it]

{'loss': 0.1509, 'learning_rate': 7.666666666666667e-06, 'epoch': 1.17}


 24%|██▍       | 29/120 [04:08<14:27,  9.53s/it]

{'loss': 0.1362, 'learning_rate': 7.583333333333333e-06, 'epoch': 1.21}


 25%|██▌       | 30/120 [04:13<12:30,  8.34s/it]

{'loss': 0.1284, 'learning_rate': 7.500000000000001e-06, 'epoch': 1.25}


 26%|██▌       | 31/120 [04:21<12:20,  8.32s/it]

{'loss': 0.086, 'learning_rate': 7.416666666666668e-06, 'epoch': 1.29}


 27%|██▋       | 32/120 [04:28<11:27,  7.81s/it]

{'loss': 0.1012, 'learning_rate': 7.333333333333333e-06, 'epoch': 1.33}


 28%|██▊       | 33/120 [04:35<11:04,  7.64s/it]

{'loss': 0.1031, 'learning_rate': 7.25e-06, 'epoch': 1.38}


 28%|██▊       | 34/120 [04:42<10:30,  7.33s/it]

{'loss': 0.118, 'learning_rate': 7.166666666666667e-06, 'epoch': 1.42}


 29%|██▉       | 35/120 [04:53<12:02,  8.50s/it]

{'loss': 0.0927, 'learning_rate': 7.083333333333335e-06, 'epoch': 1.46}


 30%|███       | 36/120 [05:05<13:29,  9.63s/it]

{'loss': 0.086, 'learning_rate': 7e-06, 'epoch': 1.5}


 31%|███       | 37/120 [05:12<12:07,  8.77s/it]

{'loss': 0.1114, 'learning_rate': 6.916666666666667e-06, 'epoch': 1.54}


 32%|███▏      | 38/120 [05:18<10:49,  7.92s/it]

{'loss': 0.0926, 'learning_rate': 6.833333333333334e-06, 'epoch': 1.58}


 32%|███▎      | 39/120 [05:25<10:14,  7.59s/it]

{'loss': 0.1279, 'learning_rate': 6.750000000000001e-06, 'epoch': 1.62}


 33%|███▎      | 40/120 [05:32<09:46,  7.33s/it]

{'loss': 0.106, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.67}


 34%|███▍      | 41/120 [05:41<10:15,  7.79s/it]

{'loss': 0.0869, 'learning_rate': 6.5833333333333335e-06, 'epoch': 1.71}


 35%|███▌      | 42/120 [05:47<09:28,  7.29s/it]

{'loss': 0.0997, 'learning_rate': 6.5000000000000004e-06, 'epoch': 1.75}


 36%|███▌      | 43/120 [05:54<09:26,  7.36s/it]

{'loss': 0.1672, 'learning_rate': 6.416666666666667e-06, 'epoch': 1.79}


 37%|███▋      | 44/120 [06:01<09:10,  7.25s/it]

{'loss': 0.0786, 'learning_rate': 6.333333333333333e-06, 'epoch': 1.83}


 38%|███▊      | 45/120 [06:08<08:48,  7.04s/it]

{'loss': 0.0849, 'learning_rate': 6.25e-06, 'epoch': 1.88}


 38%|███▊      | 46/120 [06:15<08:34,  6.96s/it]

{'loss': 0.0909, 'learning_rate': 6.166666666666667e-06, 'epoch': 1.92}


 39%|███▉      | 47/120 [06:21<08:12,  6.75s/it]

{'loss': 0.0364, 'learning_rate': 6.083333333333333e-06, 'epoch': 1.96}


 40%|████      | 48/120 [06:26<07:22,  6.14s/it]

{'loss': 0.0615, 'learning_rate': 6e-06, 'epoch': 2.0}


                                                
 40%|████      | 48/120 [06:43<07:22,  6.14s/it]

{'eval_loss': 0.07730578631162643, 'eval_precision': 0.8028169014084507, 'eval_recall': 0.7972027972027972, 'eval_f1': 0.8, 'eval_accuracy': 0.9744525547445255, 'eval_runtime': 17.1751, 'eval_samples_per_second': 7.278, 'eval_steps_per_second': 0.466, 'epoch': 2.0}


 41%|████      | 49/120 [06:56<15:58, 13.50s/it]

{'loss': 0.0404, 'learning_rate': 5.916666666666667e-06, 'epoch': 2.04}


 42%|████▏     | 50/120 [07:07<14:44, 12.63s/it]

{'loss': 0.0892, 'learning_rate': 5.833333333333334e-06, 'epoch': 2.08}


 42%|████▎     | 51/120 [07:15<12:51, 11.18s/it]

{'loss': 0.0654, 'learning_rate': 5.75e-06, 'epoch': 2.12}


 43%|████▎     | 52/120 [07:21<11:10,  9.87s/it]

{'loss': 0.0322, 'learning_rate': 5.666666666666667e-06, 'epoch': 2.17}


 44%|████▍     | 53/120 [07:27<09:33,  8.56s/it]

{'loss': 0.0713, 'learning_rate': 5.583333333333334e-06, 'epoch': 2.21}


 45%|████▌     | 54/120 [07:34<08:49,  8.03s/it]

{'loss': 0.083, 'learning_rate': 5.500000000000001e-06, 'epoch': 2.25}


 46%|████▌     | 55/120 [07:42<08:40,  8.01s/it]

{'loss': 0.1003, 'learning_rate': 5.416666666666667e-06, 'epoch': 2.29}


 47%|████▋     | 56/120 [07:48<08:09,  7.64s/it]

{'loss': 0.0401, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.33}


 48%|████▊     | 57/120 [07:55<07:42,  7.34s/it]

{'loss': 0.0772, 'learning_rate': 5.2500000000000006e-06, 'epoch': 2.38}


 48%|████▊     | 58/120 [08:01<07:12,  6.97s/it]

{'loss': 0.0613, 'learning_rate': 5.1666666666666675e-06, 'epoch': 2.42}


 49%|████▉     | 59/120 [08:09<07:26,  7.33s/it]

{'loss': 0.0502, 'learning_rate': 5.0833333333333335e-06, 'epoch': 2.46}


 50%|█████     | 60/120 [08:19<07:56,  7.94s/it]

{'loss': 0.082, 'learning_rate': 5e-06, 'epoch': 2.5}


 51%|█████     | 61/120 [08:28<08:11,  8.33s/it]

{'loss': 0.0641, 'learning_rate': 4.9166666666666665e-06, 'epoch': 2.54}


 52%|█████▏    | 62/120 [08:34<07:26,  7.70s/it]

{'loss': 0.0625, 'learning_rate': 4.833333333333333e-06, 'epoch': 2.58}


 52%|█████▎    | 63/120 [08:41<06:59,  7.37s/it]

{'loss': 0.0428, 'learning_rate': 4.75e-06, 'epoch': 2.62}


 53%|█████▎    | 64/120 [08:48<06:49,  7.31s/it]

{'loss': 0.0443, 'learning_rate': 4.666666666666667e-06, 'epoch': 2.67}


 54%|█████▍    | 65/120 [08:56<06:52,  7.50s/it]

{'loss': 0.0668, 'learning_rate': 4.583333333333333e-06, 'epoch': 2.71}


 55%|█████▌    | 66/120 [09:06<07:21,  8.18s/it]

{'loss': 0.1075, 'learning_rate': 4.5e-06, 'epoch': 2.75}


 56%|█████▌    | 67/120 [09:13<06:54,  7.82s/it]

{'loss': 0.0329, 'learning_rate': 4.416666666666667e-06, 'epoch': 2.79}


 57%|█████▋    | 68/120 [09:20<06:35,  7.61s/it]

{'loss': 0.0209, 'learning_rate': 4.333333333333334e-06, 'epoch': 2.83}


 57%|█████▊    | 69/120 [09:27<06:27,  7.60s/it]

{'loss': 0.0479, 'learning_rate': 4.25e-06, 'epoch': 2.88}


 58%|█████▊    | 70/120 [09:34<06:09,  7.38s/it]

{'loss': 0.0297, 'learning_rate': 4.166666666666667e-06, 'epoch': 2.92}


 59%|█████▉    | 71/120 [09:43<06:18,  7.72s/it]

{'loss': 0.0522, 'learning_rate': 4.083333333333334e-06, 'epoch': 2.96}


 60%|██████    | 72/120 [09:46<05:12,  6.52s/it]

{'loss': 0.0489, 'learning_rate': 4.000000000000001e-06, 'epoch': 3.0}


                                                
 60%|██████    | 72/120 [10:05<05:12,  6.52s/it]

{'eval_loss': 0.06415483355522156, 'eval_precision': 0.8354430379746836, 'eval_recall': 0.9230769230769231, 'eval_f1': 0.8770764119601329, 'eval_accuracy': 0.9826642335766423, 'eval_runtime': 18.8333, 'eval_samples_per_second': 6.637, 'eval_steps_per_second': 0.425, 'epoch': 3.0}


 61%|██████    | 73/120 [10:12<09:37, 12.30s/it]

{'loss': 0.0585, 'learning_rate': 3.916666666666667e-06, 'epoch': 3.04}


 62%|██████▏   | 74/120 [10:21<08:34, 11.19s/it]

{'loss': 0.0214, 'learning_rate': 3.833333333333334e-06, 'epoch': 3.08}


 62%|██████▎   | 75/120 [10:29<07:40, 10.24s/it]

{'loss': 0.07, 'learning_rate': 3.7500000000000005e-06, 'epoch': 3.12}


 63%|██████▎   | 76/120 [10:38<07:19,  9.99s/it]

{'loss': 0.0282, 'learning_rate': 3.6666666666666666e-06, 'epoch': 3.17}


 64%|██████▍   | 77/120 [10:46<06:37,  9.24s/it]

{'loss': 0.0455, 'learning_rate': 3.5833333333333335e-06, 'epoch': 3.21}


 65%|██████▌   | 78/120 [10:52<05:56,  8.48s/it]

{'loss': 0.0932, 'learning_rate': 3.5e-06, 'epoch': 3.25}


 66%|██████▌   | 79/120 [11:01<05:53,  8.62s/it]

{'loss': 0.0519, 'learning_rate': 3.416666666666667e-06, 'epoch': 3.29}


 67%|██████▋   | 80/120 [11:11<05:59,  8.98s/it]

{'loss': 0.0324, 'learning_rate': 3.3333333333333333e-06, 'epoch': 3.33}


 68%|██████▊   | 81/120 [11:18<05:27,  8.40s/it]

{'loss': 0.0179, 'learning_rate': 3.2500000000000002e-06, 'epoch': 3.38}


 68%|██████▊   | 82/120 [11:26<05:09,  8.14s/it]

{'loss': 0.0261, 'learning_rate': 3.1666666666666667e-06, 'epoch': 3.42}


 69%|██████▉   | 83/120 [11:34<04:56,  8.01s/it]

{'loss': 0.0513, 'learning_rate': 3.0833333333333336e-06, 'epoch': 3.46}


 70%|███████   | 84/120 [11:41<04:37,  7.72s/it]

{'loss': 0.0196, 'learning_rate': 3e-06, 'epoch': 3.5}


 71%|███████   | 85/120 [11:47<04:19,  7.42s/it]

{'loss': 0.0304, 'learning_rate': 2.916666666666667e-06, 'epoch': 3.54}


 72%|███████▏  | 86/120 [12:00<05:06,  9.00s/it]

{'loss': 0.0697, 'learning_rate': 2.8333333333333335e-06, 'epoch': 3.58}


 72%|███████▎  | 87/120 [12:09<05:01,  9.14s/it]

{'loss': 0.024, 'learning_rate': 2.7500000000000004e-06, 'epoch': 3.62}


 73%|███████▎  | 88/120 [12:19<04:54,  9.21s/it]

{'loss': 0.0393, 'learning_rate': 2.666666666666667e-06, 'epoch': 3.67}


 74%|███████▍  | 89/120 [12:26<04:23,  8.49s/it]

{'loss': 0.022, 'learning_rate': 2.5833333333333337e-06, 'epoch': 3.71}


 75%|███████▌  | 90/120 [12:34<04:16,  8.55s/it]

{'loss': 0.0178, 'learning_rate': 2.5e-06, 'epoch': 3.75}


 76%|███████▌  | 91/120 [12:44<04:16,  8.84s/it]

{'loss': 0.099, 'learning_rate': 2.4166666666666667e-06, 'epoch': 3.79}


 77%|███████▋  | 92/120 [12:52<04:03,  8.71s/it]

{'loss': 0.0335, 'learning_rate': 2.3333333333333336e-06, 'epoch': 3.83}


 78%|███████▊  | 93/120 [13:00<03:43,  8.28s/it]

{'loss': 0.0445, 'learning_rate': 2.25e-06, 'epoch': 3.88}


 78%|███████▊  | 94/120 [13:06<03:23,  7.82s/it]

{'loss': 0.0775, 'learning_rate': 2.166666666666667e-06, 'epoch': 3.92}


 79%|███████▉  | 95/120 [13:14<03:13,  7.74s/it]

{'loss': 0.0511, 'learning_rate': 2.0833333333333334e-06, 'epoch': 3.96}


 80%|████████  | 96/120 [13:17<02:36,  6.51s/it]

{'loss': 0.03, 'learning_rate': 2.0000000000000003e-06, 'epoch': 4.0}


                                                
 80%|████████  | 96/120 [13:36<02:36,  6.51s/it]

{'eval_loss': 0.05511501431465149, 'eval_precision': 0.8376623376623377, 'eval_recall': 0.9020979020979021, 'eval_f1': 0.8686868686868687, 'eval_accuracy': 0.9826642335766423, 'eval_runtime': 18.1406, 'eval_samples_per_second': 6.891, 'eval_steps_per_second': 0.441, 'epoch': 4.0}


 81%|████████  | 97/120 [13:43<04:38, 12.12s/it]

{'loss': 0.0161, 'learning_rate': 1.916666666666667e-06, 'epoch': 4.04}


 82%|████████▏ | 98/120 [13:52<04:09, 11.33s/it]

{'loss': 0.0223, 'learning_rate': 1.8333333333333333e-06, 'epoch': 4.08}


 82%|████████▎ | 99/120 [14:02<03:48, 10.87s/it]

{'loss': 0.0479, 'learning_rate': 1.75e-06, 'epoch': 4.12}


 83%|████████▎ | 100/120 [14:09<03:11,  9.59s/it]

{'loss': 0.0561, 'learning_rate': 1.6666666666666667e-06, 'epoch': 4.17}


 84%|████████▍ | 101/120 [14:16<02:48,  8.87s/it]

{'loss': 0.0133, 'learning_rate': 1.5833333333333333e-06, 'epoch': 4.21}


 85%|████████▌ | 102/120 [14:24<02:38,  8.79s/it]

{'loss': 0.0223, 'learning_rate': 1.5e-06, 'epoch': 4.25}


 86%|████████▌ | 103/120 [14:32<02:24,  8.47s/it]

{'loss': 0.0461, 'learning_rate': 1.4166666666666667e-06, 'epoch': 4.29}


 87%|████████▋ | 104/120 [14:39<02:09,  8.09s/it]

{'loss': 0.0407, 'learning_rate': 1.3333333333333334e-06, 'epoch': 4.33}


 88%|████████▊ | 105/120 [14:45<01:52,  7.53s/it]

{'loss': 0.0159, 'learning_rate': 1.25e-06, 'epoch': 4.38}


 88%|████████▊ | 106/120 [14:52<01:43,  7.37s/it]

{'loss': 0.0342, 'learning_rate': 1.1666666666666668e-06, 'epoch': 4.42}


 89%|████████▉ | 107/120 [15:01<01:38,  7.60s/it]

{'loss': 0.0111, 'learning_rate': 1.0833333333333335e-06, 'epoch': 4.46}


 90%|█████████ | 108/120 [15:08<01:28,  7.39s/it]

{'loss': 0.0202, 'learning_rate': 1.0000000000000002e-06, 'epoch': 4.5}


 91%|█████████ | 109/120 [15:15<01:20,  7.32s/it]

{'loss': 0.0288, 'learning_rate': 9.166666666666666e-07, 'epoch': 4.54}


 92%|█████████▏| 110/120 [15:24<01:18,  7.88s/it]

{'loss': 0.0276, 'learning_rate': 8.333333333333333e-07, 'epoch': 4.58}


 92%|█████████▎| 111/120 [15:32<01:11,  7.96s/it]

{'loss': 0.0173, 'learning_rate': 7.5e-07, 'epoch': 4.62}


 93%|█████████▎| 112/120 [15:39<01:01,  7.67s/it]

{'loss': 0.0222, 'learning_rate': 6.666666666666667e-07, 'epoch': 4.67}


 94%|█████████▍| 113/120 [15:48<00:56,  8.07s/it]

{'loss': 0.0286, 'learning_rate': 5.833333333333334e-07, 'epoch': 4.71}


 95%|█████████▌| 114/120 [16:01<00:56,  9.48s/it]

{'loss': 0.0525, 'learning_rate': 5.000000000000001e-07, 'epoch': 4.75}


 96%|█████████▌| 115/120 [16:09<00:45,  9.16s/it]

{'loss': 0.0858, 'learning_rate': 4.1666666666666667e-07, 'epoch': 4.79}


 97%|█████████▋| 116/120 [16:16<00:33,  8.49s/it]

{'loss': 0.0316, 'learning_rate': 3.3333333333333335e-07, 'epoch': 4.83}


 98%|█████████▊| 117/120 [16:24<00:24,  8.19s/it]

{'loss': 0.045, 'learning_rate': 2.5000000000000004e-07, 'epoch': 4.88}


 98%|█████████▊| 118/120 [16:32<00:16,  8.39s/it]

{'loss': 0.0102, 'learning_rate': 1.6666666666666668e-07, 'epoch': 4.92}


 99%|█████████▉| 119/120 [16:42<00:08,  8.69s/it]

{'loss': 0.0585, 'learning_rate': 8.333333333333334e-08, 'epoch': 4.96}


100%|██████████| 120/120 [16:45<00:00,  7.14s/it]

{'loss': 0.057, 'learning_rate': 0.0, 'epoch': 5.0}


                                                 
100%|██████████| 120/120 [17:03<00:00,  8.53s/it]

{'eval_loss': 0.05464215949177742, 'eval_precision': 0.8280254777070064, 'eval_recall': 0.9090909090909091, 'eval_f1': 0.8666666666666667, 'eval_accuracy': 0.9820559610705596, 'eval_runtime': 17.9349, 'eval_samples_per_second': 6.97, 'eval_steps_per_second': 0.446, 'epoch': 5.0}
{'train_runtime': 1023.8974, 'train_samples_per_second': 1.831, 'train_steps_per_second': 0.117, 'train_loss': 0.1294136055589964, 'epoch': 5.0}





TrainOutput(global_step=120, training_loss=0.1294136055589964, metrics={'train_runtime': 1023.8974, 'train_samples_per_second': 1.831, 'train_steps_per_second': 0.117, 'train_loss': 0.1294136055589964, 'epoch': 5.0})

In [46]:
sequence = "I was treated with furosemid."
ner_pipe = pipeline(task="ner", model=model, tokenizer=tokenizer)
for entity in ner_pipe(sequence):
    print(entity)

{'entity': 'LABEL_0', 'score': 0.9763866, 'index': 1, 'word': 'i', 'start': 0, 'end': 1}
{'entity': 'LABEL_0', 'score': 0.997759, 'index': 2, 'word': 'was', 'start': 2, 'end': 5}
{'entity': 'LABEL_0', 'score': 0.9955455, 'index': 3, 'word': 'treated', 'start': 6, 'end': 13}
{'entity': 'LABEL_0', 'score': 0.9892511, 'index': 4, 'word': 'with', 'start': 14, 'end': 18}
{'entity': 'LABEL_1', 'score': 0.84162325, 'index': 5, 'word': 'fur', 'start': 19, 'end': 22}
{'entity': 'LABEL_2', 'score': 0.94467264, 'index': 6, 'word': '##ose', 'start': 22, 'end': 25}
{'entity': 'LABEL_2', 'score': 0.98947084, 'index': 7, 'word': '##mid', 'start': 25, 'end': 28}
{'entity': 'LABEL_0', 'score': 0.9896006, 'index': 8, 'word': '.', 'start': 28, 'end': 29}


In [37]:
predictions, labels, _ = trainer.predict(labeled_dataset["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

100%|██████████| 8/8 [00:15<00:00,  1.91s/it]


array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int64)

In [38]:
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer, device=-1)

In [39]:
def visualize_entities(sentence):
    tokens = effect_ner_model(sentence)
    entities = []
    
    for token in tokens:
        label = int(token["entity"][-1])
        if label != 0:
            token["label"] = label_list[label]
            entities.append(token)
    
    params = [{"text": sentence,
               "ents": entities,
               "title": None}]
    
    html = displacy.render(params, style="ent", manual=True, options={
        "colors": {
                   "B-DRUG": "#f08080",
                   "I-DRUG": "#f08080",
               },
    })
    

In [40]:
#examples = [
#    "Abortion, miscarriage or uterine hemorrhage associated with misoprostol (Cytotec), a labor-inducing drug.",
#    "Addiction to many sedatives and analgesics, such as diazepam, morphine, etc.",
#    "Birth defects associated with thalidomide",
#    "Bleeding of the intestine associated with aspirin therapy",
#    "Cardiovascular disease associated with COX-2 inhibitors (i.e. Vioxx)",
#    "Deafness and kidney failure associated with gentamicin (an antibiotic)"
#]

# for example in examples:
#     visualize_entities(example)
#     print(f"{'*' * 50}\n")

with open("./data/Albers.txt", "r", encoding="utf8") as file:
    examples = file.readlines()
for example in examples:
    if example != "":
        visualize_entities(example)
        print(f"{'*' * 50}\n")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************

