<a href="https://colab.research.google.com/github/ganeevsingh18/Drug_prediction/blob/main/Copy_of_fine_tuning_drugs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SciBERT Fine-Tuning on Drug/ADE Corpus

In [None]:
! pip install datasets transformers seqeval

Collecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)


In [None]:
! pip install spacy



In [None]:
from datasets import Dataset, ClassLabel, Sequence, load_dataset, load_metric
import numpy as np
import pandas as pd
from spacy import displacy
import transformers
from transformers import (AutoModelForTokenClassification,
                          AutoTokenizer,
                          DataCollatorForTokenClassification,
                          pipeline,
                          TrainingArguments,
                          Trainer)

---
## Dataset Exploration

We use the `Ade_corpus_v2_drug_ade_relation` subset of the `ade_corpus_v2` dataset, which provides labeled spans for drug names and adverse effects.

See dataset page here: https://huggingface.co/datasets/ade_corpus_v2

In [None]:
datasets = load_dataset("ade_corpus_v2", "Ade_corpus_v2_drug_ade_relation")

Downloading builder script:   0%|          | 0.00/11.7k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/11.9k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.84k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/307k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/868k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6821 [00:00<?, ? examples/s]

In [None]:
datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'indexes'],
        num_rows: 6821
    })
})

In [None]:
datasets["train"][0]

{'text': 'Intravenous azithromycin-induced ototoxicity.',
 'drug': 'azithromycin',
 'effect': 'ototoxicity',
 'indexes': {'drug': {'start_char': [12], 'end_char': [24]},
  'effect': {'start_char': [33], 'end_char': [44]}}}

## Dataset Consolidation
----
Upon further examination of the dataset, we can see that sentences are often repeated to identify different pairs of drugs and adverse reactions. For example, see this sentence from the dataset:
```
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'increasing myalgia', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [147], 'end_char': [165]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'lost consciousness', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [233], 'end_char': [251]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'high fever', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [179], 'end_char': [189]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'high fever', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [179], 'end_char': [189]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'lost consciousness', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [233], 'end_char': [251]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'respiratory and metabolic acidosis', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [194], 'end_char': [228]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'respiratory and metabolic acidosis', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [194], 'end_char': [228]}}}
```

This is not ideal in an NER setting - if we assigned one set of token labels per row in this dataset as-is, we would end up giving different labels to the same tokens in the same sentences. This would confuse the model during fine-tuning, so we need to consolidate all of the ranges provided for each unique sentence, before performing one pass to label all known entities.

In [None]:
consolidated_dataset = {}

for row in datasets["train"]:
    if row["text"] in consolidated_dataset:
        consolidated_dataset[row["text"]]["drug_indices_start"].update(row["indexes"]["drug"]["start_char"])
        consolidated_dataset[row["text"]]["drug_indices_end"].update(row["indexes"]["drug"]["end_char"])
        consolidated_dataset[row["text"]]["effect_indices_start"].update(row["indexes"]["effect"]["start_char"])
        consolidated_dataset[row["text"]]["effect_indices_end"].update(row["indexes"]["effect"]["end_char"])
        consolidated_dataset[row["text"]]["drug"].append(row["drug"])
        consolidated_dataset[row["text"]]["effect"].append(row["effect"])

    else:
        consolidated_dataset[row["text"]] = {
            "text": row["text"],
            "drug": [row["drug"]],
            "effect": [row["effect"]],
            # use sets because the indices can repeat for various reasons
            "drug_indices_start": set(row["indexes"]["drug"]["start_char"]),
            "drug_indices_end": set(row["indexes"]["drug"]["end_char"]),
            "effect_indices_start": set(row["indexes"]["effect"]["start_char"]),
            "effect_indices_end": set(row["indexes"]["effect"]["end_char"])
        }

---
With the dataset consolidated, we need to assign per-token labels to each sentence. First, we re-define our Python data structure as a Hugging Face Dataset object.

In [None]:
pd.DataFrame(datasets["train"])

