In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments

# Supervised Fine-tuning of a Hugging Face LLM Model

Supervised Fine-tuning (SFT) is used to teach a model new behavior and skills. SFT along with RAG form the two main ways one can customize an LLM. [Newer approaches](https://gorilla.cs.berkeley.edu/blogs/9_raft.html) are also emerging.

## When to Fine-tune?

Before fine-tuning one should consider RAG. RAG is used mainly to incorporate new knowledge (facts) into an LLM based application. These facts remain in your own documents and never become “learned” by the model. The model can read the facts in your document and reason about them. This is very useful and many applications can be written using just RAG.

Fine-tuning teaches new behavior and skills to a model. Consider these use cases:

- Read a medical document and answer questions. Medical documents are written in a particular way that an out of the box model may not understand very well.
- Convert a plain text question into a SQL query.
- Summarize a property lease document. A model by default may not know what the key points are in a lease.

In each of these cases RAG may not provide a satisfactory solution. We can, however, teach the model to do these things to our expected level of satisfaction. This is where SFT comes in. We call it “supervised” because during training we provide the question and context and teach the model what a good answer would look like.

SFT can also be used to incorporate new knowledge into the model. But that is not its primary purpose. This is mainly due to these reasons:

- The facts learned by a model are never quite exact. For example, the Mistral-7B-Instruct model is 15GB in size. Even though it was trained on a huge corpus of text, it cannot possibly contain all the knowledge of the world. That is not how LLMs work. They are not knowledge repositories.
- Facts can change from day to day. Retraining a model can be expensive.

## Parameter Efficient Fine-tuning (PEFT)

Retraining a large model can take a lot of GPU power and time. Fortunately, we’ve discovered that during training we can calculate the gradients and apply corrections to a small subset of the weights and still get very high quality results. This is called Parameter Efficient Fine-tuning (PEFT).

Several different approaches to PEFT exist. Adapter, Low-Rank Adaptation (LoRA), and Prefix tuning to name a few. In this notebook we will use LoRA.

## Quantization and PEFT

The story continues to get even better. [We’ve found out](https://arxiv.org/abs/2305.14314) that quantization can be combined with PEFT and still get high quality results. When quantization is combined with LoRA, we call it QLoRA.

In this notebook we will do QLoRA.

## Hugging Face Support for PEFT

Hugging Face originally released the [PEFT](https://github.com/huggingface/peft) library. Later, [TRL](https://github.com/huggingface/trl) was released that wrapped over the PEFT library and made it even easier to run fine-tuning. They support quantization and work in conjunction with BitsAndBytes.

The [SFTTrainer](https://huggingface.co/docs/trl/v0.8.0/en/sft_trainer#trl.SFTTrainer) class forms the heart of the TRL library.

In this notebook we will use TRL. These packages are installed as follows.

```
pip install peft trl
```

## The Business Problem

[Midjourney](https://www.midjourney.com/) is a text to art generator. It requires some skill to come up with effective prompts with special instructions like “35mm film, epic, dramatic, photorealistic, -ar 3:2”. We are being asked to develop an application where users can enter plain English prompts and generate effective Midjourney prompts.

An example interaction with LLM will look like this.

```
User:
Generate a prompt for Midjourney based on the sentence below.

A blue table against a red wall.

LLM:
blue table, red wall, photorealistic –ar 4:3 –v 5.1
```

## Proposed Solution

This problem is a classic use case for fine-tuning. Adding knowledge to the system using RAG won’t help very much. We need to teach the model a new skill to translate plain English text to Midjourney prompt.

We choose to use a few techniques to speed up training and inference:

- Use a small language model and see if that works well. We will use TinyLlama/TinyLlama-1.1B-Chat-v1.0.
- Use LoRA PEFT.
- Use 4bit quantization.

We will use the [TheBossLevel123/midjourney-prompt-enhancement](https://huggingface.co/datasets/TheBossLevel123/midjourney-prompt-enhancement) dataset for training. It gives us exactly what we need – translation between plain text and good quality Midjourney prompts.

Go to the dataset's home page and review what it looks like.

## Prepare Training Data

SFTTrainer supports [two different data formats](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support). We will format each piece of training data as follows.

```json
{
  "messages": [
    {"role": "system", "content": "Generate a prompt for Midjourney based on the sentence below"},
    {"role": "user", "content": "A bustling city street scene at night."},
    {"role": "assistant", "content": "Photorealistic bustling city street at night, vibrant lights, busy pedestrians, urban life --ar 16:9 --v 5.2"}
  ]
}
```

The source dataset has two columns:

- input – The plain text prompt
- output – Good quality Midjourney prompt

The code below will load the dataset, reformat it according to the requirement of SFTTrainer and save it in the ``train_dataset.jsonl`` file.

In [None]:
def prepare_data():
 
    #Data mapping function
    def create_conversation(sample):
        return {
          "messages": [
            {"role": "system", "content": "Generate a prompt for Midjourney based on the sentence below"},
            {"role": "user", "content": sample["input"]},
            {"role": "assistant", "content": sample["output"]}
          ]
        }
 
    dataset = load_dataset(
        "TheBossLevel123/midjourney-prompt-enhancement", 
        split="train")
     
    #By default the map() function merges new columns to the dataset.
    #We need only the "messages" column. So, delete the input and output columns.
    dataset = dataset.map(
        create_conversation, 
        remove_columns=["input", "output"])
     
    # Save dataset
    dataset.to_json("train_dataset.jsonl", orient="records")
 
#Run data conversion
prepare_data()

JSONL is an interesting format where each line is a JSON document. Open the ``train_dataset.jsonl`` file and review it.

Data conversion needs to be done only once. Before running training we need to load the converted data.

In [None]:
train_dataset = load_dataset(
    "json", 
    data_files="train_dataset.jsonl", 
    split="train")

## Load the Base Model

This code will load the base model with 4bit quantization.

In [None]:
bnb_config = BitsAndBytesConfig(
    #For 4bit quantization
    load_in_4bit=True
)
 
model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    quantization_config=bnb_config)
 
# Do not do this:
# model.to(device)
 
tokenizer = AutoTokenizer.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0")

#Some documents ask you to call setup_chat_format().
#Don't do this. Doing so will overwrite the correct
#tokenizer settings with chatml settings.
# model, tokenizer = setup_chat_format(model, tokenizer)

## Evaluate the Base Model

Before running any training we should see if the base model is any good at solving our problems. We write a simple utility to perform text generation.

In [None]:
#Generate midjourney prompt
def generate(model, tokenizer, prompt):
  streamer = TextStreamer(tokenizer)
   
  messages = [
      {"role": "system", "content": "Generate a prompt for Midjourney based on the sentence below."},
      {"role": "user", "content": prompt}
  ]
 
  device = "cuda"
 
  encoded = tokenizer.apply_chat_template(
      messages, 
      add_generation_prompt=True, 
      return_tensors="pt").to(device)
 
  generated_ids = model.generate(encoded, streamer=streamer, max_new_tokens=2000)


#Give it a try.
generate(model, tokenizer, "A blue table against a red wall.")

This outputs something not very usable as a Midjourney prompt.

```
Amidst a sea of red, a blue table stands tall, 
a symbol of hope and unity.
```

Clearly, the base model is not very good at this. Let’s see if fine-tuning will help.

## Run Training

First we configure the training parameters. We run training for 10 epochs. Each batch will have 3 samples of training data. We set the maximum sequence length to only 2000 because we're using a very small language model.

In [None]:
peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)
 
args = TrainingArguments(
    output_dir="trained-model", # directory to save and repository id
    num_train_epochs=6,                     # number of training epochs
    per_device_train_batch_size=3,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=2,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=False,                              # use bfloat16 precision
    tf32=False,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
)
 
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length=2000,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

Now, we can begin training. As training progresses you should see a dramatic reduction in loss. This is always a welcome sign.

In [None]:
trainer.train()

While training is going on, you can use the ``nvidia-smi`` command to check GPU usage and memory avalability.

## Save the Model

The model weights are saved for every epoch in the ./trained-model folder. But we should save the final version. This will save the model as well as the tokenizer.

In [None]:
trainer.save_model()

## Evaluate the Model

We can use the same prompt that we tried before.

In [None]:
generate(model, tokenizer, "A blue table against a red wall.")

It will now generate this.

```
blue table, red wall, photorealistic --ar 4:3 --v 5.1
```

The model now understands how to generate Midjourney prompts.

## Run Inference

To run inference we need to load the fine-tuned model from the ``./trained-model`` folder. This model is already quantized. There’s no need to quantize it again.

Before you go forward I recommend that you restart the notebook session or run this code to free up memory.

In [None]:
#Free up memory taken up during training
del model
del trainer
torch.cuda.empty_cache()

In [None]:
#Load the model
model = AutoModelForCausalLM.from_pretrained(
    "./trained-model")

#You need to load it into the GPU
device = "cuda"
 
model.to(device)
 
tokenizer = AutoTokenizer.from_pretrained(
    "./trained-model")

Run inference.

In [None]:
generate(model, tokenizer, "A blue table against a red wall.")

## Summary

Fine-tuning teaches new skills to a model. It takes more GPU power than RAG but certain things simply cannot be done using RAG. Over the last few months fine-tuning has become more democratized. We’re able to re-train a model in regular GPU using techniques like quantization and PEFT. It’s amazing what even a small LLM can do once it is fine-tuned.