## Fine-tuning of Mixtral-8x7B on Tatqa

Install and Load Required Libraries

In [2]:
! pip3 install -q -U transformers
! pip install -q -U datasets
! pip3 install -q -U peft
! pip install -q -U trl
! pip3 install -q -U auto-gptq
! pip3 install -q -U optimum
! pip3 install -q -U bitsandbytes

In [3]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/home/kmb85/rds/hpc-work/huggingface'

In [4]:
import transformers
import torch
from datasets import load_dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer



### Load Mixtral-8x7B-v0.1 and Tokenizer

In [4]:
model_name_or_path = "mistralai/Mixtral-8x7B-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4b_quant_type='nf4',
    torch_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    use_safetensors=True,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
    token=""
)

Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True,
                                          token="")
tokenizer.pad_token=tokenizer.eos_token

In [6]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Load LoRA Adapter

In [7]:
config = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CASUAL_LM",
    target_modules=["q_proj", "v_proj"]
)

In [8]:
model=get_peft_model(model, config)

### Dataset preparation

In [5]:
dataset = load_dataset('ChanceFocus/flare-tatqa')

In [6]:
dataset

DatasetDict({
    test: Dataset({
        features: ['id', 'query', 'answer', 'text'],
        num_rows: 1668
    })
})

In [7]:
from datasets import DatasetDict

total_size = 1668
train_size = int(0.7 * total_size)
test_size = int(0.15 * total_size)

train_subset = dataset['test'].select(range(train_size))
test_subset = dataset['test'].select(range(train_size, train_size + test_size))
validation_subset = dataset['test'].select(range(train_size + test_size, total_size))

split_datasets = DatasetDict({
    'train': train_subset,
    'test': test_subset,
    'validation': validation_subset
})

In [8]:
import re

def remove_conversations(text):
    cleaned_text = re.sub(r'Conversations:.*?Question:', 'Question:', text, flags=re.DOTALL)
    return cleaned_text

def add_marker_before_first_occurrence(text, search_string):
    pattern = re.escape(search_string)
    marked_text = re.sub(pattern, '\n###' + search_string, text, count=1)
    return marked_text

In [9]:
def generate_train_prompt(data_point):
    query = data_point['query']
    query = add_marker_before_first_occurrence(query, 'Context:')
    query = add_marker_before_first_occurrence(query, 'Question:')
    query = add_marker_before_first_occurrence(query, 'Answer:')
    answer = data_point['answer']
    text = f'{query}\n{answer}'
    return {'text': text, 'labels': answer}

In [10]:
train_dataset = split_datasets['train'].shuffle().map(generate_train_prompt)

Map:   0%|          | 0/1167 [00:00<?, ? examples/s]

In [11]:
validation_dataset = split_datasets['validation'].shuffle().map(generate_train_prompt)

Map:   0%|          | 0/251 [00:00<?, ? examples/s]

In [12]:
def generate_test_prompt(data_point):
    query = data_point['query']
    query = add_marker_before_first_occurrence(query, 'Context:')
    query = add_marker_before_first_occurrence(query, 'Question:')
    query = add_marker_before_first_occurrence(query, 'Answer:')
    text = f'{query}\n'
    return {'text':text}

In [13]:
test_dataset = split_datasets['test'].shuffle().map(generate_test_prompt)

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

### Training

In [18]:
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=0.00001,
    bf16=True,
    num_train_epochs=16,
    save_strategy="epoch",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    output_dir='./experiments',
    remove_unused_columns=False,
    warmup_ratio=0.03,
    logging_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=15,
    label_names=['labels'],
    eval_steps=15,
    group_by_length=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    args=training_args,
    tokenizer=tokenizer,
    dataset_text_field='text',
    peft_config=config,
    max_seq_length=4096
)

Map:   0%|          | 0/1167 [00:00<?, ? examples/s]