Unnamed: 0,text,drug,effect,indexes
0,Intravenous azithromycin-induced ototoxicity.,azithromycin,ototoxicity,"{'drug': {'start_char': [12], 'end_char': [24]..."
1,"Immobilization, while Paget's bone disease was...",dihydrotachysterol,increased calcium-release,"{'drug': {'start_char': [91], 'end_char': [109..."
2,Unaccountable severe hypercalcemia in a patien...,dihydrotachysterol,hypercalcemia,"{'drug': {'start_char': [84], 'end_char': [102..."
3,METHODS: We report two cases of pseudoporphyri...,naproxen,pseudoporphyria,"{'drug': {'start_char': [58], 'end_char': [66]..."
4,METHODS: We report two cases of pseudoporphyri...,oxaprozin,pseudoporphyria,"{'drug': {'start_char': [71], 'end_char': [80]..."
...,...,...,...,...
6816,Lithium treatment was terminated in 1975 becau...,Lithium,lithium intoxication,"{'drug': {'start_char': [0], 'end_char': [7]},..."
6817,Lithium treatment was terminated in 1975 becau...,lithium,lithium intoxication,"{'drug': {'start_char': [52], 'end_char': [59]..."
6818,Eosinophilia caused by clozapine was observed ...,clozapine,Eosinophilia,"{'drug': {'start_char': [23], 'end_char': [32]..."
6819,Eosinophilia has been encountered from 0.2 to ...,clozapine,Eosinophilia,"{'drug': {'start_char': [55], 'end_char': [64]..."


In [None]:
datasets["train"]["indexes"][0]

{'drug': {'start_char': [12], 'end_char': [24]},
 'effect': {'start_char': [33], 'end_char': [44]}}

In [None]:
df = pd.DataFrame(list(consolidated_dataset.values()))

In [None]:
df.head()

Unnamed: 0,text,drug,effect,drug_indices_start,drug_indices_end,effect_indices_start,effect_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],[ototoxicity],{12},{24},{33},{44}
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],[increased calcium-release],{91},{109},{143},{168}
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],[hypercalcemia],{84},{102},{21},{34}
3,METHODS: We report two cases of pseudoporphyri...,"[naproxen, oxaprozin]","[pseudoporphyria, pseudoporphyria]","{58, 71}","{80, 66}",{32},{47}
4,"Naproxen, the most common offender, has been a...",[Naproxen],[erythropoietic protoporphyria],{0},{8},{134},{163}


In [None]:
# since no spans overlap, we can sort to get 1:1 matched index spans
# note that sets don't preserve insertion order

df["drug_indices_start"] = df["drug_indices_start"].apply(list).apply(sorted)
df["drug_indices_end"] = df["drug_indices_end"].apply(list).apply(sorted)
df["effect_indices_start"] = df["effect_indices_start"].apply(list).apply(sorted)
df["effect_indices_end"] = df["effect_indices_end"].apply(list).apply(sorted)

In [None]:
# save to JSON to then import into Dataset object
df.to_json("dataset.jsonl", orient="records", lines=True)

In [None]:
cons_dataset = load_dataset("json", data_files="dataset.jsonl")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
# no train-test provided, so we create our own
cons_dataset = cons_dataset["train"].train_test_split()

In [None]:
cons_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 3203
    })
    test: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 1068
    })
})

---
## Token Labeling

Finally, we can label each token with its entity. We use BIO tagging on two entities, `DRUG` and `EFFECT`. This results in five possible classes for each token:

* `O` - outside any entity we care about
* `B-DRUG` - the beginning of a `DRUG` entity
* `I-DRUG` - inside a `DRUG` entity
* `B-EFFECT` - the beginning of an `EFFECT` entity
* `I-EFFECT` - inside an `EFFECT` entity

In [None]:
label_list = ['O', 'B-DRUG', 'I-DRUG', 'B-EFFECT', 'I-EFFECT']

custom_seq = Sequence(feature=ClassLabel(num_classes=5,
                                         names=label_list,
                                         names_file=None, id=None), length=-1, id=None)

cons_dataset["train"].features["ner_tags"] = custom_seq
cons_dataset["test"].features["ner_tags"] = custom_seq

In [None]:
custom_seq

Sequence(feature=ClassLabel(names=['O', 'B-DRUG', 'I-DRUG', 'B-EFFECT', 'I-EFFECT'], id=None), length=-1, id=None)

In [None]:
cons_dataset["train"].features["ner_tags"]

KeyError: ignored

In [None]:
cons_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 3203
    })
    test: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 1068
    })
})

In [None]:
model_checkpoint = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/228k [00:00<?, ?B/s]

