# Serving Fine Tuned Gemma Model with multiple LoRA adapters on Databricks 

This is a tutorial to show how to serve [`gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it) with multiple LoRA adpaters on Databricks Model Serving.

Environment for this notebook:
- Runtime: 16.1 GPU ML Runtime
- Instance: Tested on `g5.8xlarge` for AWS, smaller GPU cluster should also work
- MLFlow 2.15

Serving Endpoint requirement:
- For this example, 1 T4 GPU is sufficient without further quantization (GPU Small)


## Install required packages

Run the cells below to setup and install the required libraries. Since gemma-2-2b-it is small enought to fit in the cluster, we are not loading a quantized base model. However, for larger models (e.g. Llama 7 or 8B models), we can use `bitsandbytes` to [quantize the base model into 4bit](https://huggingface.co/blog/4bit-transformers-bitsandbytes). We will also need `accelerate`, `peft`, `transformers` to lload the base model and PEFT adapters.

In [0]:
%pip install bitsandbytes==0.45
%pip install accelerate
%pip install -U peft
%pip install -U transformers

dbutils.library.restartPython()

In [0]:
# import os
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] ='max_split_size_mb:128'
# import torch
# torch.cuda.empty_cache()
# torch.cuda.memory_summary() 

In [0]:
# !nvidia-smi

## Loading the model

In this section we will load the [gemma-2-2b-it](https://huggingface.co/google/gemma-2-2b-it) model and a few popular open source adapters from Huggingface and save to Unity Catalog Volumes.

Adapters we will be using here:
  - [google-cloud-partnership/gemma-2-2b-it-lora-sql](https://huggingface.co/google-cloud-partnership/gemma-2-2b-it-lora-sql)
  - [google-cloud-partnership/gemma-2-2b-it-lora-jap-en](https://huggingface.co/google-cloud-partnership/gemma-2-2b-it-lora-jap-en)
  - [google-cloud-partnership/gemma-2-2b-it-lora-magicoder](https://huggingface.co/google-cloud-partnership/gemma-2-2b-it-lora-magicoder)


In [0]:
catalog = "cindy_demo_catalog"
schema = "llm_fine_tuning"
volume = "hf_models"

spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{schema}.{volume}")

In [0]:
base_model_path = f'/Volumes/{catalog}/{schema}/{volume}/gemma-2-2b-it'
adapters_path = f'/Volumes/{catalog}/{schema}/{volume}/adapters' # Directory to store all adapters
adapters_mapping_path = f'/Volumes/{catalog}/{schema}/{volume}/adapters_mapping.json' # Mapping of adapters to model names
adapters_mapping = {'sql' :'gemma-2-2b-it-lora-sql',
                    'japanese': 'gemma-2-2b-it-lora-jap-en',
                    'coder': 'gemma-2-2b-it-lora-coder'
                    }

# base_tokenizer_path = f'/Volumes/{catalog}/{schema}/{volume}/gemma-2-2b-it-tokenzier' ## Specify this if tokenizer is not stored with the base model and has a different path

import os
if not os.path.exists(adapters_path):
    dbutils.fs.mkdirs(adapters_path)
    
import json
with open(adapters_mapping_path, 'w') as f:
    json.dump(adapters_mapping, f)

## (Optional) Download base model and adapters to Unity Catalog Volumes
- Requires Huggingface Token with access to use gemma model

In [0]:
# dbutils.widgets.text("huggingface_token", "", "Enter Parameter")
# os.environ['HF_TOKEN'] = dbutils.widgets.get("huggingface_token")

In [0]:
# from peft import PeftModel
# from huggingface_hub import snapshot_download

# # Download base Gemma model
# base_model_path = snapshot_download(repo_id="google/gemma-2-2b-it", local_dir=base_model_path)

# # Download LoRA adapters
# lora_sql_path = snapshot_download(repo_id="google-cloud-partnership/gemma-2-2b-it-lora-sql", local_dir=f"{adapters_path}/gemma-2-2b-it-lora-sql")
# lora_jap_en_path = snapshot_download(repo_id="google-cloud-partnership/gemma-2-2b-it-lora-jap-en", local_dir=f"{adapters_path}/gemma-2-2b-it-lora-jap-en")
# lora_coder_path = snapshot_download(repo_id='google-cloud-partnership/gemma-2-2b-it-lora-magicoder', local_dir=f"{adapters_path}/gemma-2-2b-it-lora-coder")

## (Optional) Load all adapters and test peft model locally

In [0]:
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
# from peft import PeftModel

# base_model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map="cuda:0")
# tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# peft_model = base_model 

# for name, path in adapters_mapping.items():
#   # This loads the adapter onto the model under the provided adapter name.
#   peft_model.load_adapter(f"{adapters_path}/{path}", adapter_name=name)
#   print('loaded Peft Model',name)


In [0]:
# generated_texts

In [0]:
## peft_model.delete_adapter("sql") ## To remove an adapter from peft model

## Create MLFlow PyFunc Model with Multiple Adapters

In [0]:

import mlflow
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from peft import PeftModel, PeftConfig

class FINETUNED_QLORA(mlflow.pyfunc.PythonModel):
    # Load base model, tokenizer, and adapters.
    def load_context(self, context):
        import json
        import os

        ## Uncomment this to load a quanitized model, requires less memory, slower inference due to de-quant overhead
        # bnb_config = BitsAndBytesConfig(
        # load_in_4bit=True,
        # bnb_4bit_quant_type="nf4",
        # # bnb_4bit_use_double_quant=True,
        # bnb_4bit_compute_dtype=torch.float16,
        # )

        # Load the tokenizer and set the pad token to the EOS token.
        self.tokenizer = AutoTokenizer.from_pretrained(context.artifacts['base_model'])
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load the base model (using 4-bit quantization in this example).
        self.base_model = AutoModelForCausalLM.from_pretrained(
            context.artifacts['base_model'], 
            return_dict=True, 
            # quantization_config=bnb_config, 
            torch_dtype=torch.float16,
            device_map={"": 0})
        
        # Load PEFT adapters from a dictionary artifact.
        with open(context.artifacts["adapters_mapping"], "r") as f:
            self.adapters_mapping = json.load(f)

        print('loaded adapter mappings')
       
        self.model = self.base_model
  
        for adapter_name, adapter_path in self.adapters_mapping.items():
            self.model.load_adapter(f"{context.artifacts["adapters"]}/{adapter_path}", adapter_name=adapter_name)
            print('loaded Peft Model',f"{context.artifacts["adapters"]}/{adapter_path}")

        ## Set the model to evaluation mode. Use this for a merged model
        # self.model.eval()

        self.model.config.use_cache = False

    def predict(self, context, model_input, params):
        # Handle single or batch prompts, input should be a list[str]
        prompts = model_input.get("prompts")[0]

        print('input:', prompts)
        
        temperature = float(params.get('temperature', 0.1))
        max_tokens = int(params.get('max_tokens', 100))
        adapter_name = params.get('adapter', 'sql')
        print( 'params: ', temperature, max_tokens, adapter_name)

        # Activate the desired adapter if provided.
        if adapter_name in list(self.adapters_mapping.keys()):
          self.model.set_adapter(adapter_name)

        else:
          print('no adapter found')
          generated_text = 'no adapter found'
          return generated_text
        
        # Tokenize the input prompt with padding and truncation, and move to CUDA.
        batch = self.tokenizer(text=prompts, padding=True, truncation=True, return_tensors='pt').to('cuda')

        with torch.amp.autocast('cuda'):
            output_tokens = self.model.generate(
                input_ids=batch.input_ids, 
                max_new_tokens=max_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        # Decode the generated tokens into text.
        generated_texts = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
        return generated_texts 

## (Optional) Test MLFlow PyFunc Model locally

In [0]:
# Create a local model context object with required artifact paths for testing
artifacts = {
    "base_model": base_model_path,  
    "adapters_mapping": adapters_mapping_path,
    "adapters": adapters_path
            }

class ModelContext:
    def __init__(self):
        self.artifacts = artifacts
# Instantiate a dummy context.
dummy_context = ModelContext()

In [0]:
# Instantiate your pyfunc wrapper and load the context.
finetuned_model = FINETUNED_QLORA()
finetuned_model.load_context(dummy_context)

In [0]:
test_input =  {'prompts': ["what is Databricks?", "what's ML"]}
params = {
    "temperature": 0.1,
    "max_tokens": 100,
    "adapter": "sql"
}
# Run a prediction and display the output.
print("Testing prediction...")
generated_output = finetuned_model.predict(dummy_context, test_input, params)
print("Generated Output:", generated_output)

In [0]:
generated_output

In [0]:
test_input = {
    "prompts": ["データブリックスとは"],
}

# Run a prediction and display the output.
print("Testing prediction...")
generated_output = finetuned_model.predict(dummy_context, test_input, params)
print("Generated Output:", generated_output)

In [0]:
test_input = {
    "prompts": ["generate some pandas code to create a dataframe"]
}

# Run a prediction and display the output.
print("Testing prediction...")
generated_output = finetuned_model.predict(dummy_context, test_input, params)
print("Generated Output:", generated_output)

## Log to MLFlow + Register model in UC

In [0]:
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec
from mlflow.types.schema import Array, DataType, Schema

import pandas as pd
import numpy as np
import mlflow
from mlflow.models.signature import infer_signature

# Set mlflow registry to databricks-uc
mlflow.set_registry_uri("databricks-uc")
# Specify an input example that conforms to the input schema for the task.
import numpy as np
input_data={"prompts": ["what is Databricks?", "what's ML"]}
               
params = {
    "temperature": 0.1,
    "max_tokens": 100,
    "adapter": "sql"
}
input_example = (input_data, params)

output_example = {"generated_texts": ["what is Databricks?\n\n```sql\nSELECT * FROM Databricks;```\nThis query retrieves all records from the 'Databricks' table.\n", "what's ML model performance for each model?\nmodel\n```sql\nSELECT model_name, performance_score FROM model_performance;```\nThis query retrieves the ML model performance for each model by selecting the model_name and performance_score columns from the model_performance table.\n"]}

signature = infer_signature(input_data, output_example, params)
signature 

In [0]:
model_name = "gemma_2_multi_adapters"
registered_model_name = f"{catalog}.{schema}.{model_name}"


artifacts = {
    # "tokenizer": base_tokenizer_path,      
    "base_model": base_model_path,  
    "adapters_mapping": adapters_mapping_path,
    "adapters": adapters_path
            }


with mlflow.start_run() as run:  
    model_info = mlflow.pyfunc.log_model(
        "model",
        python_model=FINETUNED_QLORA(),
        artifacts= artifacts,
        pip_requirements=["torch==2.5.0", "torchvision==0.20.0","transformers==4.46.3", "accelerate==1.1.1", "peft==0.15.2", "bitsandbytes==0.45.0"],
        input_example= input_example,
        signature=signature,
        registered_model_name=registered_model_name
    )

## (Optional) Load MLFLow Model locally to test

In [0]:
## Restart cluster to avoid OOM if using a small GPU cluster (16G should be plenty for Gemma2 without restarting)
#  dbutils.library.restartPython()

In [0]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

In [0]:
import mlflow
import pandas as pd
import numpy as np
input_data = {
    "prompts":["what is Databricks?","import pandas as"]
}
params = {
    "temperature": 0.1,
    "max_tokens": 100,
    "adapter": "coder"
    }

preds = loaded_model.predict(input_data, params)

In [0]:
preds

## Serve Registered Model

In [0]:
import requests
import json

# Set the name of the MLflow endpoint
endpoint_name = "gemma_2_multi_adapters"

# Get the latest version of the MLflow model
model_version = model_info.registered_model_version
print(model_version)
# Name of the registered MLflow model
registered_model_name = f"{catalog}.{schema}.{model_name}"

# Get the API endpoint and token for the current notebook context
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()


Create model endpoint for model serving. Gemma-2-2b-it can fit on a T4 GPU (GPU Small) with no quantization, choose GPU size based on loaded pyfunc model size

In [0]:

# Specify the type of compute (CPU, GPU_SMALL, GPU_LARGE, etc.) 
workload_type = "GPU_SMALL" 

# Specify the scale-out size of compute (Small, Medium, Large, etc.)
workload_size = "Small" 

# Specify Scale to Zero(only supported for CPU endpoints)
scale_to_zero = True 

data = {
    "name": endpoint_name,
    "config": {
        "served_entities": [
            {
                "entity_name": registered_model_name,
                "entity_version": model_version,
                "workload_size": workload_size,
                "scale_to_zero_enabled": scale_to_zero,
                "workload_type": workload_type,
            }
        ]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(
    url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
)

print(json.dumps(response.json(), indent=4))

If endpoint already exists, you can update endpoint with the deisered model version or endpoint configs

In [0]:
from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")
endpoint = client.update_endpoint(
    endpoint=endpoint_name,
    config={
        "served_entities": [
            {
              "entity_name": registered_model_name,
                "entity_version": model_version,
                "scale_to_zero_enabled": scale_to_zero,
                "workload_type": workload_type,
                "workload_size": workload_size

            }
        ],
    },
)

## View your endpoint
To see more information about your endpoint, go to the Serving UI and search for your endpoint name.

## Query your endpoint
Once your endpoint is ready, you can query it by making an API request. Depending on the model size and complexity, it can take 30 minutes or more for the endpoint to get ready.

In [0]:

data = {
  "dataframe_split": {
    "columns": [
      "prompts"
    ],
    "data": [
      [
        [
          "what is Databricks?",
          "what's ML"
        ]
      ]
    ]
  },
  "params": {
    "temperature": 0.1,
    "max_tokens": 100,
    "adapter": "sql"
  }
}
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(
    url=f"{API_ROOT}/serving-endpoints/{endpoint_name}/invocations", json=data, headers=headers
)

print(json.dumps(response.json()))