# Fine Tune LLM (On Text Classification Data)

We'll walk through prepping data and fine-tuning a model in this notebook. Databricks makes it easy to fine tune a model with a few lines of code. We will be working with a text classification dataset: the [ml4pubmed dataset from hugging face](https://huggingface.co/datasets/ml4pubmed/pubmed-text-classification-cased) which contains text from different sections of scientific research articles.

This notebook is the first in a series to showcase the performance gains provided by combining fine-tuning and prompt optimization. See this article for more [detail](https://www.databricks.com/blog/building-state-art-enterprise-agents-90x-cheaper-automated-prompt-optimization). 

In [0]:
%pip install databricks_genai
%pip install databricks-sdk
dbutils.library.restartPython()

In [0]:
from databricks.model_training import foundation_model as fm
from datasets import load_dataset
import os
import mlflow
from pyspark.sql.functions import col

## Set Variables & Prompt

In [0]:
catalog = "megan_fang_demos"
schema = "llm_opt"
model_name = "cc-meta-llama-3-1-8b-instruct"
registered_model_name = f"{catalog}.{schema}.{model_name}"
dataset = "ml4pubmed/pubmed-text-classification-cased"

spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"USE SCHEMA {schema}")

os.environ["HF_DATASETS_CACHE"] = "/Volumes/megan_fang_demos/llm_opt/datasets"

Let's include a helpful system prompt for our LLM.

In [0]:
system_prompt = """Given a text field containing a sentence from a research paper abstract, classify it into one of the following categories and output only the category label:

**Categories:**
- BACKGROUND: Contextual information, established knowledge, problem statements, or general facts that set up the research context (e.g., "Although opioids are effective treatments for postoperative pain, they contribute to the delayed recovery of gastrointestinal function.")
- OBJECTIVE: Research aims, goals, purposes, hypotheses, or what the study was designed to investigate (e.g., "This study was designed to assess...", "The aim was to investigate...", "To determine whether...")
- METHODS: Descriptions of experimental procedures, data collection methods, study protocols, analytical approaches, statistical methods, or control group descriptions (e.g., "Blood samples were collected...", "This fifth group served as a control...", "mean (+/- SD) was used")
- RESULTS: Findings, outcomes, statistical data, success rates, observed effects, or factual outcomes from the study (e.g., "The success rate was 70.4%...", "Follow-up coronary angiography was performed in 108 patients...", "a significant improvement in terms of OS (p = 0.02)")
- CONCLUSIONS: Final interpretations, implications, recommendations, suggestions, or what the authors conclude from their findings. These often contain interpretive language and speculation (e.g., "From this finding, we conclude that...", "may be associated with...")

**Critical Classification Guidelines:**

1. **OBJECTIVE vs BACKGROUND**: 
    - Sentences describing what a study "was designed to," "aimed to," "evaluated whether," or starting with "To determine/investigate" are OBJECTIVE, not BACKGROUND
    - "There is a need to..." statements are OBJECTIVE (expressing research purpose), not BACKGROUND
    - Study purpose statements are always OBJECTIVE regardless of their position in the abstract

2. **RESULTS vs CONCLUSIONS**: 
    - RESULTS report factual findings, statistical outcomes, and observed data without interpretation
    - CONCLUSIONS contain interpretive statements with words like "may be," "suggest," "conclude," "might," or implications drawn from results
    - Statements with "evidence suggesting" are CONCLUSIONS, not BACKGROUND

3. **Key Phrases for OBJECTIVE**: 
    - "There is a need to..."
    - "This study evaluated whether..."
    - "was designed to assess..."
    - "The aim was to..."
    - "To determine..."
    - "This clinical study was designed to assess whether..."

4. **Key Phrases for CONCLUSIONS**: 
    - "we conclude that..."
    - "may be associated with..."
    - "might benefit from..."
    - "suggest that..."
    - "evidence suggesting..."
    - Any speculative or interpretive language

5. **METHODS indicators**: 
    - Descriptions of procedures, controls, statistical approaches
    - Data collection timing and methods
    - Analytical methods and statistical significance thresholds
    - Control group descriptions

6. **Focus on primary purpose**: Classify based on the main function of the sentence in the research narrative, not just its position in the abstract.

**Common Misclassification Patterns to Avoid:**
- Study objectives stated as "This study was designed to..." should be OBJECTIVE, not BACKGROUND
- "There is a need to identify..." statements are OBJECTIVE, not BACKGROUND
- Statements with "may be associated with" or similar speculative language are CONCLUSIONS, not RESULTS
- "Evidence suggesting" statements are CONCLUSIONS, not BACKGROUND
- Research aims are always OBJECTIVE regardless of how they are phrased

**Classification Strategy:**
1. Look for explicit objective markers first (study aims, purposes, "to determine")
2. Check for speculative/interpretive language indicating CONCLUSIONS
3. Identify factual data reporting for RESULTS
4. Look for procedural descriptions for METHODS
5. Default to BACKGROUND only for established knowledge or context-setting information

Based on the above, categorize the following sentence and output only the category label (BACKGROUND, OBJECTIVE, METHODS, RESULTS, or CONCLUSIONS) without any additional text or explanation : \n\n"""

