# Gemma3 Multi-Adapter Serving with Stateful Inference using LMI

This notebook will demonstrates how to use LMI container to serve Gemma3 Multi-Adapter LoRA with stateful sessions enabled.

Stateful sessions is a feature that allows all requests within the same session routed to the same instance, allowing your ML application to reuse previously processed information. This reduces latency and enhances the overall user experience.

Stateful sessions configurations:

* `OPTION_ENABLE_STATEFUL_SESSIONS`: Whether to enable stateful sessions support, defaults to true.
* `OPTION_SESSIONS_PATH`: Specifies the path where session data is saved, defaults to "/dev/shm/djl_sessions".
* `OPTION_SESSIONS_EXPIRATION`: Specifies time in seconds a session remains valid before it expires, defaults to 1200.

## Install Packages and Import Dependencies

In [None]:
!pip install sagemaker boto3 transformers huggingface-hub

## Deploy to SageMaker 

In [None]:
import sagemaker
import boto3
import json

print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name

sm_client = boto3.client(service_name="sagemaker")
sm_runtime = boto3.client(service_name="sagemaker-runtime")

In [None]:
inference_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.34.0-lmi16.0.0-cu128"

print(f"Inference container image: {inference_image_uri}")

### Download and upload adapter weights

In [None]:
from huggingface_hub import snapshot_download
snapshot_download("Cossale/poetry-gemma3-4B-LoRA", local_dir="./adapter1", local_dir_use_symlinks=False)

In [None]:
#
# PLEASE NOTE - Adapter files must be in "tar.gz" file and uploaded to S3
#

adapter_filename = "adapter.tar.gz"
adapter_s3_uri = f"s3://{bucket}/gemma-3-4b-adapter/{adapter_filename}"

print(adapter_s3_uri)

In [None]:
!cd adapter1 && tar -czvf ../{adapter_filename} .

In [None]:
!aws s3 cp {adapter_filename} {adapter_s3_uri}

### Create SageMaker Model and Endpoint

In [None]:
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

model_id = "unsloth/gemma-3-4b-it"

model_name = endpoint_name = "IC-endpoint-gemma3"
base_inference_component_name = "base-" + model_name

env = {
    "HF_MODEL_ID": model_id,
    "SERVING_FAIL_FAST": "True",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_ENABLE_LORA": "true",
    "OPTION_MAX_LORAS": "4",
    "OPTION_MAX_CPU_LORAS": "8",
    "OPTION_MAX_LORA_RANK": "64",
}

lmi_model = sagemaker.Model(image_uri = inference_image_uri,
                            env = env,
                            role = role,
                            name = model_name)


In [None]:
lmi_model.deploy(instance_type = "ml.g6.12xlarge",
                 initial_instance_count = 1,
                 container_startup_health_check_timeout = 600,
                 endpoint_name = endpoint_name,
                 endpoint_type = sagemaker.enums.EndpointType.INFERENCE_COMPONENT_BASED,
                 inference_component_name = base_inference_component_name,
                 resources = ResourceRequirements(requests={"num_accelerators": 1, "memory": 4096, "copies": 1}))

In [None]:
%%time

ic1_adapter_name = f"ic1-adapter-{model_name}"

adapter_create_inference_component_response = sm_client.create_inference_component(
    InferenceComponentName = ic1_adapter_name,
    EndpointName = endpoint_name,
    Specification={
        "BaseInferenceComponentName": base_inference_component_name,
        "Container": {
            "ArtifactUrl": adapter_s3_uri
        },
    },
)

sess.wait_for_inference_component(ic1_adapter_name)

print(f"\nCreated Adapter inference component ARN: {adapter_create_inference_component_response['InferenceComponentArn']}")

In [None]:
import urllib

cw_path = urllib.parse.quote_plus(f'/aws/sagemaker/InferenceComponents/{base_inference_component_name}', safe='', encoding=None, errors=None)

