<a href="https://colab.research.google.com/gist/zredlined/75a14557b5a288551131d432a2c2f249/gemma-lora-finetune-synthetic-data-ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to Fine-Tune Gemma to Label PII in Financial Data

Recently, Google released [Gemma](https://huggingface.co/blog/gemma), a new family of state-of-the-art open LLMs. Gemma comes in two sizes: 7B parameters, for efficient deployment and development on consumer-size GPU and TPU and 2B versions for CPU and on-device applications. Both come in base and instruction-tuned variants.

This blog post is derived Phil Schmid's excellent [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl) blog. We will use Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index), along with synthetic training data generated by [Gretel Navigator](https://gretel.ai/navigator).

1. Setup development environment
2. Create and prepare the synthetic dataset that we generated earlier
3. Fine-tune LLM using `trl` and the `SFTTrainer`
4. Test and evaluate the LLM - Visualize results, identify gaps, iterate on synthetic training data if necessary

Note: This blog was created to run on an NVIDIA A100 40GB GPU, but can also run on recent consumer grade GPUs running Ampere architecture or later (A10G, RTX3090/4090) with 24GB or more memory.

## 1. Setup development environment

Our first step is to install Hugging Face Libraries and Pytorch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs

If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).

In [None]:
# Install Pytorch & other libraries
!pip install -Uqqr requirements.txt

In [None]:
import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'

We will also need to login into our [Hugging Face](https://huggingface.co) account to be able to access Gemma. To use Gemma you first need to agree to the terms of use. You can do this by visiting the [Gemma page](https://huggingface.co/google/gemma-7b) following the gate mechanism.

In [None]:
from huggingface_hub import login

hf_token = input("Please enter your Hugging Face token: ")

login(token=hf_token, add_to_git_credential=False)

## 2. Create and prepare the dataset

In this section, we will convert our synthetically generated training set into Huggingface Dataset to support batching and streaming during training, and the ChatML format matching the example format below.

```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```

The latest release of `trl` supports the conversation dataset formats. This means don't need to do any additional formatting of the dataset. We can use the dataset as is.

In [None]:
import json
from datasets import Dataset
from smart_open import open

# Path to the synthetic training data
dataset_path = "https://storage.googleapis.com/gretel-public-data/gretel-datasets/synthetic-pii-training-data/generated_results_with_spans.jsonl"

# Convert dataset to OAI messages
system_message = """You are an expert named entity recognition system for financial documents. Users will send you text and you will return only a labeled version of the exact same text, without any additional content. Enclose the recognized entities in curly braces and square brackets, like this: {[ENTITY_TYPE]text}. Do not provide any explanations, summaries, or other additional text."""

def create_conversation(sample):
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": sample["document_text"]},
            {"role": "assistant", "content": sample["text_markup"]}
        ]
    }

def load_jsonl(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]
        for item in data:
            if 'pii_spans' in item:
                item['pii_spans'] = json.dumps(item['pii_spans']) 
    return data

def pretty_print_conversation(messages):
    print("Conversation:")
    for msg in messages:
        print(f"\nRole: {msg['role']}")
        print(f"Content: {msg['content']}\n")
        print("-" * 50)

# Load dataset from JSONL file
data = load_jsonl(dataset_path)

# Convert data into a dictionary with column names as keys
column_names = list(data[0].keys())
columns = {column: [sample[column] for sample in data] for column in column_names}

# Convert to Dataset object
dataset = Dataset.from_dict(columns)

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, batched=False)

# Split dataset into training and testing sets
train_size = int(len(dataset) * 0.8)
train_dataset = dataset.select(range(train_size))
test_dataset = dataset.select(range(train_size, len(dataset)))

pretty_print_conversation(train_dataset[42]["messages"])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Save datasets to disk
train_dataset.to_json("train_dataset.json", orient="records")
test_dataset.to_json("test_dataset.json", orient="records")



## 3. Fine-tune LLM using `trl` and the `SFTTrainer`

We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter-efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig

# Model and tokenizer IDs
model_id = "google/gemma-2b-it"
tokenizer_id = "philschmid/gemma-tokenizer-chatml"
output_dir = "gemma-2b-it-label-chatml"

# BitsAndBytesConfig for int-4 quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Configure environment to use the first GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("CUDA_VISIBLE_DEVICES set to:", os.environ["CUDA_VISIBLE_DEVICES"])

# Load model with specified device_map and quantization configurations
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map={'': torch.cuda.current_device()},
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

# Load and configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right'  # Adjust to prevent warnings

# LoRA configuration for model parameter-efficient tuning
peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.05,
    r=6, 
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

# Training arguments for model training
args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=8,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    tf32=True,
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=False,
    report_to="tensorboard",
)

Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

In [None]:
%%time

from trl import SFTTrainer

max_seq_length = 1512 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model
trainer.save_model()

The training with Flash Attention for 8 epochs with a synthetic dataset of 3k samples (780 steps) took 1:5:12 on a `a2-highgpu-1g` with a single 40GB NVIDIA A100 GPU. The instance costs `1.61$/h` which brings us to a total cost of only ~`$2`.



### Optional: Merge LoRA adapter in to the original model

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the save_pretrained method.

