# Standard instruction for using LMI container on SageMaker
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker boto3 awscli --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

### (Remove if not needed) upload HuggingFace model to S3 bucket

LMI has good capability to download model in a S3 bucket. This step is recommended if you would like to speed up model loading process

In [None]:
%pip install huggingface_hub --upgrade --quiet

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import os

# - This will download the model into the ./model directory where ever the jupyter file is running
local_model_path = Path("/tmp")
local_model_path.mkdir(exist_ok=True)
model_name = "facebook/opt-6.7b"
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model"]

# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_model_path,
    allow_patterns=allow_patterns,
)


Upload the model to S3 bucket

In [None]:
key_prefix="lmi_models/mymodel"
model_artifact = sess.upload_data(path=model_download_path, key_prefix=key_prefix)
print(f"Model uploaded to --- > {model_artifact}")
print(f"You can set option.s3url={model_artifact}")

## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install

In [None]:
%%writefile serving.properties
# Start writing content here

In [None]:
%%writefile model.py
# Start writing content here (remove this cell if not used)

In [None]:
%%writefile requirements.txt
# Start writing content here (remove this file if not neeed)

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
# remove the following lines if not needed
mv model.py mymodel/
mv requirements.txt mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

## Step 3: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI

Available framework are:
- djl-deepspeed (0.22.1, 0.23.0, 0.24.0, 0.25.0)
- djl-fastertransformer (0.22.1, 0.23.0, 0.24.0, 0.25.0)
- djl-neuronx (0.22.1, 0.23.0, 0.24.0, 0.25.0)

In [None]:
image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.25.0"
    )

### Upload artifact on S3 and create SageMaker model

In [None]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
env = {"HUGGINGFACE_HUB_CACHE": "/tmp", "TRANSFORMERS_CACHE": "/tmp"}

model = Model(image_uri=image_uri, model_data=code_artifact, env=env, role=role)

### 4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [None]:
instance_type = "ml.g4dn.4xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

## Step 5: Test and benchmark the inference

In [None]:
%%timeit -n3 -r1
predictor.predict(
    {"inputs": "Large model inference is", "parameters": {}}
)

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()