Map:   0%|          | 0/251 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [19]:
model.config.use_cache = False
trainer.state.log_history = True
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkmb85[0m ([33mcam_kiril[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss
15,1.3828,1.383254
30,1.4188,1.379405
45,1.4474,1.370802
60,1.4323,1.358168
75,1.4294,1.344447
90,1.3992,1.330559
105,1.3428,1.315585
120,1.3025,1.298849
135,1.2964,1.280177
150,1.2919,1.261208


Checkpoint destination directory ./experiments/checkpoint-73 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-146 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-219 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-292 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-365 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-438 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-511 already exists and is non-empty. Saving will 

Checkpoint destination directory ./experiments/checkpoint-1095 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-1168 already exists and is non-empty. Saving will proceed but saved results may be invalid.


TrainOutput(global_step=1168, training_loss=1.1129945508421284, metrics={'train_runtime': 25649.5116, 'train_samples_per_second': 0.728, 'train_steps_per_second': 0.046, 'total_flos': 3.2745554016524206e+18, 'train_loss': 1.1129945508421284, 'epoch': 16.0})

### Save the fine-tuned model

In [20]:
model.save_pretrained(f'Mixtral-8x7B-v0.1-tatqa_batch_size_4_epochs_16')

### Evaluate the fine-tuned model

In [14]:
test_dataset = test_dataset.shuffle(seed=42)

In [15]:
import requests
import ast

request = {
    'max_new_tokens': 200,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
}

url = "http://127.0.0.1:5090/api/v1/generate"
headers = {'Content-Type': 'application/json'}

In [16]:
def trim_float_string(float_string):
    if float_string.endswith('.0'):
        return float_string[:-2]
    else:
        return float_string

In [17]:
def extract_and_compare(numerical_output, llm_output):
    numerical_value = float(numerical_output)

    numbers_in_text = re.findall(r'\b\d+\b', llm_output)
    words_in_text = re.findall(r'\b[a-zA-Z]+\b', llm_output)

    llm_values = [float(num) for num in numbers_in_text]

    for word in words_in_text:
        try:
            llm_values.append(float(w2n.word_to_num(word)))
        except ValueError:
            continue

    return any(numerical_value == llm_val for llm_val in llm_values)

In [18]:
total_correct = 0

In [19]:
from word2number import w2n

for i in range(len(test_dataset)):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)

    prediction = ast.literal_eval(response.text)["results"][0]['text'].lower()
    correct_ans = trim_float_string(test_dataset[i]['answer'].lower())

    if correct_ans in prediction:
        total_correct += 1
    else:
        try:
            if extract_and_compare(correct_ans, prediction):
                total_correct += 1
        except:
            continue

In [20]:
correct_percentage = (total_correct / len(test_dataset)) * 100
print(f'Exact Match(EM) Accuracy fined-tuned model: {correct_percentage}%')

Exact Match(EM) Accuracy fined-tuned model: 34.4%


### Evaluate the base model

In [24]:
test_dataset = test_dataset.shuffle(seed=42)

In [25]:
import requests
import ast

request = {
    'max_new_tokens': 200,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
}

url = "http://127.0.0.1:5010/api/v1/generate"
headers = {'Content-Type': 'application/json'}

In [26]:
def trim_float_string(float_string):
    if float_string.endswith('.0'):
        return float_string[:-2]
    else:
        return float_string

In [27]:
def extract_and_compare(numerical_output, llm_output):
    numerical_value = float(numerical_output)

    numbers_in_text = re.findall(r'\b\d+\b', llm_output)
    words_in_text = re.findall(r'\b[a-zA-Z]+\b', llm_output)

    llm_values = [float(num) for num in numbers_in_text]

    for word in words_in_text:
        try:
            llm_values.append(float(w2n.word_to_num(word)))
        except ValueError:
            continue

    return any(numerical_value == llm_val for llm_val in llm_values)

In [28]:
total_correct = 0

In [29]:
from word2number import w2n

for i in range(len(test_dataset)):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)

    prediction = ast.literal_eval(response.text)["results"][0]['text'].lower()
    correct_ans = trim_float_string(test_dataset[i]['answer'].lower())

    if correct_ans in prediction:
        total_correct += 1
    else:
        try:
            if extract_and_compare(correct_ans, prediction):
                total_correct += 1
        except:
            continue

In [30]:
correct_percentage = (total_correct / len(test_dataset)) * 100
print(f'Exact Match(EM) Accuracy {correct_percentage}%')

Exact Match(EM) Accuracy 34.0%
