In [1]:
# !pip install datasets
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install transformers[sentencepiece]
# !pip install evaluate
# !pip install accelerate

In [2]:
#load the dataset

data_dir = "/home/jovyan/Works/Practice/dataset/drugscom/drug_review-condition"

from datasets import load_from_disk

review_condition_dataset = load_from_disk(data_dir)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
review_condition_dataset = review_condition_dataset.remove_columns("__index_level_0__")
review_condition_dataset

DatasetDict({
    train: Dataset({
        features: ['condition', 'review'],
        num_rows: 16189
    })
    validation: Dataset({
        features: ['condition', 'review'],
        num_rows: 4048
    })
})

In [4]:
# Create label for each condition
# As our problem is multiclass classification
unique_cond = review_condition_dataset["train"].unique("condition")
len(unique_cond)

239

In [5]:
label2id = {condition:idx for idx, condition in enumerate(unique_cond)}
label2id

{'psoriatic arthritis': 0,
 'sedation': 1,
 'osteoporosis': 2,
 'amenorrhea': 3,
 'bronchitis': 4,
 'pseudotumor cerebri': 5,
 'polycystic ovary syndrome': 6,
 'opiate withdrawal': 7,
 "crohn's disease, maintenance": 8,
 'dermatitis': 9,
 'cold symptoms': 10,
 'hot flashes': 11,
 'pain': 12,
 'bowel preparation': 13,
 'constipation, drug induced': 14,
 'nasal congestion': 15,
 'anorexia': 16,
 'birth control': 17,
 'menstrual disorders': 18,
 'diverticulitis': 19,
 'cough': 20,
 'opiate dependence': 21,
 'head lice': 22,
 'anxiety': 23,
 'urinary incontinence': 24,
 'panic disorde': 25,
 'gout': 26,
 'benign prostatic hyperplasia': 27,
 'acute coronary syndrome': 28,
 'plaque psoriasis': 29,
 'peripheral neuropathy': 30,
 'allergies': 31,
 'copd': 32,
 'opioid-induced constipation': 33,
 'dental abscess': 34,
 'rhinitis': 35,
 'asthma, acute': 36,
 'hypogonadism, male': 37,
 'chronic fatigue syndrome': 38,
 "crohn's disease": 39,
 'perimenopausal symptoms': 40,
 'neuralgia': 41,
 'uppe

In [6]:
# Removing unseen conditions validation dataset
drug_review_dataset = review_condition_dataset.filter(lambda x: x["condition"] in unique_cond)


In [7]:
drug_review_dataset

DatasetDict({
    train: Dataset({
        features: ['condition', 'review'],
        num_rows: 16189
    })
    validation: Dataset({
        features: ['condition', 'review'],
        num_rows: 4048
    })
})

In [8]:
# Encode labels
def encode_label(example):
    example["labels"] = label2id[example["condition"]]
    return example

In [9]:
review_condition_dataset = review_condition_dataset.map(encode_label)

In [10]:
review_condition_dataset

DatasetDict({
    train: Dataset({
        features: ['condition', 'review', 'labels'],
        num_rows: 16189
    })
    validation: Dataset({
        features: ['condition', 'review', 'labels'],
        num_rows: 4048
    })
})

In [11]:
review_condition_dataset["train"][10]

{'condition': "crohn's disease, maintenance",
 'review': "i was on pentasa for 19 years but every time i had a colonoscopy the crohn's was progressing.  i finally ended in the hospital with a severe inflammation.  i was told by 2 gi docs that the pentasa was probably not working but back in 1997, except for 6mp they really had no other alternative.  2 months ago i switched to humira.  many more side effects than pentasa however.",
 'labels': 8}

In [12]:
from transformers import AutoTokenizer

checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [13]:
#tokenize the dataset
def tokenize_function(example):
    tokens = tokenizer(example["review"], truncation=True)
    # tokens["labels"] = example["labels"]
    return tokens


In [14]:
tokenized_dataset = review_condition_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/4048 [00:00<?, ? examples/s]

Map: 100%|██████████| 4048/4048 [00:00<00:00, 9107.77 examples/s]


In [15]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['condition', 'review', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16189
    })
    validation: Dataset({
        features: ['condition', 'review', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 4048
    })
})

In [16]:
tokenized_dataset["train"][0]

{'condition': 'psoriatic arthritis',
 'review': 'this drug worked like magic for me, going from not being able to dress myself, tie shoelaces or buttons to being totally mobile again within 2 doses! also worked wonders on psoriasis, all with no outward side-effects for more than a year now! thank you!',
 'labels': 0,
 'input_ids': [101,
  2023,
  4319,
  2499,
  2066,
  3894,
  2005,
  2033,
  1010,
  2183,
  2013,
  2025,
  2108,
  2583,
  2000,
  4377,
  2870,
  1010,
  5495,
  10818,
  19217,
  2015,
  2030,
  11287,
  2000,
  2108,
  6135,
  4684,
  2153,
  2306,
  1016,
  21656,
  999,
  2036,
  2499,
  16278,
  2006,
  8827,
  11069,
  6190,
  1010,
  2035,
  2007,
  2053,
  15436,
  2217,
  1011,
  3896,
  2005,
  2062,
  2084,
  1037,
  2095,
  2085,
  999,
  4067,
  2017,
  999,
  102],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,


In [17]:
# implement data collator
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [18]:
samples = tokenized_dataset["train"][1000]
samples

{'condition': 'diarrhea',
 'review': "paregoric is a wonderful medication. not sure if i have crohn's disease or ibs but when having a flare it is the best thing i have found.   as a child i was given a medication with paregoric in it and it always worked and made me feel better.   i hate the fact that medications are so abused that the government feels the need to take it off the shelf.   i have not been able to get a prescription for more than 5 years.",
 'labels': 230,
 'input_ids': [101,
  11968,
  20265,
  7277,
  2003,
  1037,
  6919,
  14667,
  1012,
  2025,
  2469,
  2065,
  1045,
  2031,
  13675,
  11631,
  2078,
  1005,
  1055,
  4295,
  2030,
  21307,
  2015,
  2021,
  2043,
  2383,
  1037,
  17748,
  2009,
  2003,
  1996,
  2190,
  2518,
  1045,
  2031,
  2179,
  1012,
  2004,
  1037,
  2775,
  1045,
  2001,
  2445,
  1037,
  14667,
  2007,
  11968,
  20265,
  7277,
  1999,
  2009,
  1998,
  2009,
  2467,
  2499,
  1998,
  2081,
  2033,
  2514,
  2488,
  1012,
  1045,
  522

In [19]:
import evaluate
import numpy as np

# Preload metric outside function to avoid reloading each time
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    
    # Compute accuracy
    acc = accuracy_metric.compute(predictions=predictions, references=labels)
    
    # Compute macro-averaged F1
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="macro")
    
    # Return combined metrics
    return {"accuracy": acc["accuracy"], "f1": f1["f1"]}


In [20]:
#laod the models
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=239)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
from transformers import TrainingArguments
training_args = TrainingArguments("/home/jovyan/Works/Practice/trainers/drug-review-trainer", eval_strategy="epoch", num_train_epochs=10)

In [22]:
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset = tokenized_dataset["train"],
    eval_dataset = tokenized_dataset["validation"],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,4.1586,3.513812,0.299901,0.206413
2,2.5654,2.303151,0.478014,0.403534
3,1.6601,1.764624,0.563241,0.513817


