# PEFT fine-tuning of Mistral 7-B

This notebook shows how to fine-tune the [Mistral 7-B](https://huggingface.co/mistralai/Mistral-7B-v0.1) model for a specific domain. We'll use a [legal question answering](https://huggingface.co/datasets/umarbutler/open-australian-legal-qa) dataset, which has examples of legal questions with answers backed by relevant case law.

## Prerequisites

This notebook needs a kernel using PyTorch 2.0 or later, and should have a GPU.

We run a single epoch with a limited number of steps for the sake of time. In order to get better results, you should use more epochs and steps.

## Install dependencies

In [2]:
%pip install --upgrade pip --quiet

[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
#
# Fully Sharded Data Parallel (FSDP) requires PyTorch >= 2.1.0
#
%pip install --upgrade torch

Collecting torch
  Using cached torch-2.2.1-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Using cached typing_extensions-4.10.0-py3-none-any.whl.metadata (3.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==

In [4]:
%pip install -U bitsandbytes
%pip install -U git+https://github.com/huggingface/transformers.git
%pip install -U git+https://github.com/huggingface/peft.git
%pip install -U git+https://github.com/huggingface/accelerate.git
%pip install -U datasets scipy ipywidgets
%pip install sentencepiece

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl.metadata (1.8 kB)
Downloading bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl (102.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.0
[0mNote: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-labigq45
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-labigq45
  Resolved https://github.com/huggingface/transformers.git to commit 66ce9593fdb8e340df546ddd0774eb444f17a12c
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  In

## Imports

In [2]:
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config = FullStateDictConfig(offload_to_cpu = True, rank0_only = False),
    optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu = True, rank0_only = False)
)

accelerator = Accelerator(fsdp_plugin = fsdp_plugin)

Detected kernel version 4.14.336, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


## Dataset

 We use [a legal QA dataset](https://huggingface.co/datasets/umarbutler/open-australian-legal-qa).

In [8]:
from datasets import load_dataset

full_dataset = load_dataset("umarbutler/open-australian-legal-qa", "default")

In [9]:
full_dataset['train']

Dataset({
    features: ['question', 'answer', 'text', 'prompt', 'source'],
    num_rows: 2124
})

In [10]:
from datasets import DatasetDict

# 90% train, 10% test + validation
train_testvalid = full_dataset['train'].train_test_split(test_size=0.2)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

## Load base model

In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [12]:
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    model_max_length=512,
    padding_side="left"
)

tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

In [16]:
def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=512,
        padding="max_length",
    )
    
    result['labels'] = result['input_ids'].copy()
    
    return(result)

## Prepare dataset


In [17]:
def generate_and_tokenize_prompt(data_point):
    full_prompt = f""""{tokenizer.bos_token}{data_point['prompt']}

### Answer: {data_point['answer']}{tokenizer.eos_token}"""
    
    return(tokenize(full_prompt))

In [19]:
tokenized_training_dataset = dataset['train'].map(generate_and_tokenize_prompt)
tokenized_validation_dataset = dataset['valid'].map(generate_and_tokenize_prompt)

Map:   0%|          | 0/1699 [00:00<?, ? examples/s]

Map:   0%|          | 0/212 [00:00<?, ? examples/s]

In [20]:
print(tokenized_training_dataset)
print(tokenized_validation_dataset)

Dataset({
    features: ['question', 'answer', 'text', 'prompt', 'source', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1699
})
Dataset({
    features: ['question', 'answer', 'text', 'prompt', 'source', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 212
})


## Test the base model on one sample

In [21]:
eval_prompt= dataset['test'][0]['prompt']
eval_prompt

"# Snippet\nThe snippet from an Australian legal document from which you must synthesise a question and answer is provided below.\n<document_metadata>\n<document_title>Law Society of New South Wales v McCartney [2017] NSWCATOD 130</document_title>\n<document_jurisdiction>New South Wales</document_jurisdiction>\n<document_type>Decision</document_type>\n</document_metadata>\n<snippet>\n32. The applicant did not rely on the alternative limb of s 497, namely that the conduct constituted a “substantial” failure to reach or maintain a reasonable standard of competence and diligence. On the basis of what was said in the Xu case, we are of the view that the respondent’s conduct did not constitute professional misconduct under s 497. The respondent’s conduct was, like Mr Xu, incredibly sloppy and fell well short of the standard of competence and diligence that a member of the public is entitled to expect of a reasonably competent Australian legal practitioner, but it did not, in our view, const

In [22]:
eval_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    add_bos_token = True
)

model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    print(eval_tokenizer.decode(model.generate(**model_input, max_new_tokens=256)[0]))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> # Snippet
The snippet from an Australian legal document from which you must synthesise a question and answer is provided below.
<document_metadata>
<document_title>Law Society of New South Wales v McCartney [2017] NSWCATOD 130</document_title>
<document_jurisdiction>New South Wales</document_jurisdiction>
<document_type>Decision</document_type>
</document_metadata>
<snippet>
32. The applicant did not rely on the alternative limb of s 497, namely that the conduct constituted a “substantial” failure to reach or maintain a reasonable standard of competence and diligence. On the basis of what was said in the Xu case, we are of the view that the respondent’s conduct did not constitute professional misconduct under s 497. The respondent’s conduct was, like Mr Xu, incredibly sloppy and fell well short of the standard of competence and diligence that a member of the public is entitled to expect of a reasonably competent Australian legal practitioner, but it did not, in our view, constitute

## Fine-tune the model

In [23]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [24]:
print(model)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )

In [25]:
def print_trainable_parameters(model):
    
    all_params = 0
    trainable_params = 0
    
    for _, param in model.named_parameters():
        all_params += param.numel()
        
        if param.requires_grad:
            trainable_params += param.numel()
    
    print(
        f"trainable params: {trainable_params}; all params: {all_params}; % trainable: {100 * (trainable_params / all_params)}"
    )

In [26]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        
        "gate_proj",
        "up_proj",
        "down_proj",
        
        "lm_head"
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

model = accelerator.prepare_model(model)

trainable params: 21260288; all params: 3773331456; % trainable: 0.5634354746703705


In [27]:
if torch.cuda.device_count() > 1:
    model.is_parallelizable = True
    model.model_parallel = True

In [28]:
import transformers
from datetime import datetime

base_model_name = "mistral"
project = "finetune-legal"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_training_dataset,
    eval_dataset=tokenized_validation_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=5,
        per_device_train_batch_size=2,

        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':False},
        gradient_accumulation_steps=4,
        
        max_steps=200, # reduced from 1000
        learning_rate=2.5e-5,
        logging_steps=50,
        bf16=True,
        optim="paged_adamw_8bit",
        logging_dir="./",
        save_strategy="steps",
        save_steps=50,
        evaluation_strategy="steps",
        eval_steps=50,
        do_eval=True,
        report_to="none",
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False)
)

