## Fine-tuning of WizardLM-13B on EdtSum

Install and Load Required Libraries

In [1]:
! pip3 install -q -U transformers
! pip install -q -U datasets
! pip3 install -q -U peft
! pip install -q -U trl
! pip3 install -q -U auto-gptq
! pip3 install -q -U optimum
! pip3 install -q -U bitsandbytes

In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/home/kmb85/rds/hpc-work/huggingface'

In [2]:
import transformers
import torch
from datasets import load_dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer



### Load WizardLM-13B and Tokenizer

In [4]:
model_name_or_path = "WizardLM/WizardLM-13B-V1.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4b_quant_type='nf4',
    torch_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    use_safetensors=True,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
    token=""
)



In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True,
                                          token="")
tokenizer.pad_token=tokenizer.eos_token

In [6]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Load LoRA Adapter

In [7]:
config = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CASUAL_LM",
    target_modules=["q_proj", "v_proj"]
)

In [8]:
model=get_peft_model(model, config)

### Dataset preparation

In [3]:
dataset = load_dataset('ChanceFocus/flare-edtsum')

In [4]:
dataset

DatasetDict({
    test: Dataset({
        features: ['id', 'query', 'answer', 'text'],
        num_rows: 2000
    })
})

In [5]:
from datasets import DatasetDict

total_size = 2000
train_size = int(0.7 * total_size)
test_size = int(0.15 * total_size)

train_subset = dataset['test'].select(range(train_size))
test_subset = dataset['test'].select(range(train_size, train_size + test_size))
validation_subset = dataset['test'].select(range(train_size + test_size, total_size))

split_datasets = DatasetDict({
    'train': train_subset,
    'test': test_subset,
    'validation': validation_subset
})

In [6]:
prompt = "You are given a text that consists of multiple sentences. Your task is to perform abstractive summarization on this text. Use your understanding of the content to express the main ideas and crucial details in a shorter, coherent, and natural sounding text."

def generate_train_prompt(data_point):
    input_text = data_point['text']
    summary = data_point['answer']
    text = f'{prompt}\n###Input:\n{input_text}\n###Output:\n{summary}'
    return {'text': text, 'labels': summary}

In [7]:
train_dataset = split_datasets['train'].shuffle().map(generate_train_prompt)

Map:   0%|          | 0/1400 [00:00<?, ? examples/s]

In [8]:
validation_dataset = split_datasets['validation'].shuffle().map(generate_train_prompt)

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

In [9]:
def generate_test_prompt(data_point):
    input_text = data_point['text']
    text = f'{prompt}\n###Input:\n{input_text}\n###Output:\n'
    return {'text': text}

In [10]:
test_dataset = split_datasets['test'].shuffle().map(generate_test_prompt)

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

### Training

In [17]:
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,
    learning_rate=0.00003,
    bf16=True,
    num_train_epochs=8,
    save_strategy="epoch",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    output_dir='./experiments',
    remove_unused_columns=False,
    warmup_ratio=0.03,
    logging_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=15,
    label_names=['labels'],
    eval_steps=15,
    group_by_length=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    args=training_args,
    tokenizer=tokenizer,
    dataset_text_field='text',
    peft_config=config,
    max_seq_length=4096
)

Map:   0%|          | 0/1400 [00:00<?, ? examples/s]

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [18]:
model.config.use_cache = False
trainer.state.log_history = True
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkmb85[0m ([33mcam_kiril[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss
15,1.6026,1.505949
30,1.5357,1.448113
45,1.4907,1.388971
60,1.425,1.328787
75,1.3522,1.2919
90,1.3257,1.278317
105,1.3141,1.269004
120,1.3136,1.263967
135,1.3156,1.261327
150,1.2968,1.260275


Checkpoint destination directory ./experiments/checkpoint-21 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-43 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-65 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-87 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-109 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-131 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-153 already exists and is non-empty. Saving will pro

TrainOutput(global_step=168, training_loss=1.3867902329989843, metrics={'train_runtime': 12186.1862, 'train_samples_per_second': 0.919, 'train_steps_per_second': 0.014, 'total_flos': 1.0787647261741056e+18, 'train_loss': 1.3867902329989843, 'epoch': 7.68})

### Save the fine-tuned model

In [19]:
model.save_pretrained(f'WizardLM-13B-edtsum_batch_size_8_epochs_8')

### Evaluate the fine-tuned model

In [13]:
test_dataset = test_dataset.shuffle(seed=42)

In [14]:
import requests
import ast

request = {
    'max_new_tokens': 200,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
}

url = "http://127.0.0.1:5070/api/v1/generate"
headers = {'Content-Type': 'application/json'}

In [15]:
def trim_float_string(s):
    s = s.replace('###', '')
    s = s.replace('\n', '')
    return s

In [16]:
from datasets import load_metric
import requests
import ast

rouge_metric = load_metric("rouge")

total_scores = {'rouge1': [], 'rouge2': [], 'rougeL': [], 'rougeLsum': []}
num_evaluated = 0

for i in range(len(test_dataset)):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)

    prediction_text = ast.literal_eval(response.text)["results"][0]['text'].lower()
    correct_ans_text = trim_float_string(test_dataset[i]['answer'].lower())

    if not prediction_text.strip() or not correct_ans_text.strip():
        continue

    rouge_scores = rouge_metric.compute(predictions=[prediction_text], references=[correct_ans_text])

    for key in total_scores.keys():
        total_scores[key].append(rouge_scores[key].mid.fmeasure)
    num_evaluated += 1

  rouge_metric = load_metric("rouge")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [17]:
average_scores = {key: sum(values) / num_evaluated for key, values in total_scores.items() if num_evaluated > 0}
print(f"Average ROUGE Scores fine-tuned model: {average_scores}")

Average ROUGE Scores fine-tuned model: {'rouge1': 0.20703350009989505, 'rouge2': 0.10989082512507194, 'rougeL': 0.1692579996830393, 'rougeLsum': 0.16966090762580183}


### Evaluate the base model

In [11]:
test_dataset = test_dataset.shuffle(seed=42)

In [12]:
import requests
import ast

request = {
    'max_new_tokens': 200,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
}

url = "http://127.0.0.1:5030/api/v1/generate"
headers = {'Content-Type': 'application/json'}

In [13]:
def trim_float_string(s):
    s = s.replace('###', '')
    s = s.replace('\n', '')
    return s

In [14]:
from datasets import load_metric
import requests
import ast

rouge_metric = load_metric("rouge")

total_scores = {'rouge1': [], 'rouge2': [], 'rougeL': [], 'rougeLsum': []}
num_evaluated = 0

for i in range(len(test_dataset)):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)

    prediction_text = ast.literal_eval(response.text)["results"][0]['text'].lower()
    correct_ans_text = trim_float_string(test_dataset[i]['answer'].lower())

    if not prediction_text.strip() or not correct_ans_text.strip():
        continue

    rouge_scores = rouge_metric.compute(predictions=[prediction_text], references=[correct_ans_text])

    for key in total_scores.keys():
        total_scores[key].append(rouge_scores[key].mid.fmeasure)
    num_evaluated += 1

  rouge_metric = load_metric("rouge")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [15]:
average_scores = {key: sum(values) / num_evaluated for key, values in total_scores.items() if num_evaluated > 0}
print(f"Average ROUGE Scores: {average_scores}")

Average ROUGE Scores: {'rouge1': 0.18480767445062993, 'rouge2': 0.08276666205900492, 'rougeL': 0.14285202317736728, 'rougeLsum': 0.14491822014612332}
