<a href="https://colab.research.google.com/github/mgfrantz/CTME-llm-lecture-resources/blob/main/labs/fine_tuning_and_inference_with_axolotl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine tuning and inference

In this lab, we're going to demonstrate the process of fine tuning.
Note that this **does not mean we will have a better, cheaper, or faster model than API providers.**
Given the small amnount of data we have and the time/resource constraints of this lab setting, we can't shoot for a highly performant model.
But we will have a **general design pattern for fine tuning**, and we will see that the fine tuned model is not as bad as the base model for fine tuning.

[Axolotl](https://github.com/axolotl-ai-cloud/axolotl) is a convenient library that helps fine tune text generation models.
In this notebook, we will use `axolotl` to fine tune a small LLM on a dataset we've created.

We will be using the small and open Llama 3.2 1b. model today.
Here's are agenda:

- Run the model in the notebook to demonstrate that it cannot do anything we want it to do out of the box.
- Load the conversations we generated yesterday and prepare them for training by converting them into the ChatML format and tokenizing them.
- Fine-tune the llama model using QLoRA
- Export the model to .gguf so we can run it anywhere with `ollama` or `llama.cpp`
- Test our model in our agent to demonstrate that it is better than the base model
- Push our model artifacts to Huggingface so they can be run anywhere

# Setup

## Ollama

We will be using ollama for local inference.
To set it up, open the Colab terminal ↙.
Then, run the following commands:

```
curl -fsSL https://ollama.com/install.sh | sh # install ollama
ollama serve &                                # start the ollama server and return the terminal
ollama pull llama3.2:3b                       # pull llama3.2:1b (the model we'll be fine-tuning)
```

This will pull the model we'll be testing against.

## Installs

In [None]:
!pip install -qqqq \
    huggingface_hub \
    bitsandbytes \
    accelerate \
    "transformers[torch]" \
    llama-cpp-python \
    vllm \
    "llama-index>=0.11.17" \
    "llama-index-core>=0.10.43" \
    "openinference-instrumentation-llama-index>=2" \
    "opentelemetry-proto>=1.12.0" \
    arize-phoenix-otel \
    nest-asyncio \
    llama-index-callbacks-arize-phoenix \
    llama-index-readers-database \
    llama-index-llms-openai \
    llama-index-embeddings-fastembed \
    llama-index-readers-database \
    fastembed-gpu \
    llama-index-llms-ollama \
    llama-index-agent-openai \
    --progress-bar off

# Clone llama.cpp for conversion to gguf
!git clone https://github.com/ggerganov/llama.cpp.git

## Imports

In [None]:
import sqlite3
import pandas as pd
from IPython.display import display
from rich import print
from typing import Literal, List
from pydantic import BaseModel, Field
from time import sleep
from sqlalchemy import create_engine
from google.colab import userdata
import json
import os
import phoenix as px
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
from openinference.instrumentation import using_metadata
from phoenix.otel import register
from enum import Enum
from tqdm.auto import tqdm
from llama_index.core import VectorStoreIndex, Document
from llama_index.core.tools import FunctionTool, QueryEngineTool, RetrieverTool
from llama_index.core.agent import ReActAgent
from llama_index.core.llms import ChatMessage
from llama_index.core import PromptTemplate
from llama_index.readers.database import DatabaseReader
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.llms.ollama import Ollama
# from llama_index.llms.huggingface import HuggingFaceLLM
# from llama_index.llms.llama_cpp import LlamaCPP
# from llama_index.llms.vllm import Vllm
from transformers import BitsAndBytesConfig
import asyncio
import nest_asyncio

token = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = token

## All the code to set up our agent 👇

Since we have all the code we used to run our agent yesterday, let's bring that over so we can use it in this notebook.
Here is the solution code, but if your code differs feel free to replace it with your own.

In [None]:
if not os.path.exists('drive/MyDrive/CTME-LLM-labs/ecommerce.db'):
    print("The ecommerce.db database does not exist. Please make sure you're connected to Google Drive or upload it to the Colab notebook or re-run lesson 1.")
!cp drive/MyDrive/CTME-LLM-labs/ecommerce.db .

In [None]:
def query(q:str, db:str='ecommerce.db') -> pd.DataFrame:
    """
    Executes a SQL query against the SQLite database and returns the result as a pandas DataFrame.
    Use this function when you want to query a database and return results.

    Args:
        q (str): The SQL query to execute.
        db (str, optional): The path to the SQLite database file. Defaults to 'ecommerce.db'.

    Returns:
        pd.DataFrame: The result of the SQL query as a pandas DataFrame.
    """
    connection = sqlite3.connect(db)
    cursor = connection.cursor()
    cursor.execute(q)
    result = cursor.fetchall()
    df = pd.DataFrame(result)
    df.columns = [i[0] for i in cursor.description]
    connection.close()
    return df

def execute(q:str, db:str='ecommerce.db') -> None:
    """
    Executes an SQL query against the SQLite database.
    Use this when you want to run commands like updates, inserts, or deletes that don't return results.

    Args:
        q (str): The SQL query to execute.
        db (str, optional): The path to the SQLite database file. Defaults to 'ecommerce.db'.

    Returns:
        None
    """
    connection = sqlite3.connect(db)
    cursor = connection.cursor()
    cursor.execute(q)
    connection.commit()
    connection.close()

def instrument():
    """
    Starts a poenix session.

    Returns:
        session: The phoenix session object.
    """
    session = px.launch_app()
    tracer_provider = register(endpoint="http://127.0.0.1:6006/v1/traces")
    LlamaIndexInstrumentor().instrument(skip_dep_check=True, tracer_provider=tracer_provider)
    return session

def end_session(session):
    """
    Ends a phoenix session.

    Args:
        session: The phoenix session object.

    Returns:
        None
    """
    !rm {session.database_url.replace('sqlite:///', '')}
    session.end()

In [None]:
AGENT_SYSTEM_PROMPT = PromptTemplate("""\
You are a helpul customer service assistant for MikeCorp, an ecommerce company selling electronics. \
You are designed to help with a variety of problems a customer may have, including account management, order management, and product-related queries. \
If you are ever unsure what to do, please escalate.

## Tools

You have access to a wide variety of tools. You are responsible for using the tools in any sequence you deem
appropriate to complete the task at hand.
This may require breaking the task into subtasks and using different tools to complete each subtask.

You have access to the following tools:
{tool_desc}


## Output Format

Please answer in English using the following format:

```
Thought: I need to use a tool to help me answer the question.
Action: tool name (one of {tool_names}) if using a tool.
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
```

Please ALWAYS start with a Thought.

NEVER surround your response with markdown code markers. \
You may use code markers within your response if you need to.

Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.

If this format is used, the user will respond in the following format:

```
Observation: tool response
```

You should keep repeating the above format till you have enough information \
to answer the question without using any more tools. \
At that point, you MUST respond in the one of the following two formats:

```
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer:
```

```
Thought: I cannot answer the question with the provided tools.
Answer:
```

## Current Conversation

Below is the current conversation consisting of interleaving human and assistant messages. \
Conversation:
""")

def get_random_id():
    return query("SELECT customer_id FROM customers ORDER BY random() LIMIT 1;").iloc[0,0]

def create_agent(llm, tools, system_prompt=AGENT_SYSTEM_PROMPT, verbose=False):
    agent = ReActAgent.from_tools(tools, llm=llm, verbose=verbose)
    prompt_dict = agent.get_prompts()
    prompt_dict['agent_worker:system_prompt'] = system_prompt
    agent.update_prompts(prompt_dict)
    return agent

In [None]:
def does_id_exist(id:int)-> bool:
    df = query(f"SELECT customer_id FROM customers WHERE customer_id = {id}")
    if len(df) == 0:
        return False
    else:
        return True

def get_update_pin(customer_id:int) -> str:
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def update_pin(new_pin:str) -> str:
        """Use when you want to update the customer's pin.

        Args:
            new_pin (str): The new pin.

        Returns:
            str: A message indicating the success or failure of the update.
        """
        execute(f"UPDATE customers SET pin = '{new_pin}' WHERE customer_id = {customer_id}")
        return "Pin updated successfully."

    return FunctionTool.from_defaults(update_pin)

def get_update_address(customer_id:int):
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def update_address(street:str, city:str, state:str, zip:str, country:str) -> str:
        """Use when you want to update the customer's address.

        Args:
            street (str): The street address.
            city (str): The city.
            state (str): The state.
            zip (str): The zip code.
            country (str): The country.

        Returns:
            str: A message indicating the success or failure of the update.
        """
        execute(f"UPDATE customers SET street_address = '{street}', city = '{city}', state = '{state}', zip_code = '{zip}', country = '{country}' WHERE customer_id = {customer_id}")
        return "Address updated successfully."

    return FunctionTool.from_defaults(update_address)

def get_update_phone_number(customer_id:int):
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def update_phone_number(phone:str) -> str:
        """Use when you want to update the customer's phone number.

        Args:
            customer_id (int): The customer ID.
            phone (str): The new phone number.

        Returns:
            str: A message indicating the success or failure of the update.
        """
        execute(f"UPDATE customers SET phone = '{phone}' WHERE customer_id = {customer_id}")
        return "Phone number updated successfully."

    return FunctionTool.from_defaults(update_phone_number)

def get_user_management_tools(customer_id=get_random_id()):
    return [
        get_update_pin(customer_id),
        get_update_address(customer_id),
        get_update_phone_number(customer_id)
    ]

def get_list_orders(customer_id:int):
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def list_orders() -> List[dict]:
        """Use when you want to list all order data for a customer.

        Args:
            customer_id (int): The customer ID.

        Returns:
            List[dict]: A list of dictionaries containing the order data.
        """
        df = query(f"""
        SELECT o.order_id, i.name, o.ordered_date, o.status, o.estimated_delivery, o.shipping_carrier, o.tracking_number, o.shipping_address
        FROM orders o
        INNER JOIN customers c ON c.customer_id = o.customer_id
        INNER JOIN items i ON i.item_id = o.item_id
        WHERE c.customer_id = {customer_id}
        """)
        return df.to_dict(orient='records')
    return FunctionTool.from_defaults(list_orders)

def does_customer_have_order_id(customer_id, order_id):
    df = query(f"""
    SELECT order_id
    FROM orders
    WHERE customer_id = {customer_id} AND order_id = {order_id}
    """)
    if len(df) == 0:
        return False
    else:
        return True

def get_cancel_order(customer_id:int):
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def cancel_order(order_id:int) -> str:
        """Use when you want to cancel an order.

        Args:
            order_id (int): the order ID.

        Returns:
            str
        """
        if not does_customer_have_order_id(customer_id, order_id):
            return "Order does not belong to this customer."
        execute(f"UPDATE orders SET status = 'cancelled' WHERE order_id = '{order_id}'")
        return "Order cancelled successfully."
    return FunctionTool.from_defaults(cancel_order)

def get_update_order_address(customer_id:int):
    if not does_id_exist(customer_id):
        raise ValueError("Customer id not found on file.")
    def update_order_address(order_id:int, new_street:str, new_city:str, new_zip:str, new_state:str, new_country:str) -> str:
        """Use when you want to update the shipping address of an order.

        Args:
            order_id (int): The order ID.
            new_street (str): The new street address.
            new_city (str): The new city.
            new_zip (str): The new zip code.
            new_state (str): The new state.
            new_country (str): The new country.

        Returns:
            str: A message indicating the success or failure of the update.
        """
        if not does_customer_have_order_id(customer_id, order_id):
            return "Order does not belong to this customer."
        new_address = f"{new_street}, {new_city}, {new_state}, {new_zip}, {new_country}"
        execute(f"UPDATE orders SET shipping_address = '{new_address}' WHERE order_id = '{order_id}'")
        return "Shipping address updated successfully."
    return FunctionTool.from_defaults(update_order_address)

def get_order_tools(customer_id=get_random_id()):
    return [
        get_list_orders(customer_id),
        get_cancel_order(customer_id),
        get_update_order_address(customer_id)
    ]

# Load documents
engine = create_engine('sqlite:///ecommerce.db')
docs = DatabaseReader(engine=engine).load_data(query="SELECT item_id, description, price, quantity, name AS text FROM items")
embed_model = FastEmbedEmbedding(model_name='mixedbread-ai/mxbai-embed-large-v1')
index = VectorStoreIndex.from_documents(docs, embed_model=embed_model, show_progress=True)
inventory_tools = [
    RetrieverTool.from_defaults(index.as_retriever(), description="Useful when you need to answer a question by searching items.", name='search_items'),
]

def get_all_tools(customer_id=get_random_id()):
    return (
        get_base_tools(customer_id)
        + get_user_management_tools(customer_id)
        + get_order_tools(customer_id)
        + inventory_tools
    )

# Does our agent work at all with the un finetuned model?

Let's use the `Ollama` LLM class to run `llama3.2:3b`.
Make sure to run `ollama pull llama3.2:3b` in the Colab terminal before loading the model - otherwise it will give you an error.

Once you've loaded the model, let's run a few basic things we'd expect our agent to do.
Does it work?

In [None]:
llm = Ollama('llama3.2:3b')

In [None]:
customer_id = get_random_id()
customer_info = get_customer_information(customer_id)
print(customer_info)

In [None]:
agent = create_agent(llm, get_all_tools(customer_id), verbose=True)

In [None]:
agent.chat("Look up my information using my id and tell me my address.")

## Discussion:

* Based on the model's outputs, what do you think the training process was?
* Did it give any of the right information?
* Did it call any functions? Why/why not do you think this happened?
* Do you think the model learning our previous conversations might help?

# Fine tuning: Environment setup and imports

In this section, we install and set up `axolotl`, the framework we will use to configure our fine tuning.
We will also configure our `HF_TOKEN` here, so we can upload our model to 🤗 once we're done fine-tuning.

In [None]:
# Install axolotl
import os
if os.path.exists("axolotl"):
  !rm -rf axolotl
!git clone https://github.com/axolotl-ai-cloud/axolotl
# This handles a mismatch between xformers torch requirements and that of other dependencies
with open('/content/axolotl/requirements.txt', 'r') as file:
    requirements = file.read()
    # replace xformers==0.0.27 with xformers
    requirements = requirements.replace('xformers==0.0.27', 'xformers')
with open('/content/axolotl/requirements.txt', 'w') as file:
    file.write(requirements)
!pip install -qqqq ninja packaging mlflow=="2.13.0" --progress-bar off
!cd axolotl && pip install -qqqq -e ".[flash-attn,deepspeed]" --progress-bar off

In [None]:
# Set the `HF_TOKEN` env variable
from google.colab import userdata
import os
token = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = token

In [None]:
# Pull in our train.jsonl and eval.jsonl from Google drive
if not os.path.exists("/content/drive/MyDrive/CTME-LLM-labs/train.jsonl"):
    raise ValueError("data.jsonl does not exist. Please make sure you've connected to google drive and run the first two lab notebooks.")
else:
    try:
        !rm -r data
    except:
        pass
    !mkdir data
    !cp /content/drive/MyDrive/CTME-LLM-labs/*.jsonl data/

In [None]:
# Observe 1 row from the training data.
!head data/train.jsonl -n 1 | python -m json.tool

# Axolotl configuration

In this section, we define everything about how we want to fine tune the model, including what model we want to fine tune, where the data is, what template we want to use, and where to export the model.

This config was mostly copied from [Axolotl's repositlry of examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples).
You can spend a lot of time fiddling aroudn with hyperparameters, but these examples are pretty good and relatively easy to modify for anything you want to do.
Your time is *much* better spent curating data and making sure your data is properly formatted rather than messing around with hyperparameters.
There are also model-specific quirks that mean it's challenging to apply one fine tuning configuration to another model.
For example, the modules targeted for LoRA adaptes may be named differently in different model families (llama, gemma, mistral, etc.).
Do yourself a favor and start with something that works!

In [None]:
%%writefile axolotl.yaml
base_model: meta-llama/Llama-3.2-3B

load_in_8bit: false
load_in_4bit: true
strict: false
adapter: qlora

# Data config
dataset_prepared_path: data
chat_template: chatml
datasets:
  - path: data/train.jsonl
    ds_type: json
    data_files:
      - data/train.jsonl
    conversation: alpaca
    type: sharegpt

test_datasets:
  - path: data/eval.jsonl
    ds_type: json
    # You need to specify a split. For "json" datasets the default split is called "train".
    split: train
    type: sharegpt
    conversation: alpaca
    data_files:
      - data/eval.jsonl

sequence_len: 4096
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: "<|end_of_text|>"

# Data preparation

Currently, our data is in the `sharegpt` format, we have an array of conversations with a `from` and a `value` field that can't directly be used to fine-tune the model.
We need to convert it to a template.
In this case, we will use the popular ChatML template.
Below is the `jinja` template for ChatML ([source](https://huggingface.co/docs/transformers/main/en/chat_templating#what-template-should-i-use)):

```python
{%- for message in messages %}
    {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
{%- endfor %}
```

When we pipe our chat messages through this template, the actual text will be formatted like so ([source](https://huggingface.co/docs/transformers/main/en/chat_templating#what-are-generation-prompts)):

```
<|im_start|>user
Hi there!<|im_end|>
<|im_start|>assistant
Nice to meet you!<|im_end|>
<|im_start|>user
Can I ask a question?<|im_end|>
```

We also need to tokenize our data - turn it into tensrs of integers that the model can read and learn from.
`axolotl` will also help us tokenize the data and perfom several optimizations such as sample packing (putting multiple smaller samples in the same training example to reduce the number of padding tokens) and creating masks so we don't train on system prompts and user messages.

Thankfully, `axolotl` handles all these complex configs!

In [None]:
# Tokenize the data
!python -m axolotl.cli.preprocess /content/axolotl.yaml

# Fine tuning

Wow, that was tough!
Now let's do another hard thing, fine tune the model.
The following command will launch our fine tuning run.
It will save the LoRA adapters in the output folder.

In [None]:
# By using the ! the comand will be executed as a bash command
!accelerate launch -m axolotl.cli.train /content/axolotl.yaml

# Merge weights

In this cell, we will merge the LoRA weights with the original model.

In [None]:
!python3 -m axolotl.cli.merge_lora axolotl.yaml

# Export to .gguf

We have decided we want to run this model with `ollama`, so we need to export to `.gguf`.
Thankfully, `llama.cpp` comes with a handy script that helps us export our 🤗 `transformers` - style model to `.gguf`.

In [None]:
# Convert to .gguf
!python llama.cpp/convert_hf_to_gguf.py /content/model-out/merged \
  --outfile /content/model-out/customer-service-agent-merged.gguf \
  --outtype q8_0

# Upload to 🤗

Our final step is to upload the model to Huggingface.
Make sure you have the `HF_TOKEN` environment variable set, then run the next several cells.
Make sure to populate your hf.co username and the name of the repo you want to upload to.
This will upload the merged file and the `.gguf` file.

In [None]:
!huggingface-cli login --token $HF_TOKEN --add-to-git-credential

In [None]:
import os
username = None
repo_name = None
if not username or not repo_name:
    username = input("Username: ")
    repo_name = input("Repo name: ")
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
!huggingface-cli upload {username}/{repo_name} /content/model-out/ .

# Use your fine tune

Ok! Finally, you have your fine tuned model.
Let's go back up to the top of the notebook and use it instead of our original model and see if there was any improvement.
When you instantiate the `Ollama` LLM class, replace the model we originally used with `"hf.co/{your_username}/{your_repo_name}"`.
Then, run the agent normally as we did before!

# Discussion: Next steps

If you were to try and improve on this solution, what steps would you take?