model.config.use_cache = False

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.14.336, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss
50,1.421,1.179742
100,1.1186,1.139615
150,1.1289,1.128315
200,1.1031,1.12397




TrainOutput(global_step=200, training_loss=1.1928878974914552, metrics={'train_runtime': 1676.2015, 'train_samples_per_second': 0.955, 'train_steps_per_second': 0.119, 'total_flos': 3.50548150714368e+16, 'train_loss': 1.1928878974914552, 'epoch': 0.94})

## Try the fine-tuned model

In [29]:
bnb_conf = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

eval_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    add_bos_token = True
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [30]:
from peft import PeftModel

ft_model = PeftModel.from_pretrained(base_model, "mistral-finetune-legal/checkpoint-200")

In [None]:
eval_prompt= dataset['test'][0]['prompt']

In [31]:
eval_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1"
)

model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")

ft_model.eval()
with torch.no_grad():
    print(eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=256)[0]))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> # Snippet
The snippet from an Australian legal document from which you must synthesise a question and answer is provided below.
<document_metadata>
<document_title>Law Society of New South Wales v McCartney [2017] NSWCATOD 130</document_title>
<document_jurisdiction>New South Wales</document_jurisdiction>
<document_type>Decision</document_type>
</document_metadata>
<snippet>
32. The applicant did not rely on the alternative limb of s 497, namely that the conduct constituted a “substantial” failure to reach or maintain a reasonable standard of competence and diligence. On the basis of what was said in the Xu case, we are of the view that the respondent’s conduct did not constitute professional misconduct under s 497. The respondent’s conduct was, like Mr Xu, incredibly sloppy and fell well short of the standard of competence and diligence that a member of the public is entitled to expect of a reasonably competent Australian legal practitioner, but it did not, in our view, constitute

In [32]:
dataset['test'][0]['answer']

"In the case of Law Society of New South Wales v McCartney [2017] NSWCATOD 130, the Tribunal decided that the respondent's conduct did not constitute professional misconduct under s 497. The Tribunal found the respondent's conduct to be incredibly sloppy and falling short of the standard of competence and diligence expected of a reasonably competent Australian legal practitioner, but it did not constitute professional misconduct. The Tribunal also did not accept the contention that the respondent's conduct constituted professional misconduct in the Allinson sense. However, the Tribunal agreed with the applicant's contentions that the respondent should be reprimanded and that a fine should be imposed."

## Dora

In [33]:
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [34]:
base_model.gradient_checkpointing_enable()
base_model = prepare_model_for_kbit_training(base_model)

In [35]:
dora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        
        "gate_proj",
        "up_proj",
        "down_proj",
        
        "lm_head"
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
    use_dora = True
)

In [36]:
dora_model = get_peft_model(base_model, dora_config)
print_trainable_parameters(dora_model)

dora_model = accelerator.prepare_model(dora_model)

trainable params: 22668544; all params: 3774739712; % trainable: 0.6005326387919169


In [37]:
if torch.cuda.device_count() > 1:
    dora_model.is_parallelizable = True
    dora_model.model_parallel = True

In [38]:
project = "finetune-legal-dora"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=dora_model,
    train_dataset=tokenized_training_dataset,
    eval_dataset=tokenized_validation_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=5,
        per_device_train_batch_size=2,

        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':False},
        gradient_accumulation_steps=4,
        
        max_steps=200, # reduced from 1000
        learning_rate=2.5e-5,
        logging_steps=50,
        bf16=True,
        optim="paged_adamw_8bit",
        logging_dir="./",
        save_strategy="steps",
        save_steps=50,
        evaluation_strategy="steps",
        eval_steps=50,
        do_eval=True,
        report_to="none",
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False)
)

dora_model.config.use_cache = False

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.14.336, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss
50,1.4104,1.17766
100,1.1181,1.139276
150,1.1287,1.128569
200,1.1023,1.124236




TrainOutput(global_step=200, training_loss=1.1898723220825196, metrics={'train_runtime': 3712.375, 'train_samples_per_second': 0.431, 'train_steps_per_second': 0.054, 'total_flos': 3.5061736931328e+16, 'train_loss': 1.1898723220825196, 'epoch': 0.94})