# Fine-Tune LLaMa on Koyeb

### Step 0: Install dependencies and login with HuggingFace and WandB

A popup will appear after running the cell below. You need to paste your HuggingFace API token and your WandB API key. Make sure you have requested access to [LLaMa 3.1 8B Instruct](meta-llama/Meta-Llama-3.1-8B-Instruct) on HuggingFace.

In [1]:
!pip install llama-recipes ipywidgets wandb > /dev/null

import huggingface_hub
huggingface_hub.login()

Optionally, you can login with WandB to track the model's loss during training.

In [None]:
import wandb
wandb.login()

### Step 1: Load the Model

In this step, we load the model and tokenizer from the HuggingFace Hub. We define the training config and the hyperparameters we'll use to fine-tune the model.

In [None]:
import torch
from transformers import LlamaForCausalLM, AutoTokenizer
from llama_recipes.configs import train_config as TRAIN_CONFIG

train_config = TRAIN_CONFIG()
train_config.model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
train_config.num_epochs = 1
train_config.run_validation = False
train_config.gradient_accumulation_steps = 4
train_config.batch_size_training = 1
train_config.lr = 3e-4
train_config.use_fast_kernels = True
train_config.use_fp16 = True
train_config.context_length = 4096
train_config.batching_strategy = "packing"
train_config.output_dir = "Meta-Llama-3.1-8B-Instruct-Apple-MLX"

from transformers import BitsAndBytesConfig
config = BitsAndBytesConfig(
    load_in_8bit=True,
)

model = LlamaForCausalLM.from_pretrained(
    train_config.model_name,
    device_map="auto",
    quantization_config=config,
    use_cache=False,
    attn_implementation="sdpa" if train_config.use_fast_kernels else None,
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)
tokenizer.pad_token = tokenizer.eos_token

### Step 2: Test Model Generation

Let's now generate some text using our non-fine-tuned model to see how it performs. We'll ask it to explain how to compute the Fast Fourier Transform operation in MLX.

In [None]:
SYSTEM_PROMPT = "You are a helpful AI coding assistant with expert knowledge of Apple's latest machine learning framework: MLX. You can help answer questions about MLX, provide code snippets, and help debug code."

def complete_chat(model, tokenizer, messages, **kwargs):
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True).to(model.device)
    num_input_tokens = len(inputs["input_ids"][0])
    model.eval()
    with torch.no_grad():
        return tokenizer.decode(model.generate(**inputs, **kwargs)[0][num_input_tokens:], skip_special_tokens=True)


def complete_chat_single_turn(model, tokenizer, user: str, **kwargs):
    return complete_chat(model, tokenizer, [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user},
    ], **kwargs)

print(complete_chat_single_turn(model, tokenizer, "How do I compute the fast fourrier transform for a signal in MLX?", max_new_tokens=128))

As expected, the model has no knowledge of Apple MLX. It hallucinates a response because Apple MLX was released in December of 2023 which corresponds to LLaMa 3.1's training data cutoff. Hence, the model has seen very few samples related to Apple MLX during training.

### Step 3: Define and Load our Custom Dataset

Let's load the dataset we've built in the previous steps. We'll define a custom function to process each sample in our dataset, tokenize it and return samples in the format expected by the model. You can read through the comments in the code below to understand how samples are processed.

In [None]:
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.utils.config_utils import get_dataloader_kwargs
from llama_recipes.utils.dataset_utils import DATASET_PREPROC, get_preprocessed_dataset
from copy import deepcopy
from dataclasses import dataclass
import datasets
import torch


# We define our custom dataset preprocessing function. It loads our dataset from the Hub,
# tokenizes each sample according to LLaMa 3.1's chat template and masks the loss for the
# system prompt and the user prompt.
def get_apple_mlx_qa_dataset(dataset_config, tokenizer, split_name):
    dataset = datasets.load_dataset(
        "koyeb/Apple-MLX-QA", split="train" if split_name == "train" else "test"
    )

    def apply_chat_template(sample):
        return {
            "input_ids": tokenizer.apply_chat_template(
                [
                    {
                        "role": "system",
                        "content": SYSTEM_PROMPT,
                    },
                    {
                        "role": "user",
                        "content": sample["question"],
                    },
                    {
                        "role": "assistant",
                        "content": sample["answer"],
                    },
                ],
                tokenize=True,
                add_generation_prompt=False,
            )
        }

    dataset = dataset.map(
        apply_chat_template,
        remove_columns=list(dataset.features),  # type: ignore
    )

    def create_labels_with_mask(sample):
        labels = deepcopy(sample["input_ids"])

        # The EOT token marks the end of a turn in a conversation.
        # In our case, the first EOT comes after the system prompt, the second
        # after the user prompt, and the third after the assistant answer.
        # > [system prompt] EOT [user prompt] EOT [assistant answer] EOT
        eot = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        indices = [i for i, token in enumerate(sample["input_ids"]) if token == eot]
        assert len(indices) == 3, f"{len(indices)} != 3. {sample['input_ids']}"

        # Mask the loss for the system prompt and the user prompt. We don't want
        # the model to predict the question, only the answer.
        labels[0 : indices[1] + 1] = [-100] * (indices[1] + 1)
        assert len(labels) == len(
            sample["input_ids"]
        ), f"{len(labels)} != {len(sample['input_ids'])}"

        return {"labels": labels}

    dataset = dataset.map(create_labels_with_mask)

    def convert_to_tensors(sample):
        return {
            "input_ids": torch.LongTensor(sample["input_ids"]),
            "labels": torch.LongTensor(sample["labels"]),
            "attention_mask": torch.tensor([1] * len(sample["labels"])),
        }

    dataset = dataset.map(convert_to_tensors)

    return dataset


