## Fine-tuning of Llama-2-13b on med_qa_tw_en_bigbio_qa

Install and Load Required Libraries

In [1]:
! pip3 install -q -U transformers
! pip install -q -U datasets
! pip3 install -q -U peft
! pip install -q -U trl
! pip3 install -q -U auto-gptq
! pip3 install -q -U optimum
! pip3 install -q -U bitsandbytes

In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/home/kmb85/rds/hpc-work/huggingface'

In [2]:
import transformers
import torch
from datasets import load_dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer

  from .autonotebook import tqdm as notebook_tqdm


### Load Llama-2-13b and Tokenizer

In [3]:
model_name_or_path = "meta-llama/Llama-2-13b-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4b_quant_type='nf4',
    torch_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    use_safetensors=True,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
    token=""
)

config.json: 100%|██████████| 610/610 [00:00<00:00, 5.26MB/s]
model.safetensors.index.json: 100%|██████████| 33.4k/33.4k [00:00<00:00, 136MB/s]
Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]
model-00001-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s][A
model-00001-of-00003.safetensors:   0%|          | 10.5M/9.95G [00:00<02:06, 78.4MB/s][A
model-00001-of-00003.safetensors:   0%|          | 31.5M/9.95G [00:00<01:32, 107MB/s] [A
model-00001-of-00003.safetensors:   1%|          | 52.4M/9.95G [00:00<01:28, 112MB/s][A
model-00001-of-00003.safetensors:   1%|          | 73.4M/9.95G [00:00<01:26, 114MB/s][A
model-00001-of-00003.safetensors:   1%|          | 94.4M/9.95G [00:00<01:25, 115MB/s][A
model-00001-of-00003.safetensors:   1%|          | 115M/9.95G [00:01<01:24, 116MB/s] [A
model-00001-of-00003.safetensors:   1%|▏         | 136M/9.95G [00:01<01:24, 117MB/s][A
model-00001-of-00003.safetensors:   2%|▏         | 157M/9.95G [00:01<01:23, 117MB/s][A
model

model-00001-of-00003.safetensors:  37%|███▋      | 3.72G/9.95G [00:33<00:53, 117MB/s][A
model-00001-of-00003.safetensors:  38%|███▊      | 3.74G/9.95G [00:33<00:52, 117MB/s][A
model-00001-of-00003.safetensors:  38%|███▊      | 3.76G/9.95G [00:33<00:52, 117MB/s][A
model-00001-of-00003.safetensors:  38%|███▊      | 3.79G/9.95G [00:33<00:52, 117MB/s][A
model-00001-of-00003.safetensors:  38%|███▊      | 3.81G/9.95G [00:33<00:52, 117MB/s][A
model-00001-of-00003.safetensors:  38%|███▊      | 3.83G/9.95G [00:33<00:52, 117MB/s][A
model-00001-of-00003.safetensors:  39%|███▊      | 3.85G/9.95G [00:34<00:51, 117MB/s][A
model-00001-of-00003.safetensors:  39%|███▉      | 3.87G/9.95G [00:34<00:57, 106MB/s][A
model-00001-of-00003.safetensors:  39%|███▉      | 3.89G/9.95G [00:34<00:55, 109MB/s][A
model-00001-of-00003.safetensors:  39%|███▉      | 3.91G/9.95G [00:34<00:54, 110MB/s][A
model-00001-of-00003.safetensors:  40%|███▉      | 3.93G/9.95G [00:34<01:02, 96.8MB/s][A
model-00001-of-00003

model-00001-of-00003.safetensors:  74%|███████▍  | 7.39G/9.95G [01:05<00:21, 117MB/s][A
model-00001-of-00003.safetensors:  75%|███████▍  | 7.41G/9.95G [01:05<00:21, 117MB/s][A
model-00001-of-00003.safetensors:  75%|███████▍  | 7.43G/9.95G [01:06<00:21, 117MB/s][A
model-00001-of-00003.safetensors:  75%|███████▍  | 7.46G/9.95G [01:06<00:21, 117MB/s][A
model-00001-of-00003.safetensors:  75%|███████▌  | 7.48G/9.95G [01:06<00:21, 117MB/s][A
model-00001-of-00003.safetensors:  75%|███████▌  | 7.50G/9.95G [01:06<00:20, 117MB/s][A
model-00001-of-00003.safetensors:  76%|███████▌  | 7.52G/9.95G [01:06<00:20, 117MB/s][A
model-00001-of-00003.safetensors:  76%|███████▌  | 7.54G/9.95G [01:07<00:20, 117MB/s][A
model-00001-of-00003.safetensors:  76%|███████▌  | 7.56G/9.95G [01:07<00:20, 117MB/s][A
model-00001-of-00003.safetensors:  76%|███████▌  | 7.58G/9.95G [01:07<00:20, 117MB/s][A
model-00001-of-00003.safetensors:  76%|███████▋  | 7.60G/9.95G [01:07<00:19, 117MB/s][A
model-00001-of-00003.

model-00002-of-00003.safetensors:  12%|█▏        | 1.17G/9.90G [00:10<01:40, 86.7MB/s][A
model-00002-of-00003.safetensors:  12%|█▏        | 1.20G/9.90G [00:10<01:32, 94.1MB/s][A
model-00002-of-00003.safetensors:  12%|█▏        | 1.21G/9.90G [00:10<01:32, 94.0MB/s][A
model-00002-of-00003.safetensors:  12%|█▏        | 1.22G/9.90G [00:10<01:35, 91.3MB/s][A
model-00002-of-00003.safetensors:  12%|█▏        | 1.24G/9.90G [00:10<01:27, 99.4MB/s][A
model-00002-of-00003.safetensors:  13%|█▎        | 1.26G/9.90G [00:11<01:22, 105MB/s] [A
model-00002-of-00003.safetensors:  13%|█▎        | 1.28G/9.90G [00:11<01:19, 109MB/s][A
model-00002-of-00003.safetensors:  13%|█▎        | 1.30G/9.90G [00:11<01:17, 111MB/s][A
model-00002-of-00003.safetensors:  13%|█▎        | 1.32G/9.90G [00:11<01:15, 113MB/s][A
model-00002-of-00003.safetensors:  14%|█▎        | 1.34G/9.90G [00:11<01:23, 103MB/s][A
model-00002-of-00003.safetensors:  14%|█▍        | 1.36G/9.90G [00:12<01:24, 101MB/s][A
model-00002-of-

model-00002-of-00003.safetensors:  49%|████▉     | 4.90G/9.90G [00:43<00:42, 117MB/s][A
model-00002-of-00003.safetensors:  50%|████▉     | 4.92G/9.90G [00:43<00:42, 117MB/s][A
model-00002-of-00003.safetensors:  50%|████▉     | 4.94G/9.90G [00:43<00:42, 117MB/s][A
model-00002-of-00003.safetensors:  50%|█████     | 4.96G/9.90G [00:44<00:42, 117MB/s][A
model-00002-of-00003.safetensors:  50%|█████     | 4.98G/9.90G [00:44<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  51%|█████     | 5.00G/9.90G [00:44<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  51%|█████     | 5.02G/9.90G [00:44<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  51%|█████     | 5.04G/9.90G [00:44<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  51%|█████     | 5.06G/9.90G [00:45<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  51%|█████▏    | 5.09G/9.90G [00:45<00:41, 117MB/s][A
model-00002-of-00003.safetensors:  52%|█████▏    | 5.11G/9.90G [00:45<00:40, 117MB/s][A
model-00002-of-00003.

model-00002-of-00003.safetensors:  87%|████████▋ | 8.59G/9.90G [01:16<00:11, 117MB/s][A
model-00002-of-00003.safetensors:  87%|████████▋ | 8.61G/9.90G [01:16<00:11, 117MB/s][A
model-00002-of-00003.safetensors:  87%|████████▋ | 8.63G/9.90G [01:17<00:10, 117MB/s][A
model-00002-of-00003.safetensors:  87%|████████▋ | 8.65G/9.90G [01:17<00:10, 117MB/s][A
model-00002-of-00003.safetensors:  88%|████████▊ | 8.67G/9.90G [01:17<00:10, 117MB/s][A
model-00002-of-00003.safetensors:  88%|████████▊ | 8.69G/9.90G [01:17<00:10, 117MB/s][A
model-00002-of-00003.safetensors:  88%|████████▊ | 8.71G/9.90G [01:17<00:10, 117MB/s][A
model-00002-of-00003.safetensors:  88%|████████▊ | 8.73G/9.90G [01:17<00:09, 117MB/s][A
model-00002-of-00003.safetensors:  88%|████████▊ | 8.76G/9.90G [01:18<00:09, 117MB/s][A
model-00002-of-00003.safetensors:  89%|████████▊ | 8.78G/9.90G [01:18<00:09, 117MB/s][A
model-00002-of-00003.safetensors:  89%|████████▉ | 8.80G/9.90G [01:18<00:09, 117MB/s][A
model-00002-of-00003.

model-00003-of-00003.safetensors:  40%|████      | 2.49G/6.18G [00:21<00:32, 114MB/s][A
model-00003-of-00003.safetensors:  41%|████      | 2.51G/6.18G [00:22<00:31, 115MB/s][A
model-00003-of-00003.safetensors:  41%|████      | 2.53G/6.18G [00:22<00:31, 116MB/s][A
model-00003-of-00003.safetensors:  41%|████      | 2.55G/6.18G [00:22<00:31, 116MB/s][A
model-00003-of-00003.safetensors:  42%|████▏     | 2.57G/6.18G [00:22<00:31, 116MB/s][A
model-00003-of-00003.safetensors:  42%|████▏     | 2.59G/6.18G [00:22<00:37, 96.8MB/s][A
model-00003-of-00003.safetensors:  42%|████▏     | 2.61G/6.18G [00:23<00:34, 102MB/s] [A
model-00003-of-00003.safetensors:  43%|████▎     | 2.63G/6.18G [00:23<00:33, 106MB/s][A
model-00003-of-00003.safetensors:  43%|████▎     | 2.65G/6.18G [00:23<00:32, 109MB/s][A
model-00003-of-00003.safetensors:  43%|████▎     | 2.67G/6.18G [00:23<00:31, 112MB/s][A
model-00003-of-00003.safetensors:  44%|████▎     | 2.69G/6.18G [00:23<00:30, 113MB/s][A
model-00003-of-0000

model-00003-of-00003.safetensors: 100%|██████████| 6.18G/6.18G [00:54<00:00, 113MB/s][A
Downloading shards: 100%|██████████| 3/3 [03:52<00:00, 77.39s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]
generation_config.json: 100%|██████████| 188/188 [00:00<00:00, 1.70MB/s]


In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True,
                                          token="")
tokenizer.pad_token=tokenizer.eos_token

tokenizer_config.json: 100%|██████████| 776/776 [00:00<00:00, 6.95MB/s]
tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 72.0MB/s]
tokenizer.json: 100%|██████████| 1.84M/1.84M [00:00<00:00, 6.19MB/s]
special_tokens_map.json: 100%|██████████| 414/414 [00:00<00:00, 4.44MB/s]


In [5]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Load LoRA Adapter

In [6]:
config = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CASUAL_LM",
)

In [7]:
model=get_peft_model(model, config)

### Dataset preparation

In [3]:
dataset = load_dataset('bigbio/med_qa', 'med_qa_tw_en_bigbio_qa')

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question_id', 'document_id', 'question', 'type', 'choices', 'context', 'answer'],
        num_rows: 11298
    })
    test: Dataset({
        features: ['id', 'question_id', 'document_id', 'question', 'type', 'choices', 'context', 'answer'],
        num_rows: 1413
    })
    validation: Dataset({
        features: ['id', 'question_id', 'document_id', 'question', 'type', 'choices', 'context', 'answer'],
        num_rows: 1412
    })
})

In [5]:
DEFAULT_PROMPT = "Below is a medical question and four choices for answer. Output the correct choice to answer the question."

def generate_train_prompt(data_point):
    question = data_point[ 'question']
    choices = data_point['choices']
    choices_str = ''
    answer = data_point['answer']
    for choice in choices:
        choices_str += choice + "\n"
    answer = data_point[ 'answer'][0]
    text = f'{DEFAULT_PROMPT}\n###Question:\n{question}\n###Choices:\n{choices_str}###Output:\n{answer}'
    return {'text': text, 'labels': answer}

In [6]:
train_dataset = dataset['train'].shuffle().map(generate_train_prompt)

Map: 100%|██████████| 11298/11298 [00:01<00:00, 8033.44 examples/s]


In [7]:
validation_dataset = dataset['validation'].shuffle().map(generate_train_prompt)

Map: 100%|██████████| 1412/1412 [00:00<00:00, 8033.60 examples/s]


In [8]:
def generate_test_prompt(data_point):
    question = data_point[ 'question']
    choices = data_point['choices']
    choices_str = ''
    answer = data_point['answer']
    for choice in choices:
        choices_str += choice + "\n"
    answer = data_point[ 'answer'][0]
    text = f'{DEFAULT_PROMPT}\n###Question:\n{question}\n###Choices:\n{choices_str}###Output:\n'
    return {'text':text}

In [9]:
test_dataset = dataset['test'].shuffle().map(generate_test_prompt)

Map: 100%|██████████| 1413/1413 [00:00<00:00, 8135.39 examples/s]


### Training

In [15]:
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=20,
    gradient_accumulation_steps=20,
    learning_rate=0.00004,
    bf16=True,
    num_train_epochs=8,
    save_strategy="epoch",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    output_dir='./experiments',
    remove_unused_columns=False,
    warmup_ratio=0.03,
    logging_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=15,
    label_names=['labels'],
    eval_steps=15,
    group_by_length=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    args=training_args,
    tokenizer=tokenizer,
    dataset_text_field='text',
    peft_config=config,
    max_seq_length=4096
)

Map: 100%|██████████| 11298/11298 [00:00<00:00, 13980.88 examples/s]
Map: 100%|██████████| 1412/1412 [00:00<00:00, 13523.62 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [16]:
model.config.use_cache = False
trainer.state.log_history = True
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkmb85[0m ([33mcam_kiril[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
15,2.2167,2.116947
30,1.9972,1.803617
45,1.6753,1.610732
60,1.5127,1.510035
75,1.4238,1.427893
90,1.3496,1.39698
105,1.3174,1.374987
120,1.2944,1.359327
135,1.3173,1.350723
150,1.2983,1.345893


Checkpoint destination directory ./experiments/checkpoint-28 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-56 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-84 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-113 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-141 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-169 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-197 already exists and is non-empty.Saving will proceed b

TrainOutput(global_step=224, training_loss=1.4541422043527876, metrics={'train_runtime': 11169.8353, 'train_samples_per_second': 8.092, 'train_steps_per_second': 0.02, 'total_flos': 1.1256005108028211e+18, 'train_loss': 1.4541422043527876, 'epoch': 7.93})

### Save the fine-tuned model

In [17]:
model.save_pretrained(f'Llama-2-13b_med_qa_tw_en_bigbio_qa_batch_size_20')

### Evaluate the fine-tuned model

In [10]:
import requests

url = "http://127.0.0.1:5000/api/v1/generate"

In [11]:
test_dataset = test_dataset.shuffle(seed=42)

In [12]:
import textdistance
import tiktoken
import ast

def num_of_tokens_from_text(text):
    encoding=tiktoken.encoding_for_model(model_name='gpt-3.5-turbo')
    num_tokens=len(encoding.encode(text=text))
    return num_tokens


def similiary(str1, str2):
    return textdistance.hamming.normalized_similarity(str1, str2)

def substring_after(input_string, after_string):
    is_index = input_string.find(after_string)

    if is_index == -1:
        return input_string

    return input_string[is_index + 2:]

In [13]:
request = {
    'max_new_tokens': 100,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
}
headers = {'Content-Type': 'application/json'}

In [14]:
total_correct = 0
num_samples = 200

In [180]:
import string

for i in range(num_samples):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)

    prediction = ast.literal_eval(response.text)["results"][0]['text'].lower()
    prediction = prediction.replace('the correct choice is', '').strip()
    prediction = prediction.replace('the most appropriate next step in his management would be to', '').strip()
    prediction = prediction.replace("the most likely cause of the patient's symptoms is", '').strip()
    prediction = prediction.replace("the likely mechanism of action of the medication in question is a", '').strip()
    prediction = prediction.replace("the most likely diagnosis is", '').strip()
    prediction = prediction.replace(':', '')
    prediction = prediction.replace('\n', '')
    prediction = ''.join(char for char in prediction if char not in string.punctuation)
    prediction = substring_after(prediction, 'is')
    prediction = substring_after(prediction, 'be')

    matched_words = 0
    for word in test_dataset[i]['answer'][0].lower().split():
        if word in prediction:
            matched_words+=1

    total_correct += matched_words/len(test_dataset[i]['answer'][0].lower().split())

In [181]:
correct_percentage = (total_correct / num_samples) * 100
print(f'Correctness percentage {correct_percentage}%')

Correctness percentage 35.52321408553606%


### Evaluate the RAG model

In [15]:
import requests

url_rag = "https://b1b6-131-111-184-110.ngrok-free.app/search"

payload = {
    "text": '',
    "number_documents": 5,
    'collection': 'med_qa_tw_en_bigbio'
}

In [16]:
dataset['test'] = dataset['test'].shuffle(seed=42)

In [17]:
total_correct = 0
num_samples = 200

In [18]:
request = {
    'max_new_tokens': 200,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
    'stopping_strings': ['\n', '###']
}
headers = {'Content-Type': 'application/json'}

In [19]:
def generate_rag_prompt(data_point):
    question = data_point['text']
    answer = data_point['answer']
    text = f'{question}{answer}\n'
    return text

In [20]:
import string
DEFAULT_PROMPT = "Below are some medical questions wtih four choices and answers. Output the correct choice to answer the last question only based on the provided choices."


for i in range(num_samples):
    request['prompt'] = DEFAULT_PROMPT+'\n'

    payload['question'] = dataset['test'][i]['question']

    response_rag = requests.get(url_rag, json=payload)
    data_rag = response_rag.json()

    for record in data_rag:
        request['prompt'] += generate_rag_prompt(record)

    question = dataset['test'][i]['question']
    choices = dataset['test'][i]['choices']
    choices_str = ''
    for choice in choices:
        choices_str += choice + "\n"
    request['prompt'] += f'###Question:\n{question}\n###Choices:\n{choices_str}###Output:\n'
    response = requests.post(url, json=request)

    prediction = ast.literal_eval(response.text)["results"][0]['text'].lower()
    prediction = prediction.replace("'", '')
    prediction = prediction.replace('[', '')
    prediction = prediction.replace(']', '')

    total_correct += similiary(dataset['test'][i]['answer'][0].lower(), prediction)

In [21]:
correct_percentage = (total_correct / num_samples) * 100
print(f'Correctness percentage {correct_percentage}%')

Correctness percentage 45.60480317615946%