In [None]:
def generate_row_labels(row, verbose=False):
    """ Given a row from the consolidated `Ade_corpus_v2_drug_ade_relation` dataset,
    generates BIO tags for drug and effect entities.

    """

    text = row["text"]

    labels = []
    label = "O"
    prefix = ""

    # while iterating through tokens, increment to traverse all drug and effect spans
    drug_index = 0
    effect_index = 0

    tokens = tokenizer(text, return_offsets_mapping=True)

    for n in range(len(tokens["input_ids"])):
        offset_start, offset_end = tokens["offset_mapping"][n]

        # should only happen for [CLS] and [SEP]
        if offset_end - offset_start == 0:
            labels.append(-100)
            continue

        if drug_index < len(row["drug_indices_start"]) and offset_start == row["drug_indices_start"][drug_index]:
            label = "DRUG"
            prefix = "B-"

        elif effect_index < len(row["effect_indices_start"]) and offset_start == row["effect_indices_start"][effect_index]:
            label = "EFFECT"
            prefix = "B-"

        labels.append(label_list.index(f"{prefix}{label}"))

        if drug_index < len(row["drug_indices_end"]) and offset_end == row["drug_indices_end"][drug_index]:
            label = "O"
            prefix = ""
            drug_index += 1

        elif effect_index < len(row["effect_indices_end"]) and offset_end == row["effect_indices_end"][effect_index]:
            label = "O"
            prefix = ""
            effect_index += 1

        # need to transition "inside" if we just entered an entity
        if prefix == "B-":
            prefix = "I-"

    if verbose:
        print(f"{row}\n")
        orig = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
        for n in range(len(labels)):
            print(orig[n], labels[n])
    tokens["labels"] = labels

    return tokens


In [None]:
cons_dataset["train"][56]

{'text': 'Chlorambucil-induced chromosome damage to human lymphocytes is dose-dependent and cumulative.',
 'drug': ['Chlorambucil'],
 'effect': ['chromosome damage'],
 'drug_indices_start': [0],
 'drug_indices_end': [12],
 'effect_indices_start': [21],
 'effect_indices_end': [38]}

In [None]:
cons_dataset["train"][2]["text"][11:31]

' 3 AS patients treat'

In [None]:
# testing out...

generate_row_labels(cons_dataset["train"][3], verbose=True)

{'text': 'We report the case histories of two patients with histologically confirmed adenocarcinoma of the prostate, both of whom had been treated with steroidal anti-androgen therapy in the form of cyproterone acetate prior to radical or palliative pelvic irradiation, and who subsequently developed femoral head avascular necrosis.', 'drug': ['cyproterone acetate'], 'effect': ['femoral head avascular necrosis'], 'drug_indices_start': [189], 'drug_indices_end': [208], 'effect_indices_start': [291], 'effect_indices_end': [322]}

[CLS] -100
we 0
report 0
the 0
case 0
histories 0
of 0
two 0
patients 0
with 0
histologically 0
confirmed 0
adenocarcinoma 0
of 0
the 0
prostate 0
, 0
both 0
of 0
whom 0
had 0
been 0
treated 0
with 0
steroid 0
##al 0
anti 0
- 0
androgen 0
therapy 0
in 0
the 0
form 0
of 0
cyp 1
##rote 2
##ron 2
##e 2
acetate 2
prior 0
to 0
radical 0
or 0
palliative 0
pelvic 0
irradiation 0
, 0
and 0
who 0
subsequently 0
developed 0
femoral 3
head 4
av 4
##ascular 4
necrosis 4
. 0


{'input_ids': [102, 185, 2024, 111, 820, 17102, 131, 502, 568, 190, 22571, 3804, 14091, 131, 111, 6625, 422, 655, 131, 7861, 883, 528, 2338, 190, 11809, 120, 821, 579, 16573, 2223, 121, 111, 592, 131, 7592, 11444, 1809, 30107, 9382, 1979, 147, 6382, 234, 17241, 13707, 8896, 422, 137, 975, 5224, 1815, 11572, 2795, 873, 2375, 9191, 205, 103], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 2), (3, 9), (10, 13), (14, 18), (19, 28), (29, 31), (32, 35), (36, 44), (45, 49), (50, 64), (65, 74), (75, 89), (90, 92), (93, 96), (97, 105), (105, 106), (107, 111), (112, 114), (115, 119), (120, 123), (124, 128), (129, 136), (137, 141),

In [None]:
labeled_dataset = cons_dataset.map(generate_row_labels)

Map:   0%|          | 0/3203 [00:00<?, ? examples/s]

Map:   0%|          | 0/1068 [00:00<?, ? examples/s]