# Fine-tuning a Code LLM on Custom Code on a single GPU

In this example, we will fine-tune a code LLM on private code bases to enhance its contextual awareness and improve a model's usefulness to any organization's needs.

## Setups

In [None]:
!pip install -qU transformers datasets peft bitsandbytes flash-attn

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Dataset

In this example, we will use the [`smangrul/hf-stack-v1`](https://huggingface.co/datasets/smangrul/hf-stack-v1) dataset, in which the author picked the top 10 HuggingFace public repositories on GitHub. This dataset has excluded non-code files from the data, such as images, audio files, presentations, and so on. For Jupyter notebooks, it only kept cells containing code. The final dataset contains repo id, file path, and file content.

## Model

We will finetune [`bigcode/starcoderbase-1b`](https://huggingface.co/bigcode/starcoderbase-1b), which is a 1B parameter model trained on 80+ programming languages.

We need to define some variables.

In [None]:
MODEL = 'bigcode/starcoderbase-1b'
DATASET = 'smangrul/hf-stack-v1'
DATA_COLUMN = 'content' # column name containing the code content

SEQ_LENGTH = 2048

# Training arguments
MAX_STEPS = 2000
BATCH_SIZE = 16
GR_ACC_STEPS = 1 # gradient_accumulation_steps
LR = 5E-4
LR_SCHEDULER_TYPE = 'cosine'
WEIGHT_DECAY = 0.01
NUM_WARMUP_STEPS = 30
EVAL_FREQ = 100
SAVE_FREQ = 100
LOG_FREQ = 25
OUTPUT_DIR = 'peft-starcoder-lora'
BF16 = True
FP16 = False

# FIM transformations arguments
FIM_RATE = 0.5
FIM_SPM_RATE = 0.5

# LoRA
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.0
LORA_TARGET_MODULES = 'c_proj, c_attn, q_attn, c_fc'

# bitsandbytes
USE_NESTED_QUANT = True
BNB_4BIT_COMPUTE_DTYPE = 'bfloat16'

SEED = 111

In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
    set_seed,
    BitsAndBytesConfig
)

set_seed(SEED)

## Prepare the data

As the dataset is to be quite large, we can enable the streaming mode. Streaming allows us to load the data progressively as we iterate over the dataset instead of downloading the whole dataset at once.

We will reserve the first 4000 examples as the validation set, and everything else will be the training data.

In [None]:
from datasets import load_dataset
import torch
from tqdm import tqdm

dataset = load_dataset(
    DATASET,
    data_dir='data',
    split='train',
    streaming=True
)
dataset

In [None]:
valid_data = dataset.take(4000)
train_data = dataset.skip(4000)
train_data = train_data.shuffle(buffer_size=5000, seed=SEED)

The dataset now contains raw data with code of arbitrary length. For training, we need inputs of fixed length, so we need to create an iterable dataset that would return constant-length chunks of tokens from a stream of text files.

To help us estimate the number of tokens in the text buffer, we need to estimate the average number of characters per token in the dataset. We will take 400 examples from the dataset to provide a reasonable estimate of the overall character-to-token ratio.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)


def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
    """Estimate the average number of characters per token in the dataset"""

    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        total_characters += len(example[data_column])
        total_tokens += len(tokenizer(example[data_column]).tokens())

    return total_characters / total_tokens

In [None]:
chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")

The character-to-token ratio can also be used as an indicator of the quality of text tokenization. For example, a character-to-token ratio of 1.0 means that each character is represented with a token, which is not very meaningful and indicates poor tokenization. In standard English text, one token is typically equivalent to approximately four characters, meaning the character-to-token ratio is around 4.0. We can expect a lower ratio in the code dataset, but generally speaking, a number between 2.0 and 3.5 can be considered good enough.

### Optional FIM transformations

Autoregressive language models typically generate sequences from left to right. By applying the FIM transformations from the paper [*Efficient Training of Language Models to Fill in the Middle*](https://arxiv.org/abs/2207.14255), the model can also learn to infill text.

We will define the FIM transformations and use them when creating the iterable dataset. We can skip FIM transformations by setting `fim_rate = 0`.

In [None]:
import functools
import numpy as np

# Helper function to get token ids of the special tokens for prefix, suffix, and middle for FIM transformtaions
@functools.lru_cache(maxsize=None)
def get_fim_token_ids(tokenizer):
    try:
        FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map['additional_special_tokens'][1:5]
        suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
            tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
        )
    except KeyError:
        suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None

    return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id