Check out the [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#optional-merge-lora-adapter-in-to-the-original-model) blog post on how to do it .



## 3. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric.

_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out [Evaluate LLMs and RAG a practical example using Langchain and Hugging Face](https://www.philschmid.de/evaluate-llm) blog post._



In [None]:
# Clear out GPU memory, and reload the saved model from disk
del model
del trainer
torch.cuda.empty_cache()

We load the adapted model and the tokenize into the `pipeline` to easily test it and extract the token id of `<|im_end|>` to use it in the `generate` method.

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import  AutoTokenizer, pipeline

peft_model_id = "gemma-2b-it-label-chatml"

# Load Model with PEFT adapter
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
eos_token = tokenizer("<|im_end|>",add_special_tokens=False)["input_ids"][0]

Lets test some prompt samples and see how the model performs.

In [None]:
from finetune_utils import test_inference_batch

# Define prompts
prompts = [
    "The project site is located at 33.35283, -111.78903, where the new system will be rolled out first, showcasing our commitment to innovation in this region.",
    "This agreement is executed on this 15th day of October, 2023, at New York, NY. Passport numbers: 635510778, 365520581, and 603951954 are provided for verification purposes.",
    "We received a support ticket from a customer reporting slow network connectivity at 046 Malone Neck.",
]

# Repeat prompts
prompts *= 1

# Process prompts in batch
generated_texts = test_inference_batch(prompts, pipe, eos_token_id=eos_token)

# Print results
for i, generated_text in enumerate(generated_texts):
    prompt = prompts[i]
    total_chars_prompt = len(prompt)
    total_chars_response = len(generated_text)
    print(f"Prompt:\n{prompt}")
    print(f"Response:\n{generated_text}")
    print(f"Total characters in prompt: {total_chars_prompt}")
    print(f"Total characters in response: {total_chars_response}")
    print("-" * 50)

# Evaluating the fine-tuned model

Now, we are raedy to evaluate our fine-tuned model against the `test_dataset`. Below is an evaluation routine that measures detection accuracy overall (across all document and PII types), as well as per-type and document. This granularity helps us understand where the model is performing well, and where we might need to tune the synthetic data or provide additional examples.

We will also add a timer in, to keep track of how long inference takes. Note: This is for illustrative purposes only, we are not doing any optimizations here to speed up inference.

In [None]:
%%time
from finetune_utils import evaluate_pii_labeling_accuracy

MAX_EXAMPLES = 500

overall_accuracy, pii_type_counters, doc_type_pii_type_counters, detection_percentages, missed_detections = evaluate_pii_labeling_accuracy(test_dataset.select(range(MAX_EXAMPLES)), pipe)

print("=" * 40)
print(f"Overall Accuracy: {overall_accuracy:.2f}% (Found: {sum(c['found'] for c in pii_type_counters.values())}, Missed: {sum(c['missed'] for c in pii_type_counters.values())})")
print("=" * 40)

# Exploring evaluation results

Let's use some visualizations to dive in and understand where our model is working well, and where it's underperforming. We can use this to tune the synthetic data generation prompt to provide more examples of a particular document, PII type, or both to improve task performance, or to create more synthetic data to improve a class that is under-represented in training data.

In [None]:
from finetune_utils import plot_found_vs_missed

plot_found_vs_missed(pii_type_counters, peft_model_id)


## Heatmap: Visualizing PII Detection Across Document Types

We'll create a heatmap using Plotly to visualize our model's performance in detecting PII across different document types. While Large Language Models (LLMs) and today's Small Language Models (SLMs) are more adaptable than traditional NER models (even BERT!), they still require examples specific to the unique data businesses work with.

By analyzing the heatmap, we can identify areas where the model performs well and where it needs improvement. This insights will guide us in collecting additional training data or refining examples for specific document types, ultimately enhancing the model's accuracy and robustness in handling diverse data schemas.

In [None]:
from finetune_utils import plot_detection_percentages_heatmap

plot_detection_percentages_heatmap(detection_percentages, peft_model_id)


In [None]:
# Directly investigate any missed detections
missed_detections

## Estimating Throughput

As part of an optimized pipeline on an A100 GPU in a batched mode with an input token size of 500 tokens and an accelerated back end such as vLLM or TensorRT-LLM, the Gemma 7B parameter model can average a sustained throughput of 550 tokens/sec, and Gemma 2B at an estimated 2,000 tokens/sec. With Gemma's expanded vocabulary and efficiency, we'll estimate 4 characters per token on average. 

550 tokens/sec * 4 characters/token * 3600 seconds/hour = 7,920,000 characters per hour that can be anonymized on an A100 GPU, at $1.6 per hour. 1mb of text (1 million characters) with Gemma 7b can be anonymized for about $0.20 in compute. The smaller Gemma2 model can accomplish the same task at approximately $0.05 in compute.

In comparison, NER frameworks such as Flair and Spacy `en_core_web_trf` achieve between 1,184 and 3,768 words per second on a GPU (about 2-5x more efficient than Gemma 2B in compute) and ~89% accuracy on NER tasks with industry benchmarks, but with less support for multiple languages and both lack the ability to customize detections via prompt tuning.

Coming soon: Cutting edge GPUs like NVIDIA's H200 can lead to potential orders of magnitude gains with Gemma models- where NVIDIA quotes a single H200 GPU delivering 79,000 tokens per second on the Gemma 2B model, and 19,000 tokens per second on the larger 7B model. With Gemma 2b and assuming a compute cost around $8/hr for an H200, this will reduce the cost of anonymizing a 1mb text corpus from $0.05 to about $0.0012.