# To use a custom dataset with LLaMa Recipes, you need to define a custom dataclass
# that contains information about the dataset.
@dataclass
class apple_mlx_qa_dataset:
    dataset: str =  "apple_mlx_qa_dataset"
    train_split: str = "train"
    test_split: str = "test"
    trust_remote_code: bool = False

# Then, you need to register the dataset preprocessing function in the `DATASET_PREPROC` dictionary.
DATASET_PREPROC["apple_mlx_qa_dataset"] = get_apple_mlx_qa_dataset

# Finally, we define a utility function to create a PyTorch dataloader from a split of our dataset.
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
    dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
    dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
    
    if split == "train" and train_config.batching_strategy == "packing":
        dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)

    # Create data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        num_workers=train_config.num_workers_dataloader,
        pin_memory=True,
        **dl_kwargs,
    )
    return dataloader


train_dataloader = get_dataloader(tokenizer, apple_mlx_qa_dataset, train_config, "train")
eval_dataloader = get_dataloader(tokenizer, apple_mlx_qa_dataset, train_config, "test")

### Step 4: Prepare Model for Paramater-Efficient-Fine-Tuning (PEFT)

We can use the `peft` library from HuggingFace to train only a subset of the parameters of our model. This helps with reducing the training time and memory requirements. Furthermore, it can prevent catastrophic forgetting which is a phenomenon where the model forgets what it learned during pre-training when fine-tuning on a new dataset.

In [5]:
from peft import get_peft_model, prepare_model_for_kbit_training, LoraConfig
from dataclasses import asdict
from llama_recipes.configs import lora_config as LORA_CONFIG

lora_config = LORA_CONFIG()
lora_config.r = 8
lora_config.lora_alpha = 32
lora_dropout: float = 0.01

peft_config = LoraConfig(**asdict(lora_config))

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

### Step 5: Fine-Tune the Model

We'll now fine-tune the model on our custom dataset. We'll use the `train` function from LLaMa Recipes and pass in our model and dataloader. If you have logged in with WandB, you will be able to track the training process on the [WandB dashboard](https://wandb.ai).

In [None]:
import torch.optim as optim
from llama_recipes.utils.train_utils import train
from torch.optim.lr_scheduler import StepLR


wandb_run = wandb.init(project="finetune-llama-on-koyeb")
wandb_run.config.update(train_config)

model.train()

optimizer = optim.AdamW(
    model.parameters(),
    lr=train_config.lr,
    weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

# Start the training process
results = train(
    model,
    train_dataloader,
    eval_dataloader,
    tokenizer,
    optimizer,
    scheduler,
    train_config.gradient_accumulation_steps,
    train_config,
    None,
    None,
    None,
    wandb_run,
)

### Step 6: Check Model Generation

Now that the fine-tuning is complete, let's see if our model is able to help us compute the Fast Fourier Transform operation in MLX.

In [None]:
print(complete_chat_single_turn(model, tokenizer, "How do I compute the fast fourrier transform for a signal in MLX?", max_new_tokens=128))

Awesome! It now uses the correct API for MLX. You can experiment with different prompts, evaluate the model's performance and try to identify gaps in its knowledge.

### Step 7: Push Model to HuggingFace Hub

If you're happy with the model's outputs, you can save it to the HuggingFace Hub. This will allow you to share it with the community and use it in your projects.

In [None]:
hf_model_name = "Meta-Llama-3.1-8B-Instruct-Apple-MLX-Adapter"
hf_org = input("Enter the HuggingFace organization you want to push the model to: ")
model.push_to_hub(f"{hf_org}/{hf_model_name}")

This will push only the LORA adapater's weights to the HuggingFace Hub (about ~10MB). This is enough to use the model in your Python code, however, it's more convenient to merge the LORA adapter with the base model. You can do this by running the code below.

In [None]:
model = model.merge_and_unload()
model_name = "Meta-Llama-3.1-8B-Instruct-Apple-MLX"
model.push_to_hub(f"{hf_org}/{model_name}")