#  Serve Salesforce/xgen-7b-8k-base model with Amazon SageMaker Hosting

üëã Hey there! üåü 

Let's take a walk-through together on how to deploy and perform inference on the **Salesforce Xgen-7B-8K-base** model using the **Large Model Inference (LMI)** container provided by AWS with the help of **DJL Serving**. üòÑ

Since the **Salesforce Xgen-7B-8K-base** is a relatively small language model (LLM) that can be easily accommodated on a single GPU, we'll make use of the `ml.g5.2xlarge` instance, which comes with **1** GPU. üñ•Ô∏è

## Setup

To get started, you'll need to install the necessary dependencies for packaging your model and running inferences on Amazon SageMaker. Don't worry, it's a simple process! Just make sure to update SageMaker and boto3 too. üöÄ

In [1]:
!pip install sagemaker boto3 --upgrade  --quiet

## Imports and variables

In [2]:
import os
import time
import json
import boto3
import jinja2
import sagemaker
from pathlib import Path
from sagemaker import image_uris
from sagemaker.utils import name_from_base

In [3]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
model_bucket = sess.default_bucket()  # bucket to house artifacts
hf_model_id = 'Salesforce/xgen-7b-8k-base'
model_id = hf_model_id.replace('/','-')
s3_code_prefix_accelerate = f"hf-large-model/{model_id}/accelerate"  # folder within bucket where code artifact will go

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

jinja_env = jinja2.Environment()

### 1. Create SageMaker compatible model artifacts

To get our model ready for deployment on a SageMaker Endpoint, we need to prepare a few things for both SageMaker and our container. No worries, it's a straightforward process! We'll use a local folder to store these files, including some important ones like serving.properties (which defines parameters for the LMI container) and requirements.txt (to specify the dependencies we need to install). üìÇ

In [4]:
directory_name = f"code_{model_id.replace('-','_')}_accelerate"
os.makedirs(directory_name, exist_ok=True)

In the serving.properties file, you'll need to define the engine to use and the model you want to host. Pay attention to the tensor_parallel_degree parameter, as it's essential in this scenario. If a single GPU doesn't have enough memory to handle the entire model, you can use tensor parallelism >1 to divide the model into multiple parts.

For your deployment, we'll be using a 'ml.g5.2xlarge' instance, which comes with 1 GPU and is sufficient for loading our model. Just make sure not to specify a value larger than what the instance provides, or your deployment might encounter issues. ‚ùåüôÖ‚Äç‚ôÇÔ∏è

In [5]:
%%writefile ./{directory_name}/serving.properties
engine=Python
option.model_id={{hf_model_id}}
option.tensor_parallel_degree=1

Overwriting ./code_Salesforce_xgen_7b_8k_base_accelerate/serving.properties


In [6]:
%%writefile ./{directory_name}/requirements.txt
torch==2.0.1
einops==0.5.0
tiktoken
transformers==4.30.2
accelerate

Overwriting ./code_Salesforce_xgen_7b_8k_base_accelerate/requirements.txt


