In [None]:
%%sh
pip install -qU pip sagemaker

In [None]:
import json

import boto3
from IPython.display import Markdown, display

import sagemaker
from sagemaker.djl_inference.model import DJLModel

In [None]:
role = sagemaker.get_execution_role()

In [None]:
# https://docs.djl.ai/master/docs/serving/serving/docs/lmi/index.html

model = DJLModel(
    model_id="arcee-ai/arcee-lite",
    role=role,
    env={
        "OPTION_DTYPE": "bf16",
        "OPTION_MAX_MODEL_LEN": "32768",
        "OPTION_TRUST_REMOTE_CODE": "true",
        "OPTION_ROLLING_BATCH": "vllm",
        "TENSOR_PARALLEL_DEGREE": "max",
        "OPTION_MAX_ROLLING_BATCH_SIZE": "16",
    },
)

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",
    container_startup_health_check_timeout=300,
)

In [None]:
predictor.endpoint_name

# Model Inference

#### Inference with the SageMaker SDK

In [None]:
# https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/lmi_input_output_schema.html

prompt = """Please write a marketing pitch for a new SaaS AI platform called Arcee Cloud.
We will send this pitch by email to business and technical decision-makers, so make it sound exciting yet professional.
The contact email is sales@arcee.ai. Feel free to use emojis as appopriate.
Arcee Cloud makes it simple for enterprise users to tailor open-source small language models to their own domain knowledge,
in order to build high-quality, cost-effective and secure AI solutions."""

body = {
    "inputs": prompt,
    "parameters": {
        "do_sample": True,
        "max_new_tokens": 2048,
        "stream": "false",
        "details": "true",
    },
}

In [None]:
%%time
response = predictor.predict(body)

In [None]:
generated_tokens = response["details"]["generated_tokens"]
finish_reason = response["details"]["finish_reason"]
print(f"Generated tokens: {generated_tokens}, finish reason: {finish_reason}")

In [None]:
display(Markdown(response["generated_text"]))

#### Inference with the boto3 SDK

In [None]:
smrt_client = boto3.client("sagemaker-runtime")

In [None]:
%%time
response = smrt_client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(body),
)

In [None]:
response = json.load(response["Body"])["generated_text"]
display(Markdown(response))

# Clean up

In [None]:
predictor.delete_endpoint()