In [None]:
from pathlib import Path
import random
import re
import textwrap
from typing import Literal

from datasets import Dataset
import optuna
from peft import get_peft_model, LoraConfig, TaskType
import torch
from torch.utils.data import DataLoader
from transformers import (
  AutoModelForCausalLM, AutoTokenizer,
  BitsAndBytesConfig,
  DataCollatorForLanguageModeling,
  EarlyStoppingCallback,
  TextGenerationPipeline,
  Trainer, TrainerCallback, TrainerControl,
  TrainerState, TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint

# Demonstration of Granite-3B-Code-Base-2K model finetuning for C++ code completion

## Introduction
In this notebook is demonstrated how a [Granite-3B-Code-Base-2K](#Granite-3B-Code-Base-2K) model can be finetuned for C++ code completion, based on *.cc* and *.h* files in a folder. The goal is to limit the required computing resources through [quantization](https://huggingface.co/docs/transformers/quantization/overview) and [parameter-efficient fine-tuning (PEFT)](https://huggingface.co/docs/peft/index). The demonstrated method utilizes the [Hugging Face Transformers](#Hugging-Face-Transformers) and [PyTorch](#PyTorch) deep learning libraries.

## Configuration

In [None]:
# General
RANDOM_SEED = 42

# Data
DATA_DIR = "data"
SAMPLE_SPAN = 3
SAMPLE_STRIDE = 1

# Optuna
OPTUNA_ENABLED = False
OPTUNA_EPOCHS = 2

# Training
BATCH_SIZE = 32
EARLY_STOP_PATIENCE = 5
LEARN_RATE = 1e-4
LEARN_RATE_SCHEDULER = "inverse_sqrt"
MAX_TRAIN_EPOCHS = 1_000_000
MODEL_NAME = "ibm-granite/granite-3b-code-base-2k"
MODELS_DIR = "models"
TRAIN_EPOCHS = 100
VALIDATION_SPLIT = 0.2
WARMUP_STEPS = 0
WEIGHT_DECAY = 0.05

# Text generation
MAX_GEN_TOKENS = 100

## Dataset retrieval

The dataset constitutes the lines of all C++ source and header files which can be found at `DATA_DIR`.

In [None]:
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
cpp_path = Path(DATA_DIR)
cpp_filepaths = list(cpp_path.glob("**/*.cc")) + list(cpp_path.glob("**/*.h"))
raw_lines = []
for filepath in cpp_filepaths:
  raw_lines += filepath.read_text().splitlines()

## Tokenizer

The [Granite-3B-Code-Base-2K](#Granite-3B-Code-Base-2K) model uses [Byte pair encoding](#Byte-pair-encoding) for its inputs. Thus, spaces, new lines and braces, which are important not just for C++ but also for many other programming languages, are part of the vocabulary and while their embeddings are "freezed", the model can still learn new and specific usages as per the codebase at `DATA_DIR`.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id
inverted_vocab = {v: k for k, v in tokenizer.get_vocab().items()}

## Dataset preprocessing

First, single-line comments, *#include*s and empty lines are removed.

Then, based on the `SAMPLE_SPAN` and `SAMPLE_STRIDE` settings, the dataset lines are combined into potentially overlapping groups of length `SAMPLE_SPAN` and offset `SAMPLE_STRIDE`.

The formed groups are then shuffled and split into training and validation sets, based on the `VALIDATION_SPLIT` setting.

In [None]:
raw_lines = [re.sub(r"([^\s]*)\s*//.*", r"\1", line)
             for line in raw_lines if line]
raw_lines = [line
             for line in raw_lines
             if line and not line.lstrip().startswith("#include")]

samples = ["\n".join(raw_lines[i:i+SAMPLE_SPAN])
           for i in range(0, len(raw_lines)-SAMPLE_SPAN+1, SAMPLE_STRIDE)]
random.shuffle(samples)

val_idx = int(VALIDATION_SPLIT * len(samples))
train_samples = samples[:-val_idx]
train_ds = Dataset.from_dict(tokenizer(train_samples))
val_samples = samples[-val_idx:]
val_ds = Dataset.from_dict(tokenizer(val_samples))
train_ds, val_ds

## Preparation for training

In order to limit the system resources required for training, on the model is performed 8-bit [quantization](https://huggingface.co/docs/transformers/quantization/overview).
In order to preserve the pretrained model weights, `PEFT` is performed, based on the model architecture.
Standard [Hugging Face Transformers](#Hugging-Face-Transformers) API for [PyTorch](#PyTorch), such as the `Trainer`, `TrainingArguments` and `TrainerCallback` classes, is utilized to configure the [Granite-3B-Code-Base-2K](#Granite-3B-Code-Base-2K) model for training.

If the `OPTUNA_ENABLED` setting is *True*, then, instead of a training process, starts a hyperparameter tuning process via [Optuna](#Optuna). Thus, settings such as learning rate value and schedule, batch size and weight decay can be tuned.

Additionally, a custom `CustomCallback` (based on `TrainerCallback`) is implemented to provide training control and, via the *evaluate_model* function, evaluation of the model's text generation capabilities on:
- a single, randomly selected, C++ code sample per training epoch;
- a custom C++ fragment during inference.

The *evaluate_model* function displays various evaluation data such as:
- the C++ code fragment;
- the tokens' representation as per the tokenizer;
- the prompt as per the *prompt_strategy* parameter, i.e., whether the first half of *cpp_text*, the second half or the whole *cpp_text* is to be used as prompt for code completion;
- the generated code without the prompt, as per the *suggest_changes* function.

The generated code is displayed only up to `MAX_GEN_TOKENS`.

In [None]:
def evaluate_model(
  model: AutoModelForCausalLM,
  tokenizer: AutoTokenizer,
  dataset: Dataset | None = None,
  cpp_text: str | None = None,
  prompt_strategy: Literal["start", "end", "all"] = "all",
) -> None:
  if dataset is None and cpp_text is None:
    raise ValueError("evaluate_model: "
                     "one of dataset and cpp_text must be set.")
  model.eval()  # type: ignore
  if dataset is not None:
    index = random.randint(0, len(dataset)-1)
    ids = dataset["input_ids"][index]
    cpp_text = tokenizer.decode(ids)  # type: ignore
  else:
    ids = tokenizer.encode(cpp_text)  # type: ignore
  print("-" * 80,
        f"[ C++ CODE ]\n\n{cpp_text}",
        sep="\n")

  tokens_repr = " ".join([inverted_vocab[i] for i in ids])
  tokens_lines = textwrap.wrap(tokens_repr, width=80,
                               expand_tabs=False,
                               replace_whitespace=False,
                               break_long_words=False,
                               break_on_hyphens=False,
                               drop_whitespace=False)
  print("-" * 80)
  print("[ TOKENS ]\n\n")
  for line in tokens_lines:
    print(line)

  if prompt_strategy == "start":
    prompt_ids = ids[:len(ids)//2]
  elif prompt_strategy == "end":
    prompt_ids = ids[len(ids)//2:]
  else:
    prompt_ids = ids
  prompt = tokenizer.decode(prompt_ids)  # type: ignore
  print("-" * 80,
        f"[ PROMPT ]\n\n{prompt}",
        sep="\n")

  changes = suggest_changes(model, tokenizer, prompt=prompt)
  print("-" * 80,
        f"[ GENERATED ]\n\n{changes}",
        "-" * 80,
        sep="\n")


def suggest_changes(
  model: AutoModelForCausalLM,
  tokenizer: AutoTokenizer,
  prompt: str,
) -> str:
  if not prompt:
    raise ValueError("suggest_changes: prompt is empty.")
  pipe = TextGenerationPipeline(model=model, tokenizer=tokenizer)
  return pipe(prompt,
              max_new_tokens=MAX_GEN_TOKENS,
              return_full_text=False)[0]["generated_text"]


training_args = TrainingArguments(
  output_dir=MODELS_DIR,
  eval_strategy="epoch",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  learning_rate=LEARN_RATE,
  lr_scheduler_type=LEARN_RATE_SCHEDULER,
  warmup_steps=WARMUP_STEPS,
  metric_for_best_model="eval_loss",
  save_strategy="no" if OPTUNA_ENABLED else "epoch",
  save_total_limit=1,
  save_only_model=False,
  seed=RANDOM_SEED,
  dataloader_drop_last=True,
  load_best_model_at_end=not OPTUNA_ENABLED,
  weight_decay=WEIGHT_DECAY,
  report_to=["none"] if OPTUNA_ENABLED else ["tensorboard"],
  push_to_hub=False,
  num_train_epochs=OPTUNA_EPOCHS if OPTUNA_ENABLED else MAX_TRAIN_EPOCHS,
)


if not OPTUNA_ENABLED:
  class CustomCallback(TrainerCallback):
    def __init__(self):
      self._remaining_train_epochs = TRAIN_EPOCHS

    def on_epoch_end(
      self,
      args: TrainingArguments,
      state: TrainerState,
      control: TrainerControl,
      model: AutoModelForCausalLM,
      processing_class: AutoTokenizer,
      train_dataloader: DataLoader | None = None,
      **kwargs,
    ) -> None:
      if train_dataloader is not None:
        evaluate_model(model, processing_class,
                       dataset=train_dataloader.dataset)  # type: ignore
      self._remaining_train_epochs -= 1
      if self._remaining_train_epochs == 0:
        control.should_training_stop = True


def model_init(trial: optuna.Trial) -> AutoModelForCausalLM:
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
  model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=quantization_config,
  )
  lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
  )
  model = get_peft_model(model, lora_config)
  trainable_params = sum(p.numel()
                         for p in model.parameters()
                         if p.requires_grad)
  print(f"Trainable parameters: {trainable_params:,}")
  return model  # type: ignore


trainer_cbs = None
if not OPTUNA_ENABLED:
  trainer_cbs = [
    CustomCallback(),  # type: ignore
    EarlyStoppingCallback(early_stopping_patience=EARLY_STOP_PATIENCE),
  ]
trainer = Trainer(
  model_init=model_init,  # type: ignore
  processing_class=tokenizer,
  args=training_args,
  train_dataset=train_ds,
  eval_dataset=val_ds,
  data_collator=DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
  ),
  callbacks=trainer_cbs,
)

## Training and evaluation

If `OPTUNA_ENABLED` is *True*, then in the `optuna_hp_space` function are configured the trial parameters. The optimization process is performed with regard to the validation loss.

If `OPTUNA_ENABLED` is *False*, then a training process starts. The final validation loss and perplexity metrics are reported once the training process completes.

In [None]:
if OPTUNA_ENABLED:
  def optuna_hp_space(trial: optuna.Trial) -> dict[str, str | float]:
    return {
      "learning_rate": trial.suggest_float(
        "learning_rate", 1e-6, 1e-4, step=1e-6
      ),
      "weight_decay": trial.suggest_float(
        "weight_decay", 0, 1e-1, step=1e-2
      ),
      "lr_scheduler_type": trial.suggest_categorical(
        "lr_scheduler_type", ["constant", "cosine",
                              "inverse_sqrt", "linear",
                              "reduce_lr_on_plateau"]
      ),
      "per_device_train_batch_size": trial.suggest_categorical(
        "per_device_train_batch_size", [32]
      ),
    }
  optuna_sampler = optuna.samplers.TPESampler(seed=RANDOM_SEED)
  best_trials = trainer.hyperparameter_search(
    hp_space=optuna_hp_space,  # type: ignore
    direction="minimize",
    backend="optuna",
    n_trials=20,
    sampler=optuna_sampler,
  )
  print(best_trials)
else:
  last_checkpoint = get_last_checkpoint(MODELS_DIR)
  trainer.train(resume_from_checkpoint=last_checkpoint)
  eval_results = trainer.evaluate()
  print(eval_results)
  print("Perplexity: "
        f"{torch.exp(torch.tensor(eval_results['eval_loss'])).item():.2f}")

## Testing

In [None]:
if not OPTUNA_ENABLED:
  evaluate_model(trainer.model, tokenizer,  # type: ignore
                 cpp_text=samples[0])

## References

<br><br>

### APA style for references
American Psychological Association. (2022). Creating an APA Style reference list guide. https://apastyle.apa.org/instructional-aids/creating-reference-list.pdf

American Psychological Association. (2024). APA Style common reference examples guide. https://apastyle.apa.org/instructional-aids/reference-examples.pdf

<br><br>

### Tokenization
#### Byte-pair encoding
Sennrich, R., Haddow, B., & Birch, A. (2015). Neural machine translation of rare words with subword units. arXiv preprint arXiv:1508.07909. https://arxiv.org/abs/1508.07909
- [Byte pair encoding - Wikipedia](https://en.wikipedia.org/wiki/Byte_pair_encoding)

<br><br>

### Machine learning models
#### Granite-3B-Code-Base-2K
Mishra, M., Stallone, M., Zhang, G., Shen, Y., Prasad, A., Soria, A.M., Merler, M., Selvam, P., Surendran, S., Singh, S., Sethi, M., Dang, X., Li, P., Wu, K., Zawad, S., Coleman, A., White, M., Lewis, M., Pavuluri, R., Koyfman, Y., Lublinsky, B., Bayser, M.D., Abdelaziz, I., Basu, K., Agarwal, M., Zhou, Y., Johnson, C., Goyal, A., Patel, H., Shah, Y., Zerfos, P., Ludwig, H., Munawar, A., Crouse, M., Kapanipathi, P., Salaria, S., Calio, B., Wen, S., Seelam, S.R., Belgodere, B.M., Fonseca, C., Singhee, A., Desai, N., Cox, D.D., Puri, R., & Panda, R. (2024). Granite Code Models: A Family of Open Foundation Models for Code Intelligence. ArXiv, abs/2405.04324. https://arxiv.org/abs/2405.04324

<br><br>

### Guides and tutorials
- [ibm-granite/granite-3b-code-base-2k · Hugging Face](https://huggingface.co/ibm-granite/granite-3b-code-base-2k)
- [Hugging Face - Documentation](https://huggingface.co/docs)

<br><br>

### Libraries
#### Hugging Face Transformers
Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., Cistac, P., Ma, C., Jernite, Y., Plu, J., Xu, C., Le Scao, T., Gugger, S., Drame, M., Lhoest, Q., & Rush, A. M. (2020). Transformers: State-of-the-Art Natural Language Processing [Conference paper]. 38–45. https://www.aclweb.org/anthology/2020.emnlp-demos.6
- [Transformers](https://huggingface.co/docs/transformers/index)

#### PyTorch
Ansel, J., Yang, E., He, H., Gimelshein, N., Jain, A., Voznesensky, M., Bao, B., Bell, P., Berard, D., Burovski, E., Chauhan, G., Chourdia, A., Constable, W., Desmaison, A., DeVito, Z., Ellison, E., Feng, W., Gong, J., Gschwind, M., Hirsh, B., Huang, S., Kalambarkar, K., Kirsch, L., Lazos, M., Lezcano, M., Liang, Y., Liang, J., Lu, Y., Luk, C., Maher, B., Pan, Y., Puhrsch, C., Reso, M., Saroufim, M., Siraichi, M. Y., Suk, H., Suo, M., Tillet, P., Wang, E., Wang, X., Wen, W., Zhang, S., Zhao, X., Zhou, K., Zou, R., Mathews, A., Chanan, G., Wu, P., & Chintala, S. (2024). PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation [Conference paper]. 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS '24). https://doi.org/10.1145/3620665.3640366
- [Start Locally | PyTorch](https://pytorch.org/get-started/locally)

<br><br>

### Tools
#### Optuna
Akiba, T., Sano, S., Yanase, T., Ohta, T., & Koyama, M. (2019). Optuna: A next-generation hyperparameter optimization framework [Conference paper]. *Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining*, 2623–2631. https://doi.org/10.1145/3292500.3330701
- [Optuna: A hyperparameter optimization framework](https://optuna.readthedocs.io/en/stable/)