# Fine-tuning Mistral 7B using QLoRA for Georeferencing Biological Collection Records

### Environment Set Up
NVIDIA A40-24Q
8CPU(s), 64GB memory, 200GB HDD, and 24G GPU
CUDA 12.2

## Installing the packages

In [None]:
# You only need to run this once per machine
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets matplotlib

In [None]:
# Setting-up accelerator
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

### 1. Load Dataset

Load the training, testing and validation sets

In [None]:
from datasets import load_dataset

train_dataset = load_dataset('csv', data_files='Data/train_data.csv', split='train') # 70%
eval_dataset = load_dataset('csv', data_files='Data/eval_data.csv', split='train') # 15%
test_dataset = load_dataset('csv', data_files='Data/test_data.csv', split='train') # 15%

## Formatting prompts

In [None]:
def formatting_func(row):
    # Check if 'stateProvince' is None and adjust the context string accordingly
    if row['stateProvince'] is None:
        location_context = "<country>."
    else:
        location_context = f"{row['stateProvince']}, <country>."

    instruction = (
        f"Task: Accurately georeference the location provided in the 'Locality Description' below, expressing the coordinates in decimal degrees."
        f"\nContext: This 'Locality Description' refers to a location in {location_context}"
        f"\nLocality Description: {row['locality']}\nGeographic Coordinates: {row['geolocation']}")
    return instruction

### 2. Load Base Model

Loading Mistral - `mistralai/Mistral-7B-v0.1` - using 4-bit quantization

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="cuda:0")

### 3. Tokenization

Set up the tokenizer.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

To identify the `max_length`, length distribution is plotted.

In [None]:
def generate_and_tokenize_prompt(prompt):
    return tokenizer(formatting_func(prompt))

tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)

import matplotlib.pyplot as plt

def plot_data_lengths(tokenize_train_dataset, tokenized_val_dataset):
    lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
    lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
    print(len(lengths))

    # Plotting the histogram
    plt.figure(figsize=(10, 6))
    plt.hist(lengths, bins=20, alpha=0.7, color='blue')
    plt.xlabel('Length of input_ids')
    plt.ylabel('Frequency')
    plt.title('Distribution of Lengths of input_ids')
    plt.show()

plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)

Tokenize the prompts using the identified `max_length`

In [None]:
max_length = <max-length>

def generate_and_tokenize_prompt_final(prompt):
    result = tokenizer(
        formatting_func(prompt),
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

Reformat the prompt and tokenize each sample:

In [None]:
tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt_final)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt_final)

### 4. Set Up LoRA

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

# Apply the accelerator. You can comment this out to remove the accelerator.
model = accelerator.prepare_model(model)


Connect with Weights & Biases to the training metrics.

In [None]:
!pip install -q wandb -U

import wandb, os
wandb.login()

wandb_project = "finetune-mistral-georeferencing"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

### 5. Training

In [None]:
import transformers
from datetime import datetime

project = "finetune-mistral-georeferencing"
base_model_name = "mistral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        optim="adamw_8bit",
        logging_steps=50,
        learning_rate=2e-4,
        evaluation_strategy="steps",
        do_eval=True,
        eval_steps=100,
        fp16= not torch.cuda.is_bf16_supported(),
        bf16= torch.cuda.is_bf16_supported(),
        num_train_epochs=3,
        weight_decay=0.0,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        gradient_checkpointing=True,
        report_to="wandb",
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False
trainer.train()

new_model = "./fine-tuned-model"

trainer.model.save_pretrained(new_model)