# Zero shot object detection instruction for using LMI container on SageMaker
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker boto3 awscli --upgrade  --quiet

In [None]:
import sagemaker
from sagemaker.djl_inference.model import DJLModel

role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs

## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (optional): Defines the model server settings
- model.py (required): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install

In [None]:
%%writefile serving.properties
engine=Python
# enable dynamic server side batch
# batch_size=5

In [None]:
%%writefile model.py


import logging

import requests
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

from djl_python import Input
from djl_python import Output


class ZeroShotObjectDetection(object):

    def __init__(self):
        self.device = None
        self.model = None
        self.processor = None
        self.initialized = False

    def initialize(self, properties: dict):
        """
        Initialize model.
        """
        model_id = "IDEA-Research/grounding-dino-base"
        device_id = properties.get("device_id", "-1")
        device_id = "cpu" if device_id == "-1" else "cuda:" + device_id
        self.device = torch.device(device_id)
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
            model_id).to(self.device)
        self.initialized = True

    def inference(self, inputs):
        outputs = Output()
        try:
            batch = inputs.get_batches()
            images = []
            text = []
            sizes = []
            for i, item in enumerate(batch):
                data = item.get_as_json()
                data = data.pop("inputs", data)
                image = Image.open(
                    requests.get(data["image_url"]["url"], stream=True).raw)
                images.append(image)
                text.append(data["text"])
                sizes.append(image.size[::-1])

            model_inputs = self.processor(images=images,
                                          text=text,
                                          return_tensors="pt").to(self.device)
            with torch.no_grad():
                model_outputs = self.model(**model_inputs)

            results = self.processor.post_process_grounded_object_detection(
                model_outputs,
                model_inputs.input_ids,
                box_threshold=0.4,
                text_threshold=0.3,
                target_sizes=sizes)
            for i, result in enumerate(results):
                ret = {
                    "labels": result["labels"],
                    "scores": result["scores"].tolist(),
                    "boxes": result["boxes"].cpu().detach().numpy().tolist(),
                }
                if inputs.is_batch():
                    outputs.add_as_json(ret, batch_index=i)
                else:
                    outputs.add_as_json(ret)
        except Exception as e:
            logging.exception("ZeroShotObjectDetection inference failed")
            # error handling
            outputs = Output().error(str(e))

        return outputs


_service = ZeroShotObjectDetection()


def handle(inputs: Input):
    """
    Default handler function
    """
    if not _service.initialized:
        # stateful model
        _service.initialize(inputs.get_properties())

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)

In [None]:
# %%writefile requirements.txt
# Start writing content here (remove this file if not neeed)

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
# mv requirements.txt mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

## Step 3: Upload model artifact to S3

In [None]:
bucket = session.default_bucket()  # default bucket to host artifacts
code_artifact = session.upload_data("mymodel.tar.gz", bucket, "lmi-model")
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

## Step 4: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI (optional)

Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)

In [None]:
# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"

### Create SageMaker model

Here we are using [LMI PySDK](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to create the model.

Checkout more [configuration options](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/configurations.html#environment-variable-configurations).

In [None]:
model = DJLModel(
    model_data=code_artifact,
    #image_uri=image_uri,          # choose a specific version of LMI DLC image
    role=role)

### Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [None]:
instance_type = "ml.g4dn.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

predictor = model.deploy(initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

## Step 5: Run inference

In [None]:
predictor.predict(
    {
        "text": "a cat. a remote control.",
        "image_url": {
            "url": "http://images.cocodataset.org/val2017/000000039769.jpg"
        }
    }
)

## Clean up the environment

In [None]:
session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()