# Fine-tuning Example

This notebook demonstrates how to fine-tune a code-generating LLM to observe a specific prompt and output format.
We want Python functions to be enclosed in XML tags:

```python
# <function name="foo">
def foo():
    # content
# </function>
```

We will obtain training data by extracting functions from a repository using GitPython and tree-sitter.

## ⚠️ Warnings

⚡ Fine-tuning happens in-place, which means the **original model is lost** once the PEFT-Adapter is wrapped around it. If you need to test the original model again or re-start with different parameters, **restart the notebook kernel**.

⚡ If your loss tensor reports **NaN**, the training did not converge. Try reducing learning rate, introduce warm-up, change the order of training data, or other parameters (like LoRA rank or alpha). Make sure your model is loaded in **bfloat16**, because float16 does not always have sufficient range to cover all optimization steps.

In [1]:
import huggingface_hub
from transformers import GemmaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch

# Load an LLM
* This example uses [CodeGemma 1.1](https://huggingface.co/google/codegemma-1.1-2b)
* We use the small 2B variant, note that fine-tuning larger models is much more memory-intensive!

In [2]:
gpu = torch.device('cuda:0')
model_id = "google/codegemma-1.1-2b"
tokenizer = GemmaTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=gpu, torch_dtype=torch.bfloat16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
# A standard greedy generation helper
def generate(prompt, max_new_tokens=200):
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(gpu)
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens)
    return tokenizer.decode(outputs[0])

In [4]:
# Test how the model responds to the desired prompt (probably not so well)
print(generate('# <function name="test_http_404">\n'))

