# Serve multiple LoRA adapters efficiently on SageMaker

In this tutorial, we will learn how to serve many Low-Rank Adapters (LoRA) on top of the same base model efficiently on the same GPU. In order to do this, we'll deploy the LoRA Exchange ([LoRAX](https://github.com/predibase/lorax/tree/main)) inference server to SageMaker Hosting. 

These are the steps we will take:

1. [Setup our environment](#setup)
2. [Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR](#container)
3. [Download adapters from the HuggingFace Hub and upload them to S3](#download_adapter)
4. [Deploy the extended LoRAX container to SageMaker](#deploy)
5. [Compare outputs of the base model and the adapter model](#compare)
6. [Benchmark our deployed endpoint under different traffic patterns - same adapter, and random access to many adapters](#benchmark)


## What is LoRAX? 

LoRAX is a production-ready framework specialized in multi-adapter serving that  efficiently share the same GPU resources, which dramatically reduces the cost of serving without compromising on throughput or latency. Some of the features that enable this are: 

* Dynamic Adapter Loading - fine-tuned LoRA weights are loaded from storage (local or remote) just-in-time as requests come in at runtime
* Tiered Weight Caching - fast exchanging of LoRA adapters between requests, and offloading of adapter weights to CPU and disk as they are not needed to avoid out-of-memory errors.
* Continuous Multi-Adapter Batching - a fair scheduling policy that continuously batches requests targeted at different LoRA adapters so they can be processed in paralle, optimizing aggregate throughput.
* Optimized Inference - high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming.

You can read more about LoRAX [here](https://predibase.com/blog/lora-exchange-lorax-serve-100s-of-fine-tuned-llms-for-the-cost-of-one).

<a id="setup"></a>
## 1. Setup our environment 

In [None]:
!pip install -U boto3 sagemaker huggingface_hub --quiet

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()

# sagemaker session bucket -> used for uploading data, models and logs
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
region = sess._region_name 

print(f"sagemaker role arn: {role}")
print(f"sagemaker session region: {region}")

<a id="container"></a>
## 2. Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR

This example includes a `Dockerfile` and `sagemaker_entrypoint.sh` in the `sagemaker_lorax` directory. Building this new container image makes LoRAX compatible with SageMaker Hosting, namely launching the server on port 8080 via the container's `ENTRYPOINT` instruction. [Here](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-image) you can find the basic interfaces required to adapt any container for deployment on Sagemaker Hosting.

In [None]:
!cat sagemaker_lorax/sagemaker_entrypoint.sh

In [None]:
!cat sagemaker_lorax/Dockerfile

We build the new container image and push it to a new ECR repository. Note SageMaker [supports private Docker registries](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-containers-inference-private.html) as well.

In [None]:
%%bash -s {region}
algorithm_name="sagemaker-lorax"  # name of your algorithm
tag="0.8.0"
region=$1

account=$(aws sts get-caller-identity --query Account --output text)

image_uri="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:${tag}"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" --region $region > /dev/null
fi

cd sagemaker_lorax/ && docker build  --build-arg VERSION=$tag -t ${algorithm_name}:${tag} .

# Authenticate Docker to an Amazon ECR registry
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin ${account}.dkr.ecr.${region}.amazonaws.com

# Tag the image
docker tag ${algorithm_name}:${tag} ${image_uri}

# Push the image to the repository
docker push ${image_uri}

# Save image name to tmp file to use when deploying endpoint
echo $image_uri > /tmp/image_uri

<a id="download_adapter"></a>
## 3. Download adapter from HuggingFace Hub and push it to S3

We are going to simulate storing our adapter weights on S3, and having LoRAX load them dynamically as we invoke them. This enables most scenarios, including deployment after you’ve finetuned your own adapter and pushed it to S3, as well as securing deployments with no internet access inside your VPC, as detailed in this [blog post](https://www.philschmid.de/sagemaker-llm-vpc#2-upload-the-model-to-amazon-s3).

We first download an adapter trained with Mistral Instruct v0.1 as the base model to a local directory. This particular adapter was trained on GSM8K, a grade school math dataset.

In [None]:
from pathlib import Path
from huggingface_hub import snapshot_download

HF_MODEL_ID = "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k"
# create model dir
model_dir = Path('mistral-adapter')
model_dir.mkdir(exist_ok=True)

# Download model from Hugging Face into model_dir
snapshot_download(
    HF_MODEL_ID,
    local_dir=str(model_dir), # download to model dir
    local_dir_use_symlinks=False, # use no symlinks to save disk space
    revision="main", # use a specific revision, e.g. refs/pr/21
)

We copy this same adapter `n_adapters` times to different S3 prefixes in our SageMaker session bucket, simulating a large number of adapters we want to serve on the same endpoint and underlying GPU.

In [None]:
import os

s3 = boto3.client('s3')

def upload_folder_to_s3(local_path, s3_bucket, s3_prefix):
    for root, dirs, files in os.walk(local_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            s3_object_key = os.path.join(s3_prefix, os.path.relpath(local_file_path, local_path))
            s3.upload_file(local_file_path, s3_bucket, s3_object_key)

# Upload the folder n_adapters times under different prefixes
n_adapters=50
base_prefix = 'lorax/mistral-adapters'
for i in range(1, n_adapters+1):
    prefix = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'

    upload_folder_to_s3(model_dir, sagemaker_session_bucket, prefix)
    print(f'Uploaded folder to S3 with prefix: {prefix}')

<a id="deploy"></a>
## 4. Deploy SageMaker endpoint


Now we deploy a SageMaker endpoint, pointing to our SageMaker session bucket as the ADAPTER_BUCKET env variable, which enables downloading adapters from S3.

If you have any problems deploying on g5.xlarge, you can change the instance type to g5.2xlarge or g5.4xlarge.

In [None]:
import json
import datetime

from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# Retrieve image_uri from tmp file
image_uri = !cat /tmp/image_uri
# Increased health check timeout to give time for model download
health_check_timeout = 800
number_of_gpu = 1
instance_type = "ml.g5.xlarge"
endpoint_name = endpoint_name = sagemaker.utils.name_from_base("sm-lorax")


# Model and Endpoint configuration parameters
env = {
  'HF_MODEL_ID': "mistralai/Mistral-7B-Instruct-v0.1", # model_id from hf.co/models
  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(4096),  # Max length of the generation (including input text)
  'ADAPTER_BUCKET': sagemaker_session_bucket,
}

lorax_model = Model(
    image_uri=image_uri[0],
    role=role,
    env=env
)

lorax_predictor = lorax_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout, 
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

In [None]:
# You can reinstantiate the Predictor object if you restart the notebook or Predictor is None
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
endpoint_name = endpoint_name

lorax_predictor = Predictor(
    endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

<a id="compare"></a>
## 5. Invoke base model and adapter, compare outputs

We can invoke the base Mistral model, as well as any of the adapters in our bucket! LoRAX will take care of downloading them, continuously batch requests for different adapters, and manage DRAM and RAM by loading/offloading adapters.

Let’s inspect the difference between the base model’s response and the adapter’s response:

<div class="alert alert-block alert-info">
⚠️ I observed a weird error that I haven't debugged yet, where S3 download failed for adapters ID 1 through 5, but worked as expected for all other adapters. Something with the S3 prefix. Added 0 before adapter id if id < 10 as a workaround.
</div>

In [None]:
prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'

payload_base = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
    }
}

payload_adapter = {
"inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": f'{base_prefix}/01',
        # "adapter_id" : adapter_uri, 
        "adapter_source": "s3"
    }
}

response_base = lorax_predictor.predict(payload_base)
response_adapter = lorax_predictor.predict(payload_adapter)

print(f'Base model output:\n-------------\n {response_base[0]["generated_text"]}')
print(f'Adapter output:\n-------------\n {response_adapter[0]["generated_text"]}')

<a id="benchmark"></a>
## 6. Benchmark single adapter vs. random access to adapters



First, we individually call each of the adapters in sequence, to make sure they are previously downloaded to the endpoint instance’s disk. We want to exclude S3 download latency from the benchmark metrics.

In [None]:
from tqdm import tqdm

for i in tqdm(range(1,n_adapters+1)):
    adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
    payload_adapter = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 64,
        "adapter_id": adapter_id,
        "adapter_source": "s3"
        }
    }
    lorax_predictor.predict(payload_adapter)

Now we are ready to benchmark. For the single adapter case, we invoke the adapter `total_requests` times from `num_threads` concurrent clients.

For the multi-adapter case, we invoke a random adapter from any of the clients, until all adapters have been invoked `total_requests//num_adapters` times.

In [None]:
# Adjust if you run into connection pool errors
# import botocore

# Configure botocore to use a larger connection pool
# config = botocore.config.Config(max_pool_connections=100)

In [None]:
import threading
import time
import random


# Configuration
total_requests = 300
num_adapters = 50
num_threads = 20  # Adjust based on your system capabilities


# Shared lock and counters for # invocations of each adapter 
adapter_counters = [total_requests // num_adapters] * num_adapters
counters_lock = threading.Lock()

def invoke_adapter(aggregate_latency, single_adapter=False):
    global total_requests
    latencies = []
    while True:
        with counters_lock:
            if single_adapter:
                adapter_id = 1
                if total_requests > 0:
                    total_requests -= 1
                else:
                    break
            else:
                # Find an adapter that still needs to be called
                remaining_adapters = [i for i, count in enumerate(adapter_counters) if count > 0]
                if not remaining_adapters:
                    break
                adapter_id = random.choice(remaining_adapters) + 1
                adapter_counters[adapter_id - 1] -= 1

        prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'
        invoke_adapter_id = f'{base_prefix}/0{i}' if i < 10 else f'{base_prefix}/{i}'
        payload_adapter = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 64,
                "adapter_id": invoke_adapter_id,
                "adapter_source": "s3"
            }
        }
        start_time = time.time()
        response_adapter = lorax_predictor.predict(payload_adapter)
        latency = time.time() - start_time
        latencies.append(latency)

    aggregate_latency.extend(latencies)

def benchmark_scenario(single_adapter=False):
    threads = []
    all_latencies = []
    start_time = time.time()

    for _ in range(num_threads):
        thread_latencies = []
        all_latencies.append(thread_latencies)
        thread = threading.Thread(target=invoke_adapter, args=(thread_latencies, single_adapter))
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    total_latency = sum([sum(latencies) for latencies in all_latencies])
    total_requests_made = sum([len(latencies) for latencies in all_latencies])
    average_latency = total_latency / total_requests_made
    throughput = total_requests_made / (time.time() - start_time)

    print(f"Total Time: {time.time() - start_time}s")
    print(f"Average Latency: {average_latency} s")
    print(f"Throughput: {throughput} requests/s")

# Run benchmarks
print("Benchmarking: Single Adapter Multiple Times")
benchmark_scenario(single_adapter=True)

print("\nBenchmarking: Multiple Adapters with Random Access")
benchmark_scenario()


<a id="cleanup"></a>
## 7. Cleanup endpoint resources

In [None]:
lorax_predictor.delete_endpoint()