def permute(sample, np_rng, suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id,
            fim_rate=0.5, fim_spm_rate=0.5, truncate_or_pad=False):
    """Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate.
    Using two FIM modes:
    PSM and SPM (with a probability of fim_spm_rate)
    """
    # FIM transformations will apply to samples with a probability of fim_rate
    if np_rng.binomial(1, fim_rate):
        # Split the sample into prefix, middle, and suffix, based on randomly generated indices
        # stored in the `boundaries` list
        boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))
        boundaries.sort()

        prefix = np.array(sample[: boundaries[0]], dtype=np.int64)
        middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)
        suffix = np.array(sample[boundaries[1] :], dtype=np.int64)

        if truncate_or_pad:
            # Calculate the new total length of the sample,
            # taking into account tokens indicating prefix, middle, and suffix
            new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
            diff = new_length - len(sample)

            # Truncate or pad if there is difference in length between the new length and the original
            if diff > 0:
                if suffix.shape[0] <= diff:
                    return sample, np_rng

                suffix = suffix[: suffix.shape[0] - diff]
            elif diff < 0:
                suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])

        # With the probability of fim_spm_rate, apply SPM variant of FIM transformations
        # SPM: suffix, prefix, middle
        if np_rng.binomial(1, fim_spm_rate):
            new_sample = np.concatenate([
                [prefix_tok_id, suffix_tok_id],
                suffix,
                [middle_tok_id],
                prefix,
                middle
            ])
        # Otherwise, apply the PSM variant of FIM transformations
        # PSM: prefix, suffix, middle
        else:
            new_sample = np.concatenate([
                [prefix_tok_id],
                prefix,
                [suffix_tok_id],
                suffix,
                [middle_tok_id],
                middle
            ])

    else:
        # do not apply FIM transformations
        new_sample = sample

    return list(new_sample), np_rng

Next, we will define the `ConstantLengthDataset`, an iterable dataset that will return constnat-length chunks of tokens.

In [None]:
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
import random

# Create an Iterable Dataset that returns constant-length chunks of tokens
# from a stream of text files.
class ConstantLengthDataset(IterableDataset):
    """Iterable dataset that returns constant-length chunks of tokens from a stream of text files"""
    def __init__(self, tokenizer, dataset, infinite=False, seq_length=1024,
                 num_of_sequences=1024, chars_per_token=3.6, content_field='content',
                 fim_rate=0.5, fim_spm_rate=0.5, seed=0):
        """
        Parameters
        ----------
        tokenizer: Tokenizer
            The processor used for proccessing the data
        dataset: dataset.Dataset
            Dataset with text files
        infinite: bool
            If True the iterator is reset after dataset reaches end else stops
        seq_length: int
            Length of token sequences to return
        num_of_sequences: int
            Number of token sequences to keep in buffer
        chars_per_token: int:
            Number of characters per token used to estimate number of tokens in text buffer
        fim_rate: float
            Rate (0.0 to 1.0) that sample will be permuted with FIM
        fim_spm_rate: float
            Rate (0.0 to 1.0) of FIM permutations that will use SPM
        seed: int
            Seed for random number generator
        """
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id
        self.dataset = dataset
        self.seq_length = seq_length
        self.infinite = infinite
        self.current_size = 0
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
        self.content_field = content_field
        self.fim_rate = fim_rate
        self.fim_spm_rate = fim_spm_rate
        self.seed = seed

        self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id = get_fim_token_ids(self.tokenizer)
        if not self.suffix_tok_id and self.fim_rate > 0:
            print('FIM is not supported by the tokenizer, disabling FIM')
            self.fim_rate = 0


    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        np_rng = np.random.RandomState(seed=self.seed)

        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(next(iterator)[self.content_field])
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, truncation=False)['input_ids']
            all_token_ids = []

            for tokenized_input in tokenized_inputs:
                # optionally do FIM permutations
                if self.fim_rate > 0:
                    tokenized_input, np_rng = permute(
                        tokenized_input,
                        np_rng,
                        self.suffix_tok_id,
                        self.prefix_tok_id,
                        self.middle_tok_id,
                        self.pad_tok_id,
                        fim_rate=self.fim_rate,
                        fim_spm_rate=self.fim_spm_rate,
                        truncate_or_pad=False
                    )

                all_token_ids.extend(tokenized_input + [self.concat_token_id])

            excamples = []
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    examples.append(input_ids)

            random.shuffle(examples)
            for example in examples:
                self.current_size += 1
                yield {
                    'input_ids': torch.LongTensor(example),
                    'labels': torch.LongTensor(example)
                }