<bos># <function name="test_http_404">
def test_http_404():
    """
    Test the HTTP 404 response for a non-existent page.
    """
    response = client.get('/non-existent-page')
    assert response.status_code == 404
    assert response.content == b'<h1>404 Not Found</h1>'
    assert response.headers['Content-Type'] == 'text/html; charset=utf-8'
    assert response.headers['Cache-Control'] == 'no-cache, no-store, must-revalidate'
    assert response.headers['Pragma'] == 'no-cache'
    assert response.headers['Expires'] == '0'
    assert response.headers['X-Content-Type-Options'] == 'nosniff'
    assert response.headers['X-Frame-Options'] == 'DENY'
    assert response.headers['X-


# Data Procurement

## Getting Raw Data
This part requires the packages `autopep8`, `tree-sitter-python`, and `GitPython`.

* We will parse all `.py` files from the current commit in the [Flask](https://github.com/pallets/flask) repository
* Using a tree-sitter query, we extract all function definitions
* We wrap them in the desired prompt format with XML tags

In [5]:
import git
import autopep8
import tree_sitter_python as tspython
from tree_sitter import Language, Parser

PY_LANGUAGE = Language(tspython.language(), "python")

In [6]:
repo = git.Repo('./flask.git')
tree = repo.head.commit.tree

parser = Parser()
parser.set_language(PY_LANGUAGE)
query = PY_LANGUAGE.query('''(function_definition) @func''')

files = [item.data_stream.read()
         for item in tree.list_traverse()
         if item.type == 'blob'
         and item.name.endswith('.py')]

def format_node(node):
    code = autopep8.fix_code(node.text.decode('utf-8'))
    name = node.child_by_field_name('name').text.decode('utf-8')
    return f'# <function name="{name}">\n{code}# </function>'

functions = [format_node(node)
             for file in files
             for node, _ in query.captures(parser.parse(file).root_node)]

## Format Data for Training

Next, we need to process the dataset into a format consumable by an LLM training procedure.
* We demonstrate the use of the `datasets` library to deal with (possibly large) datasets
* We **tokenize** our training examples
* To evaluate whether training improves something, we split of a small **test set**, the remaining data is our **train set**
* Training happens in **blocks of the same size**, so we split our training data into blocks of equal token numbers, possibly **padding** the result

In [7]:
import datasets

In [8]:
data = datasets.Dataset.from_dict({'source': functions})

In [9]:
print(data[300]['source'])

# <function name="test_multi_route_class_views">
def test_multi_route_class_views(app, client):
    class View:
        def __init__(self, app):
            app.add_url_rule("/", "index", self.index)
            app.add_url_rule("/<test>/", "index", self.index)

        def index(self, test="a"):
            return test

    _ = View(app)
    rv = client.open("/")
    assert rv.data == b"a"
    rv = client.open("/b/")
    assert rv.data == b"b"
# </function>


In [10]:
def tokenize(dataset_row):
    source = dataset_row['source']
    input_ids = tokenizer.encode(source) + [tokenizer.eos_token_id]
    labels = input_ids.copy()
    
    return {
        'input_ids': input_ids,
        'labels': labels
    }

In [11]:
tokenized_data = data.map(tokenize, remove_columns=['source'])

Map:   0%|          | 0/1406 [00:00<?, ? examples/s]

In [12]:
def block(data, block_size=128):
    '''Arranges a batch into blocks of given token number'''

    # concatenate all items
    concatenated = sum(data['input_ids'], [])
    length = len(concatenated)

    # shape "n / block_size" blocks
    truncated_length = (length // block_size) * block_size
    blocked_ids = [concatenated[i : i + block_size] for i in range(0, truncated_length, block_size)]

    # add last block with padding
    pad_length = block_size - (length % block_size)  # remaining tokens to fill
    if pad_length != block_size:
        blocked_ids += [concatenated[truncated_length:] + [tokenizer.eos_token_id] * pad_length]

    # format as transformers-friendly model input
    assert len(blocked_ids) > 0
    return {
        'input_ids': blocked_ids,
        'labels': blocked_ids.copy()}
    

In [13]:
split_dataset = tokenized_data.train_test_split(
    test_size = 0.1,
    shuffle = True,
    seed = 421337)
test_data = split_dataset['test']
train_data = split_dataset['train']

In [14]:
test_data_blocks = test_data.map(block, batched=True)
train_data_blocks = train_data.map(block, batched=True)

Map:   0%|          | 0/141 [00:00<?, ? examples/s]

Map:   0%|          | 0/1265 [00:00<?, ? examples/s]

# Fine-tuning using LoRA

## Configuring the Training Procedure

To perform training, we have to decide on a method, which will be LoRA in our case.
* We configure a **LoRA adapter** which insertes the additional matrices into the model
* We configure a **Collator** which takes care of correctly aligning our data in (GPU) memory
* We configure a **DataLoader** which creates training batches from our collated data

In [15]:
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from transformers import get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
from tqdm import tqdm

In [16]:
# Check all model layers. We are looking for the names of the attention matrices (Q and V)
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaR

### The LoRA Parameters
For this experiment, we will "factorize" each matrix into rank 8 (`r = 8`), which means the diff to each m * n matrix is represented as the product of an n * 8 and an 8 * m matrix.
We can also configure `alpha`, which is the degree to which the diff overrides the original matrix.
**Note that `r` and `alpha` are first guesses. Depending on your task you may need to adjust them and see if performance improves!**

In [17]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False,
    target_modules=['q_proj', 'v_proj'],   # this is specific to each model! Look up the exact names in the model
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,  # to help with generalization, 5% of updates to the LoRA weights are discarded
    bias="all"
)

In [18]:
# This wraps our LLM into the adapter model
peft_model = get_peft_model(model, peft_config)

In [19]:
# Check how many paramaters we need to train (should be orders of magnitude fewer)
peft_model.print_trainable_parameters()

trainable params: 921,600 || all params: 2,507,094,016 || trainable%: 0.036759690467068624


### Formatting for Training

* We need to decide on a **batch size**, the number of blocks we show and train in parallel each step. To save GPU memory, we stick to a low number. If you have space, increase that value.
* We decide on a number of training **epochs**, how often we show all data to the model. As long as we see improvements, we can increase the number. For fine-tuning, a small number (1 - 3) is often okay.

In [20]:
batch_size = 4
num_epochs = 2

In [21]:
# some book-keeping to make sure any padding happens with <EOS> tokens (not all tokenizers are meant to be used in fine-tuning and don't always have this right)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.pad_token = tokenizer.eos_token

# special collator that takes care of correctly offsetting our data (so that token n predicts n+1) 
collator = DataCollatorForLanguageModeling(
            tokenizer,
            mlm=False,  # we could also "mask" random tokens to introduce some noise, but we don't need that here
            pad_to_multiple_of=8,
            return_tensors="pt",)

In [22]:
train_dataloader = DataLoader(train_data_blocks, collate_fn=collator, batch_size=batch_size)
eval_dataloader = DataLoader(test_data_blocks, collate_fn=collator, batch_size=batch_size)

### Configuring the Optimizer

The optimizer is the core of our training. It updates all parameters according to their gradient and the loss they incurred.
* We need to set a **learning rate**, which is the proportion of the gradient that gets added each step. This value tends to be low and needs experimentation to set right. We guess a value between 1/1000 and 1/10000 to start with.
* The learning rate is changed over time by a **scheduler**. This ensures we "converge" over time by learning in smaller and smaller steps.
* We could configure a **warmup** in which learning rate increases before it decreases again to make sure the model can "settle" a bit before full updates happen. We don't do this here, but if the training data is extremely different than what the model has been pre-trained on, we can "stabilize" training this way.

In [23]:
lr = 3e-4

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)  # The AdamW optimizer is well-suited for LLMs

# our schedule will decrease learning rate linearly with time
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

## The Training Loop

Here, we rely on all the bits we configured above:
* We set the model to **training mode**
* Our **DataLoader** yields training batches
* We **run** the batches through the model, the model remembers its gradients since it is in training mode and we obtain a **loss**
* We **back-propagate the loss** through the model. Now each weight knows how much it contributed to the error in the model output.
* The **Optimizer** updates the weights based on this information
* The **Scheduler** updates the optimizer's learning rate
* We repeat the above for every batch in the training set
* To check whether we converge on the test data, we run each test batch through the model and average the loss

In [24]:
for epoch in range(num_epochs):
    model.train()  # set to training mode
    total_loss = 0
    
    for step, batch in enumerate(tqdm(train_dataloader)):
        outputs = model(**batch.to(gpu))  # run batch through model
        loss = outputs.loss
        total_loss += loss.detach().cpu().float()
        loss.backward()      # propagate loss back
        optimizer.step()     # update weights
        lr_scheduler.step()  # update learning rate
        optimizer.zero_grad()

    model.eval()  # set to evaluation mode
    eval_loss = 0
    for step, batch in enumerate(tqdm(eval_dataloader)):
        with torch.no_grad():
            outputs = model(**batch.to(gpu))
        loss = outputs.loss
        eval_loss += loss.detach().cpu().float()

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    train_epoch_loss = total_loss / len(train_dataloader)
    print(f"{epoch=}: {train_epoch_loss=} {eval_epoch_loss=}")

    # save adapter to be loaded later
    peft_model.save_pretrained(f'./my-checkpoint-ep{epoch}')

100%|██████████| 301/301 [00:16<00:00, 18.09it/s]
100%|██████████| 27/27 [00:00<00:00, 39.64it/s]


epoch=0: train_epoch_loss=tensor(1.4998) eval_epoch_loss=tensor(1.2637)


100%|██████████| 301/301 [00:16<00:00, 18.15it/s]
100%|██████████| 27/27 [00:00<00:00, 39.66it/s]


epoch=1: train_epoch_loss=tensor(1.1866) eval_epoch_loss=tensor(1.2272)


In [25]:
print(generate('# <function name="test_http_404">\n'))

<bos># <function name="test_http_404">
def test_http_404(app):
    @app.route("/notfound")
    def notfound():
        return "not found", 404

    rv = app.test_client().get("/notfound")
    assert rv.status_code == 404
    assert rv.data == b"not found"
# </function><eos>
