## Medical Question-Answering - PubMed HuggingFace Dataset

In [None]:
!pip install transformers trl datasets peft accelerate bitsandbytes sentencepiece

In [None]:
import os
import gc
from datasets import load_dataset
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, TrainingArguments,BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_config
from trl import SFTTrainer

## Load PubMed QA data

In [None]:
data = load_dataset("pubmed_qa","pqa_labeled")

In [None]:
data

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 1000
    })
})

In [None]:
data['train']['question'][0]

'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?'

In [None]:
data['train']['long_answer'][0]

'Results depicted mitochondrial dynamics in vivo as PCD progresses within the lace plant, and highlight the correlation of this organelle with other organelles during developmental PCD. To the best of our knowledge, this is the first report of mitochondria and chloroplasts moving on transvacuolar strands to form a ring structure surrounding the nucleus during developmental PCD. Also, for the first time, we have shown the feasibility for the use of CsA in a whole plant system. Overall, our findings implicate the mitochondria as playing a critical and early role in developmentally regulated PCD in the lace plant.'

In [None]:
data['train']['context'][0]

{'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
  'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoT

In [None]:
#hyperparameters
LORA_ALPHA = 32
LORA_DROPOUT = 0.2
LORA_R = 4

LEARNING_RATE = 2e-4
NUM_EPOCHS = 1
BATCH_SIZE = 1
WEIGHT_DECAY = 0.001
MAX_GRAD_NORM = 0.3
gradient_accumulation_steps = 16
STEPS = 1
OPTIM = "adam"
MAX_STEPS = 512
OUTPUT_DIR = "./results"

## Quantization configuration using Bitsandbytes

In [None]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
torch.cuda.get_device_capability()
device_map = "cuda:0"

## Define model and tokenization

In [None]:
model_name = "nousresearch/llama-2-7b-chat-hf"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map=device_map,
        )
model.config.pretraining_tp = 1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

In [None]:
torch.cuda.empty_cache()

## Define LoRA adapter

In [None]:
peft_config = LoraConfig(
      lora_alpha= LORA_ALPHA,
      lora_dropout= LORA_DROPOUT,
      r= LORA_R,
      bias="none",
      task_type="CAUSAL_LM",
  )

## Setup training parameters

In [None]:
training_args = TrainingArguments(
      output_dir= OUTPUT_DIR,
      per_device_train_batch_size=BATCH_SIZE,
      gradient_accumulation_steps= gradient_accumulation_steps,
      learning_rate= LEARNING_RATE,
      logging_steps= STEPS,
      num_train_epochs= NUM_EPOCHS,
      max_steps= MAX_STEPS,
)

In [None]:
torch.cuda.empty_cache()

In [None]:
trainer = SFTTrainer(
        model=model,
        train_dataset=data['train'],
        peft_config=peft_config,
        dataset_text_field= "question",
        max_seq_length=512,
        tokenizer=tokenizer,
        args=training_args,
)

In [None]:
trainer.train()

{'loss': 3.2842, 'learning_rate': 0.000199609375, 'epoch': 0.02}
{'loss': 3.4232, 'learning_rate': 0.00019921875000000001, 'epoch': 0.03}
{'loss': 3.0619, 'learning_rate': 0.000198828125, 'epoch': 0.05}
{'loss': 2.8441, 'learning_rate': 0.00019843750000000002, 'epoch': 0.06}
{'loss': 3.1515, 'learning_rate': 0.000198046875, 'epoch': 0.08}
{'loss': 3.3112, 'learning_rate': 0.00019765625, 'epoch': 0.1}
{'loss': 2.988, 'learning_rate': 0.000197265625, 'epoch': 0.11}
{'loss': 2.7994, 'learning_rate': 0.000196875, 'epoch': 0.13}
{'loss': 2.7615, 'learning_rate': 0.00019648437500000002, 'epoch': 0.14}
{'loss': 2.7665, 'learning_rate': 0.00019609375, 'epoch': 0.16}
{'loss': 2.7176, 'learning_rate': 0.00019570312500000002, 'epoch': 0.18}
{'loss': 2.827, 'learning_rate': 0.0001953125, 'epoch': 0.19}
{'loss': 2.6007, 'learning_rate': 0.000194921875, 'epoch': 0.21}
{'loss': 2.6688, 'learning_rate': 0.00019453125000000002, 'epoch': 0.22}
{'loss': 2.5487, 'learning_rate': 0.000194140625, 'epoch': 0

KeyboardInterrupt: ignored

In [None]:
logging.set_verbosity(logging.CRITICAL)
torch.cuda.empty_cache()

In [None]:
model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model # Take care of distributed/parallel training
model_to_save.save_pretrained("outputs")

## Test the model

### Using Inference pipeline

In [None]:
pipe = pipeline(task="text-generation",model=model,tokenizer=tokenizer,max_length=500)

In [None]:
prompt = "Who is at risk for Prostate Cancer?"

In [None]:
template = f"""<s>[INST] <<SYS>>
You are a honest Medical assistant bot.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.
<</SYS>>

{prompt} [/INST]
"""

In [None]:
result = pipe(template)

In [None]:
response = result[0]['generated_text']
index = response.find("[/INST]")+len("[/INST]")

In [None]:
print(response[index:].strip())

Prostate cancer can affect both men and women, but the risk of developing the disease is higher in men. Men who have a family history of prostate cancer, men who have a history of chronic inflammation, men who have a high level of testosterone, and men who have a strong genetic predisposition to the disease are at a higher risk of developing prostate cancer.


## Without using pipeline

In [None]:
from peft import get_peft_model

In [None]:
lora_config = LoraConfig.from_pretrained('outputs')
model = get_peft_model(model, lora_config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
inputs = tokenizer(template, return_tensors="pt").to(device)
model = model.to(device)
outputs = model.generate(**inputs, max_new_tokens=1024)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

[INST] <<SYS>>
You are a helpful, respectful and honest Medical and Legal assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Who is at risk for Prostate Cancer? [/INST]

Prostate cancer can occur in both men and women, but it is more common in men. The risk of developing prostate cancer increases with age. In fact, the lifetime risk of developing prostate cancer is about 1 in 6 men. African American men have a higher risk of developing prostate cancer than other men. Additionally, men with a family history of prostate cancer are also at a higher risk. It is important 

## What's next

### Join Discord Server to be part of community and learn more about LLM and GenAI:[https://discord.com/invite/hEMqtDXCHA](https://discord.com/invite/hEMqtDXCHA)

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-9ecb8e28-fc76-7481-7af2-4d3de2c801bf)