print(f'You can view your inference component logs here:\n\n https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}#logsV2:log-groups/log-group/{cw_path}')

## Start Session

To start a session with a stateful model, send an `InvokeEndpoint` request. In the request payload, set "requestType" to "NEW_SESSION" to start a new session.

In [None]:
payload = {
    "requestType": "NEW_SESSION"
}
payload = json.dumps(payload)

create_session_response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=base_inference_component_name,
    Body=payload,
    ContentType="application/json",
    SessionId="NEW_SESSION")

The LMI container handles the request by starting a new session. The container provides the session ID and expiration timestamp (UTC timezone) by setting the following HTTP header in the response:

```
X-Amzn-SageMaker-Session-Id: session_id; Expires=yyyy-mm-ddThh:mm:ssZ
```

We can extract the session ID from the invoke_endpoint response.

In [None]:
session_id = create_session_response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-new-session-id'].split(';')[0]

print(f"session_id: {session_id}")

## Make Inference Requests

To use the same session for a subsequent inference request, the client sends another `InvokeEndpoint` request, specifying the session ID in the `SessionId` parameter. SageMaker platform then routes the request to the same ML instance where the session was started.

### Invoke Base IC

In [None]:
payload={
    "messages": [
        {"role": "user", "content": "Name popular places to visit in London?"}
    ],
    "temperature": 0.9,
    "max_tokens": 256,
}

component_to_invoke = base_inference_component_name

response_model = sm_runtime.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = component_to_invoke,
    Body = json.dumps(payload),
    ContentType = "application/json",
    SessionId=session_id
)

base_response = json.loads(response_model["Body"].read().decode("utf8"))["choices"][0]["message"]["content"]

print(f'Base Model Response:\n\n {base_response}\n')

### Invoke Adapter IC

In [None]:
payload={
    "messages": [
        {"role": "user", "content": "Name popular places to visit in London?"}
    ],
    "temperature": 0.9,
    "max_tokens": 256,
}

component_to_invoke = ic1_adapter_name

response_model = sm_runtime.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = component_to_invoke,
    Body = json.dumps(payload),
    ContentType = "application/json",
    SessionId=session_id
)

adapter_response = json.loads(response_model["Body"].read().decode("utf8"))["choices"][0]["message"]["content"]

print(f'Adapter Response:\n\n {adapter_response}\n')

## Close Session

To close a session, the client sends a final `InvokeEndpoint` request, providing the session ID in the `SessionId` parameter and setting "requestType" to "CLOSE" in the request payload.

In [None]:
payload = {
    "requestType": "CLOSE"
}
payload = json.dumps(payload)

close_session_response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=base_inference_component_name,
    Body=payload,
    ContentType="application/json",
    SessionId=session_id)

The container returns the session ID by setting the following HTTP header in the response:

```
X-Amzn-SageMaker-Closed-Session-Id: session_id
```

We can extract the closed session ID from the invoke_endpoint response.

In [None]:
closed_session_id = close_session_response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-closed-session-id']

print(f"closed_session_id: {closed_session_id}")

## Clean up Resources

If you need to delete an adapter, call the `delete_inference_component` API with the IC name to remove it. 

In [None]:
sess.delete_inference_component(ic1_adapter_name, wait = True)
print(f'Adapter Component {ic1_adapter_name} deleted.')

Deleting the base model IC will automatically delete the base IC and any associated adapter ICs.

In [None]:
sess.delete_inference_component(base_inference_component_name, wait = True)

print(f'Base Component {base_inference_component_name} deleted.')

Clean up the running endpoint and its configuration.

In [None]:
sess.delete_endpoint(endpoint_name)
print(f'Endpoint {endpoint_name} deleted.')

sess.delete_endpoint_config(endpoint_name)
print(f'Endpoint Configuration {endpoint_name} deleted.')

sess.delete_model(model_name)
print(f'Model {model_name} deleted.')