In [None]:
train_dataset = ConstantLengthDataset(
    tokenizer,
    train_data,
    infinite=True,
    seq_length=SEQ_LENGTH,
    chars_per_token=chars_per_token,
    content_field=DATA_COLUMN,
    fim_rate=FIM_RATE,
    fim_spm_rate=FIM_SPM_RATE,
    seed=SEED
)

eval_dataset = ConstantLengthDataset(
    tokenizer,
    valid_data,
    infinite=False,
    seq_length=SEQ_LENGTH,
    chars_per_token=chars_per_token,
    content_field=DATA_COLUMN,
    fim_rate=FIM_RATE,
    fim_spm_rate=FIM_SPM_RATE,
    seed=SEED
)

## Prepare the model

We will load the quantized model to reduce memory usage.

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer

load_in_8bit = False

# 4bit quantization
compute_dtype = getattr(torch. BNB_4BIT_COMPUTE_DTYPE)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=USE_NESTED_QUANT,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=compute_dtype
)

device_map = {"": 0}

model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    load_in_8bit=load_in_8bit,
    quantization_config=bnb_config,
    device_map=device_map,
    use_cache=False, # we will use gradient checkpointing
    trust_remote_code=True,
    use_flash_attention_2=True
)

When using a quantized model for training, we need to call the `prepare_model_for_kbit_training()` to preprocess the quantized model for training:

In [None]:
model = prepare_model_for_kbit_training(model)

Now that the quantized model is ready, we can set up a LoRA configuration. To train a model using LoRA, we need to wrap the base model as a `PeftModel`, which involves defining LoRA configuratin with `LoraConfig`, and wrapping the original model with `get_peft_model()`:

In [None]:
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    r=LORA_R,
    bias='none',
    task_type='CAUSAL_LM',
    target_modules=LORA_TARGET_MODULES.split(',')
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## Train the model

In [None]:
train_data.start_iteration = 0

training_args = TrainingArguments(
    output_dir=f"{OUTPUT_DIR}",
    dataloader_drop_last=True,
    evaluation_strategy='steps',
    save_strategy='steps',
    max_steps=MAX_STEPS,
    eval_steps=EVAL_FREQ,
    save_steps=SAVE_FREQ,
    logging_steps=LOG_FREQ,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LR,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    warmup_steps=NUM_WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY,
    gradient_accumulation_steps=GR_ACC_STEPS,
    gradient_checkpointing=True,
    bf16=BF16,
    fp16=FP16,
    push_to_hub=False,
    include_tokens_per_second=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

print("Training...")
trainer.train()

## Inference

In [None]:
from peft import PeftModel
import torch

# load the original model first
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    quantization_config=None,
    device_map=None,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).cuda()

# merge fine-tuned weights with the base model
peft_model_id = f"Your_HF_username/{OUTPUT_DIR}"
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.merge_and_unload()

Now we can use the merged model for inference.

In [None]:
def get_code_completion(prefix, suffix):
    text = f"""<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"""
    model.eval()
    outputs = model.generate(
        input_ids=tokenizer(text, return_tensors='pt').input_ids.cuda(),
        max_new_tokens=128,
        temperature=0.2,
        top_k=50,
        top_p=0.95,
        do_sample=True,
        repetition_penalty=1.0
    )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

In [None]:
prefix = """from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
"""
suffix = """"""

print(get_code_completion(prefix, suffix))

In [None]:
prefix = """from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
"""
suffix = """"""

print(get_code_completion(prefix, suffix))