# Building a model with fine-tuning and interpolation

## Environment variables

In [None]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["NCCL_DEBUG"] = "INFO"

## Configure model

In [None]:
model_path = "ibm-granite/granite-3.3-8b-instruct"

model_name = os.path.basename(model_path)

## Configure data

Configure `data_name` such that the message data file is `message_data_${data_name}.jsonl`.

In [None]:
data_name = "ibm-annual-report"

_data_name = f"_{data_name}" if data_name is not None and len(data_name) > 0 else ""

messages_data_path = f"messages_data{_data_name}.jsonl"

force_process_data = False

## Configure fine-tuning

In [None]:
num_epochs = 3
save_samples = 0
keep_last_checkpoint_only = False

## Configure interpolation

In [None]:
trained_model_weight = 0.5

## Fine-tuning

In [None]:
import torch

assert torch.cuda.is_available()
nproc_per_node = torch.cuda.device_count()
print(f"nproc_per_node: {nproc_per_node}", flush=True)

nnodes = 1
print(f"nnodes: {nnodes}", flush=True)

In [None]:
chat_tmpl_dir = "../src/instructlab/training/chat_templates"
if "granite" in model_name:
    chat_tmpl_path = f"{chat_tmpl_dir}/ibm_generic_tmpl.py"
else:
    chat_tmpl_path = None

ckpt_output_dir = f"experiments/training_output-{model_name}{_data_name}"
processed_data_dir = f"data/processed-data-{model_name}{_data_name}"

process_data = (
    not os.path.isfile(f"{processed_data_dir}/data.jsonl") or force_process_data
)

For fine-tuning, we use the Instructlab Training library, built for optimal and efficient fine-tuning on any messages-format data. Using the python interface, we are able to launch the model training.

In this case, we ensure that we install off of main, to get the latest generic Causal LM support:

In [None]:
# %%capture
# %pip install git+https://github.com/instructlab/training.git@main

We start by importing the necessary pieces from the library:

In [None]:
from instructlab.training.config import (
    TorchrunArgs,
    TrainingArgs,
    DistributedBackend,
    FSDPOptions,
)
from instructlab.training.main_ds import run_training

We then define our distributed settings via TorchrunArgs. In our case, we trained on a single node with 8 H100 GPUs:

In [None]:
torch_args = TorchrunArgs(
    nproc_per_node=nproc_per_node,
    nnodes=nnodes,
    node_rank=0,
    rdzv_id=123,
    rdzv_endpoint="0.0.0.0:8888",
)

We then set our model and data paths, checkpoint output path, and hyperparameters via the TrainingArgs object:

In [None]:
train_args = TrainingArgs(
    model_path=model_path,
    chat_tmpl_path=chat_tmpl_path,
    data_path=messages_data_path,
    ckpt_output_dir=ckpt_output_dir,
    data_output_dir=processed_data_dir,  # processed data ids/labels/masks
    max_seq_len=20000,
    max_batch_len=30000,  # max tokens per gpu
    num_epochs=num_epochs,
    effective_batch_size=256,  # target batch size per model update
    learning_rate=2e-5,
    warmup_steps=25,
    save_samples=save_samples,  # save ckpt after num of samples seen (0=off)
    checkpoint_at_epoch=True,  # save ckpt after every epoch
    accelerate_full_state_at_epoch=False,  # save full-state for resuming
    process_data=process_data,  # can set to false if data processed before
    keep_last_checkpoint_only=keep_last_checkpoint_only,
    distributed_backend=DistributedBackend.FSDP,
    fsdp_options=FSDPOptions(cpu_offload_params=False),
)

Finally, we kick off SFT via the run_training function:

In [None]:
print("Start training", flush=True)

run_training(torch_args=torch_args, train_args=train_args)

print("Finished training", flush=True)

Upon completion, we have `{num_epochs}` Huggingface-Format checkpoints in `{ckpt_output_dir}/hf_format`. The full run logs and metrics will also be recorded in `{ckpt_output_dir}`. Running the final training as a python script rather than in a notebook may help with progress bar writing to stdout.

## Interpolation

When the training is completed successfully, we will interpolate the last checkpoint with the original model to recover the capability that may have been lost during the training process. `{output_model_path}` will be `{trained_model_path}-interp` by default.

We can also interpolate models manually as follows.
```sh
python interpolator.py --model_path {model_path} --trained_model_path {trained_model_path} --trained_model_weight {trained_model_weight}
```

In [None]:
import glob


def find_last_checkpoint(ckpt_output_dir: str) -> str | None:
    last_checkpoint_path = None

    # For keep_last_checkpoint_only is True
    # See https://github.com/instructlab/training/blob/4eb4173f2508dc1fd8db7e30b59609f0ceeb25ac/src/instructlab/training/config.py#L229
    ckpt_dirs = glob.glob(f"{ckpt_output_dir}/hf_format/last_epoch")
    for ckpt_dir in ckpt_dirs:
        last_checkpoint_path = ckpt_dir

    # For keep_last_checkpoint_only is False
    if last_checkpoint_path is None:
        ckpt_dirs = glob.glob(f"{ckpt_output_dir}/hf_format/samples_*")
        samples_len = len("samples_")
        max_num_samples = -1
        for ckpt_dir in ckpt_dirs:
            if not os.path.isdir(ckpt_dir):
                continue
            num_samples_str = os.path.basename(ckpt_dir)[samples_len:]
            try:
                num_samples = int(num_samples_str)
            except ValueError:
                continue
            if max_num_samples < num_samples:
                max_num_samples = num_samples
                last_checkpoint_path = ckpt_dir

    return last_checkpoint_path

In [None]:
trained_model_path = find_last_checkpoint(ckpt_output_dir)

if trained_model_path is not None:
    from interpolator import interpolate_models

    print(f"Trained model path: {trained_model_path}")

    output_model_path = interpolate_models(
        model_path, trained_model_path, trained_model_weight=trained_model_weight
    )

    print(f"Output model path: {output_model_path}")