In [7]:
# we plug in the appropriate model location into our `serving.properties` file based on the region in which this notebook is running
template = jinja_env.from_string(Path(f"{directory_name}/serving.properties").open().read())
Path(f"{directory_name}/serving.properties").open("w").write(
    template.render(hf_model_id=hf_model_id)
)
!pygmentize {directory_name}/serving.properties | cat -n

     1	[36mengine[39;49;00m=[33mPython[39;49;00m[37m[39;49;00m
     2	[36moption.model_id[39;49;00m=[33mSalesforce/xgen-7b-8k-base[39;49;00m[37m[39;49;00m
     3	[36moption.tensor_parallel_degree[39;49;00m=[33m1[39;49;00m[37m[39;49;00m


### 2. Create a model.py with custom inference code

With SageMaker, you have the flexibility to bring your own script for inference. In this case, we need to create a model.py file with the necessary code for the Salesforce Xgen-7b-8k-base model.

I've provided two scripts below, and both of them will work. However, I recommend using the second one (uncommented) as it produces slightly faster responses. This is because it utilizes the generate() API instead of the pipeline() API.

If you'd like more information on the difference between the pipeline and generate APIs, you can check out this helpful [Ref](https://discuss.huggingface.co/t/pipeline-vs-model-generate/26203). üìö

In [8]:
# %%writefile ./{directory_name}/model.py
# from djl_python import Input, Output
# import os
# import torch
# import transformers
# from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# from typing import Any, Dict, Tuple
# import warnings

# predictor = None
# print("transformers version"+ transformers.__version__)


# def get_model(properties):
#     model_name = properties["model_id"]
#     local_rank = int(os.getenv("LOCAL_RANK", "0"))
#     tokenizer = AutoTokenizer.from_pretrained(model_name, 
#                                           trust_remote_code=True,
#                                          )
#     model = AutoModelForCausalLM.from_pretrained(model_name, 
#                                                  torch_dtype=torch.bfloat16,
#                                                  device_map="auto"
#                                                 )
#     generator = pipeline(
#         task="text-generation", model=model, tokenizer=tokenizer, device_map="auto", torch_dtype=torch.bfloat16
#     )
#     return generator


# def handle(inputs: Input) -> None:
#     global predictor
#     if not predictor:
#         predictor = get_model(inputs.get_properties())
#     if inputs.is_empty():
#         # Model server makes an empty call to warmup the model on startup
#         return None
#     data = inputs.get_as_json()
#     text = data.pop("text", data)
#     parameters = data.pop("parameters", None)
#     outputs = predictor(text, **parameters)
#     result = {"generated_text": outputs[0]['generated_text']}
#     return Output().add_as_json(result)

In [9]:
%%writefile ./{directory_name}/model.py
from djl_python import Input, Output
import os
import torch
import transformers
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from typing import Any, Dict, Tuple
import warnings

predictor = None
print("transformers version"+ transformers.__version__)


def get_model(properties):
    model_name = properties["model_id"]
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    tokenizer = AutoTokenizer.from_pretrained(model_name, 
                                          trust_remote_code=True,
                                         )
    model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                 torch_dtype=torch.bfloat16,
                                                 device_map="auto"
                                                )
    predictor = {"model": model, "tokenizer": tokenizer}
    return predictor 


def handle(inputs: Input) -> None:
    global predictor
    if not predictor:
        predictor = get_model(inputs.get_properties())
    if inputs.is_empty():
        # Model server makes an empty call to warmup the model on startup
        return None
    model, tokenizer = predictor["model"], predictor["tokenizer"]
    data = inputs.get_as_json()
    text = data.pop("text", data)
    params = data.pop("parameters", None)
    inputs = tokenizer(text, return_tensors="pt")
    with torch.inference_mode():
        sample = model.generate(**inputs, **params)
    result = {"generated_text": tokenizer.decode(sample[0])}
    return Output().add_as_json(result)

Overwriting ./code_Salesforce_xgen_7b_8k_base_accelerate/model.py


### 3. Create the Tarball and then upload to S3 location
Now, let's package our artifacts as *.tar.gz files, which we'll upload to S3. These files will be used by SageMaker for deployment. üì¶üí®

In [10]:
!rm -f model.tar.gz
!rm -rf {directory_name}/.ipynb_checkpoints
!tar czvf model.tar.gz -C {directory_name} .
s3_code_artifact_accelerate = sess.upload_data("model.tar.gz", bucket, s3_code_prefix_accelerate)
print(f"S3 Code or Model tar for accelerate uploaded to --- > {s3_code_artifact_accelerate}")

./
./requirements.txt
./serving.properties
./model.py
S3 Code or Model tar for accelerate uploaded to --- > s3://sagemaker-eu-west-1-069230569860/hf-large-model/Salesforce-xgen-7b-8k-base/accelerate/model.tar.gz


### 4. Define a serving container, SageMaker Model and SageMaker endpoint

Now, we can move on to creating a SageMaker endpoint to serve our model. üöÄ

#### Define the serving container
In this step, we'll specify the container to be used for the model during inference. For optimal performance, we'll be using SageMaker's Large Model Inference (LMI) container with Accelerate. ‚ö°Ô∏èüß™ 

In [11]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.22.1-deepspeed0.8.3-cu118"
)

print(f"Image going to be used is ---- > {inference_image_uri}")

Image going to be used is ---- > 763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.22.1-deepspeed0.8.3-cu118


#### Create SageMaker model, endpoint configuration and endpoint.


In [12]:
model_name_acc = name_from_base(model_id)
print(model_name_acc)

Salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807


In [13]:
create_model_response = sm_client.create_model(
    ModelName=model_name_acc,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": s3_code_artifact_accelerate},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

Created Model: arn:aws:sagemaker:eu-west-1:069230569860:model/salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807


In [14]:
model_name = model_name_acc
print(f"Building EndpointConfig and Endpoint for: {model_name}")

Building EndpointConfig and Endpoint for: Salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807


In [15]:
endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g5.2xlarge",
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": 3600,
            "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
            # "VolumeSizeInGB": 512
        },
    ],
)
endpoint_config_response

{'EndpointConfigArn': 'arn:aws:sagemaker:eu-west-1:069230569860:endpoint-config/salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807-config',
 'ResponseMetadata': {'RequestId': 'e48feb47-1c25-4a43-b7e4-b1c602c22483',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'e48feb47-1c25-4a43-b7e4-b1c602c22483',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '138',
   'date': 'Thu, 29 Jun 2023 14:03:10 GMT'},
  'RetryAttempts': 0}}

In [16]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Created Endpoint: arn:aws:sagemaker:eu-west-1:069230569860:endpoint/salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807-endpoint


In [17]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:eu-west-1:069230569860:endpoint/salesforce-xgen-7b-8k-base-2023-06-29-14-03-10-807-endpoint
Status: InService


### Let's use the endpoint & run Inference

In [18]:
%%timeit -r 5
response_model = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps({"text": "The population of Greece is",
                     "parameters": {
                          "max_new_tokens": 500, #the higher the longer the response time
                          "temperature": 0.1,
                          "top_p": 0.75,
                          "top_k": 40,
                          "repetition_penalty": 1.9,
                          "do_sample": True,
                          "num_return_sequences": 1,
                          #"return_full_text":False, # avoid returning pr
                          "best_of": None, 
                          "truncate": None,
                     }}
                     ),
    ContentType="application/json",
)

38.8 s ¬± 86.7 ms per loop (mean ¬± std. dev. of 5 runs, 1 loop each)


In [19]:
prompt = "The population of Greece is"
response_model = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps({"text": prompt,
                     "parameters": {
                          "max_new_tokens": 500, #the higher the longer the response time
                          "temperature": 0.01,
                          "top_p": 0.85,
                          "top_k": 40,
                          "repetition_penalty": 1.9,
                          "do_sample": True,
                          "num_return_sequences": 1,
                          #"return_full_text":False, # avoid returning pr
                          "best_of": None, 
                          "truncate": None,
                     }}
                     ),
    ContentType="application/json",
)

def process_generated_text(text, stopwords, prompt=None):
    if prompt:
        text = text[len(prompt):]
    
    for word in stopwords:
        position = text.find(word)
        if position != -1:
            text = text[:position]
    return text
    

r = response_model["Body"].read().decode("utf8")

# Load the JSON string as a dictionary
data_dict = json.loads(r)

# Access the dictionary elements
generated_text = data_dict['generated_text']

# Print the generated text
print(process_generated_text(generated_text, ['<|endoftext|>'], prompt=prompt))

 10.8 million people (2016). The country has a very high life expectancy, with the average Greek living to be 79 years old and women live on an even longer time ‚Äì 83 year-old!
Greece‚Äôs economy relies heavily upon tourism as it accounts for about 20% GDP in total revenue from all sources including exports which are also significant at around 15%. Tourism makes up over 30 percent or more than $10 billion dollars annually according some estimates while others say that figure could reach closer towards 40%, but either way this number will continue growing due mainly because there aren't enough hotels available right now so demand keeps increasing every day without any sign off slowing down anytime soon especially since they have just opened their first ever casino last month called ‚ÄúKonstantinos Casino Resort & Spa" located near Athens International Airport where many international travelers pass through daily en route home after visiting other countries like Italy etc., thus creatin

In [20]:
print(generated_text)

The population of Greece is 10.8 million people (2016). The country has a very high life expectancy, with the average Greek living to be 79 years old and women live on an even longer time ‚Äì 83 year-old!
Greece‚Äôs economy relies heavily upon tourism as it accounts for about 20% GDP in total revenue from all sources including exports which are also significant at around 15%. Tourism makes up over 30 percent or more than $10 billion dollars annually according some estimates while others say that figure could reach closer towards 40%, but either way this number will continue growing due mainly because there aren't enough hotels available right now so demand keeps increasing every day without any sign off slowing down anytime soon especially since they have just opened their first ever casino last month called ‚ÄúKonstantinos Casino Resort & Spa" located near Athens International Airport where many international travelers pass through daily en route home after visiting other countries li

### Clean Up

In [21]:
# Delete the endpoint
sm_client.delete_endpoint(EndpointName=endpoint_name)

{'ResponseMetadata': {'RequestId': '287929f4-f5de-4dd7-9516-4eb9a59200e3',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '287929f4-f5de-4dd7-9516-4eb9a59200e3',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Thu, 29 Jun 2023 14:21:50 GMT'},
  'RetryAttempts': 0}}

In [22]:
# Delete the model and the endpoint configuration
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=model_name)

{'ResponseMetadata': {'RequestId': '514786ec-624f-46cc-b2cc-78d8af29508a',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '514786ec-624f-46cc-b2cc-78d8af29508a',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Thu, 29 Jun 2023 14:21:51 GMT'},
  'RetryAttempts': 0}}

# Delete all endpoints, endpoint configurations & models

In [46]:
import boto3

def delete_resources(resource_type):
    client = boto3.client('sagemaker')
    list_method = getattr(client, f"list_{resource_type}s")
    delete_method = getattr(client, f"delete_{resource_type}")
    resource_type_name = resource_type.replace('_', ' ').title().replace(' ', '')
    resources = list_method()[f"{resource_type_name}s"]
    for resource in resources:
        resource_name = resource[f"{resource_type_name}Name"]
        print(f"Deleting {resource_type}: {resource_name}")
        try:
            delete_method(**{f"{resource_type_name}Name": resource_name})
            print("Deleted")
        except Exception as e:
            print("An error occurred:", str(e))

def main():
    resource_types = ['model', 'endpoint', 'endpoint_config']  # Add more resource types if needed

    for resource_type in resource_types:
        delete_resources(resource_type)

if __name__ == "__main__":
    main()

Deleting model: Salesforce-xgen-7b-8k-base-2023-06-29-01-22-40-473
Deleted
Deleting model: Salesforce-xgen-7b-8k-base-2023-06-29-00-07-15-303
Deleted
Deleting model: Salesforce-xgen-7b-8k-base-2023-06-29-00-05-44-523
Deleted
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-10-46-22-359-endpoint
An error occurred: An error occurred (ValidationException) when calling the DeleteEndpoint operation: Cannot update in-progress endpoint "arn:aws:sagemaker:eu-west-1:069230569860:endpoint/salesforce-xgen-7b-8k-base-2023-06-29-10-46-22-359-endpoint".
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-01-36-44-565-endpoint
Deleted
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-01-30-00-053-endpoint
Deleted
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-01-22-40-473-endpoint
Deleted
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-00-07-15-303-endpoint
Deleted
Deleting endpoint: Salesforce-xgen-7b-8k-base-2023-06-29-00-05-44-523-endpoint
Deleted
Deleting en