## Test Base Model

Let's test the performance of our LLM before fine-tuning to get a baseline performance.

In [0]:
test_table_name = "hf_pubmed_test"
pred = spark.sql(f"""
        SELECT
            ai_query('databricks-meta-llama-3-1-8b-instruct', concat('{system_prompt}', prompt)) AS prediction,
            response,
            prompt
        FROM {catalog}.{schema}.{test_table_name}
        LIMIT 10""")
display(pred)

In [0]:
total_count = pred.count()
correct_predictions = pred.filter(col("prediction") == col("response")).count()
overall_accuracy = correct_predictions / total_count

print(f"Total samples: {total_count}")
print(f"Baseline Accuracy: {overall_accuracy:.4f}")

## Load Data

In [0]:
def load_hf_dataset(dataset: str, split: str="train", include_cols: list[str]=['prompt', 'response'], rename_cols: Dict | None = None, register_to_uc: boolean=False, catalog: str=catalog, schema: schema=schema, table_name: str=train_table_name) -> SparkDataframe:
  """
  Load dataset from hugging face and converts to spark dataframe. Optionally register to Unity Catalog.

  Args:
    dataset (str): Hugging Face dataset name
    split (str): Dataset split to load
    include_cols (list): Columns to include in the final dataframe
    rename_cols (dict): Dictionary of column names to rename
    register_to_uc (bool): Whether to register the dataframe to Unity Catalog
    catalog (str): Unity Catalog catalog name
    schema (str): Unity Catalog schema name
    table_name (str): Unity Catalog table name
  Returns:
    Spark DataFrame: Converted dataframe
  """
  # Load hugging face dataset
  hf_df = load_dataset(
      path=dataset,
      split=split
  )

  # Convert to pandas then spark dataframe
  pd_df = hf_df.to_pandas()[include_cols].dropna()
  hf_spark = spark.createDataFrame(pd_df)

  if rename_cols is not None:
    hf_spark = hf_spark.withColumnRenamed(rename_cols["prompt"], "prompt").withColumnRenamed(rename_cols["response"], "response")

  # Optionally register to unity catalog
  if register_to_uc:
    hf_spark.write.format("delta").mode("overwrite").saveAsTable(f"{catalog}.{schema}.{table_name}")

  return hf_spark

In [0]:
train_table_name = "hf_pubmed_train"
rename_cols = {"prompt": "description", "response": "target"}

hf_train = load_hf_dataset(dataset=dataset, include_cols=['target', 'description'], rename_cols=rename_cols, register_to_uc=True, table_name=train_table_name)

In [0]:
eval_table_name = "hf_pubmed_eval"
hf_eval = load_hf_dataset(dataset=dataset, split="validation", include_cols=['target', 'description'], rename_cols=rename_cols, register_to_uc=True, table_name=eval_table_name)

## Transform Data For Chat Completion

Databricks fine-tuning supports a few different task types:

- **Chat completion**: Train your model on chat logs between a user and an AI assistant. The text is automatically formatted into the appropriate format for the specific model. 
- **Instruction fine-tuning**: Train your model on structured prompt-response data. Use this to adapt your model to a new task, change its response style, or add instruction-following capabilities. This task does not automatically apply any formatting to your data and is only recommended when custom data formatting is required.

I found chat completion to perform better than instruction fine-tuning for this task, so let's format our data for chat completion. 

Chat completion requires a list of role and prompt, following the OpenAI standard:

In [0]:
[
  {"role": "system", "content": "[system prompt]"},
  {"role": "user", "content": "Here is a documentation page:[RAG context]. Based on this, answer the following question: [user question]"},
  {"role": "assistant", "content": "[answer]"}
]

In [0]:
spark.sql(f"""
CREATE OR REPLACE TABLE hf_pubmed_train_chat_complete AS
SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{system_prompt}', '\n', prompt) AS content),
        STRUCT('assistant' AS role, response AS content)
    ) AS messages
FROM hf_pubmed_train;
""")

spark.table('hf_pubmed_train_chat_complete').limit(10).display()

In [0]:
spark.sql(f"""
CREATE OR REPLACE TABLE hf_pubmed_eval_chat_complete AS
SELECT 
    ARRAY(
        STRUCT('user' AS role, CONCAT('{system_prompt}', '\n', prompt) AS content),
        STRUCT('assistant' AS role, response AS content)
    ) AS messages
FROM hf_pubmed_eval;
""")

spark.table('hf_pubmed_eval_chat_complete').limit(10).display()

## Start Fine Tuning Run

See supported model types and more details on databricks' fine tuning module [here](https://docs.databricks.com/aws/en/large-language-models/foundation-model-training/).

Notes on lower learning rate 

In [0]:
def get_current_cluster_id():
  import json
  return json.loads(dbutils.notebook.entry_point.getDbutils().notebook().getContext().safeToJson())['attributes']['clusterId']

In [0]:
run = fm.create(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
                data_prep_cluster_id=get_current_cluster_id(), # necessary when using a Unity Catalog dataset to train
                train_data_path=f"{catalog}.{schema}.{train_table_name}",
                eval_data_path=f"{catalog}.{schema}.{eval_table_name}",
                register_to=f"{catalog}.{schema}.{model_name}",
                task_type="CHAT_COMPLETION",
                learning_rate="5e-8",
                training_duration="20ep")
run

In [0]:
from mlflow import MlflowClient
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

# Get latest version
versions = client.search_model_versions(f"name='{catalog}.{schema}.{model_name}'")
latest_version = max(versions, key=lambda mv: int(mv.version)).version

# Create or update the alias to point at the specified version
alias = "champion"
client.set_registered_model_alias(name=f"{catalog}.{schema}.{model_name}", alias=alias, version=latest_version)

## Create Serving Endpoint

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    ServedEntityInput,
    EndpointCoreConfigInput,
    AiGatewayConfig,
    AiGatewayInferenceTableConfig
)

serving_endpoint_name = "cc-meta-llama-3-1-8b-instruct"
w = WorkspaceClient()

# Create the AI Gateway configuration
ai_gateway_config = AiGatewayConfig(
    inference_table_config=AiGatewayInferenceTableConfig(
        enabled=True,
        catalog_name=catalog,
        schema_name=schema,
        table_name_prefix="cc-meta-llama-3-1-8b-instruct_inference"
    )
)

endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_entities=[
        ServedEntityInput(
            entity_name=registered_model_name,
            entity_version=latest_version,
            min_provisioned_throughput=0, # The minimum tokens per second that the endpoint can scale down to.
            max_provisioned_throughput=10900,# The maximum tokens per second that the endpoint can scale up to. 
            scale_to_zero_enabled=True
        )
    ]
)

existing_endpoint = next(
    (e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)

if existing_endpoint is None:
    print(f"Creating the endpoint {serving_endpoint_name}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config, ai_gateway=ai_gateway_config)
else:
    print(f"Endpoint {serving_endpoint_name} already exists...")

In [0]:
test_table_name = "hf_pubmed_test"
pred = spark.sql(f"""
        SELECT
            ai_query('{serving_endpoint_name}', concat('{system_prompt}', prompt)) AS prediction,
            response,
            prompt
        FROM {catalog}.{schema}.{test_table_name}
        """)
display(pred)

In [0]:
# Calculate overall accuracy
total_count = pred.count()
correct_predictions = pred.filter(col("prediction") == col("response")).count()
overall_accuracy = correct_predictions / total_count

print(f"Total samples: {total_count}")
print(f"Finetuned Accuracy: {overall_accuracy:.4f}")

Our baseline accuracy was 40%, and our new accuracy after fine-tuning is 65%!