# Fine-tune a Large Language Model with LoRA

This is a part of Lab 6 of the [EE292D Edge ML class](https://ee292d.github.io/) at Stanford, which covers parameter-efficient fine-tuning and deployment of LLMs.

You'll need a GPU for this exercise. As with previous labs, we recommend you access them for free on Colab. [Click here](https://colab.research.google.com/github/ee292d/labs/blob/main/lab6/notebook.ipynb) to open this notebook in a Colab instance, then change your runtime type to GPU.

## Overview

Our goal is to fine-tune a small LLM for a new task, then prepare it for deployment on a Raspberry Pi. In this example, we will fine-tune a base model that has been pre-trained for _completion_ (i.e., to predict the next words in the input sentence) so that we can use it for _chat_.

In [1]:
!pip install datasets peft trl accelerate -U



## Choosing a Base Model

We'll work with a lightweight base model: [Phi-2](https://huggingface.co/microsoft/phi-2). At 2.7B parameters, Phi-2 can fit in about 5GB of RAM when loaded at 16-bit precision.

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
   "microsoft/phi-1_5",
   torch_dtype=torch.bfloat16,
   trust_remote_code=True
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/phi-1_5",
    trust_remote_code=True
)

Now that we have the model loaded, we can try an input:


In [12]:
inputs = tokenizer(
    "Softly sewn",
    return_tensors="pt"
)

inputs.to("cuda")

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
text

'Softly sewn. / A patchwork quilt of love. / A quilt of memories. / A quilt of dreams. / A quilt of life. / A quilt of you. / A quilt of me. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt of us. / A quilt'

[exposition]


## Preparing a Fine-tuning Dataset

A fine-tuning dataset consists of a set of examples

In [2]:
from datasets import load_dataset

dataset = load_dataset("statworx/haiku")["train"].select(range(1000)).train_test_split(test_size=0.2)
dataset

DatasetDict({
    train: Dataset({
        features: ['source', 'text', 'text_phonemes', 'keywords', 'keyword_phonemes', 'gruen_score', 'text_punc'],
        num_rows: 800
    })
    test: Dataset({
        features: ['source', 'text', 'text_phonemes', 'keywords', 'keyword_phonemes', 'gruen_score', 'text_punc'],
        num_rows: 200
    })
})

In [9]:
dataset['train'][0]

{'source': 'bfbarry',
 'text': 'Failing to warm you. / Swimming through a sea of stars. / Of imagined paths.',
 'text_phonemes': 'fey|lihng tax waorm yuw / swih|maxng thruw ax siy ahv staarz / ahv ax|mae|jhaxnd paedhz',
 'keywords': 'swimming through',
 'keyword_phonemes': 'swih|maxng thruw',
 'gruen_score': 0.684268266,
 'text_punc': None}

## Fine-tuning

In [7]:
from transformers import TrainingArguments
from peft import get_peft_model, LoraConfig, PeftModel
from trl import SFTTrainer
import re

# model.gradient_checkpointing_enable()

output_dir = './'

peft_config = LoraConfig(
    r=16,
    lora_alpha=8,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

lora_model = get_peft_model(model, peft_config)

training_args = TrainingArguments(
    output_dir=output_dir,
    save_strategy='epoch',
    logging_steps=10,
    per_device_eval_batch_size=2,
    per_device_train_batch_size=2
)

tokenizer.pad_token = tokenizer.eos_token

trainer = SFTTrainer(
    model=lora_model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=2048,
    tokenizer=tokenizer,
    args=training_args
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
10,5.0368
20,5.082
30,4.5348
40,4.8391
50,4.6746
60,4.122
70,4.1411
80,3.6757
90,4.0212
100,4.1701


Checkpoint destination directory ./checkpoint-400 already exists and is non-empty. Saving will proceed but saved results may be invalid.


KeyboardInterrupt: 

## Merging Weights