> **_NOTE:_**  **This script is supposed to be executed at SageMaker Notebook!**

## prerequesites
- We have setup an **SageMaker Notebook**, the **S3 bucket** to store the bindle, and config their permission

## Step 1
Use git to clone this file to your SageMaker Notebook instance, and open this run.ipynb at your SageMaker Notebook

## Step 2
Prepare the model file for SageMaker. Run below code blocks in sequence.

In [None]:
!mkdir handler
!mkdir handler/code
!mkdir handler/MAR-INF

In [None]:
%%writefile handler/code/requirements.txt
sentence-transformers==5.0.0

In [None]:
%%writefile handler/MAR-INF/MANIFEST.json
{
  "runtime": "python",
  "model": {
    "modelName": "neuralsparse",
    "handler": "neural_sparse_handler.py",
    "modelVersion": "1.0",
    "configFile": "neural_sparse_config.yaml"
  },
  "archiverVersion": "0.9.0"
}

In [None]:
%%writefile handler/neural_sparse_config.yaml
## configs about dynamic batch inference https://docs.pytorch.org/serve/batch_inference_with_ts.html
## batchSize: the maximum number of requests to aggregate. Each request can contain multiple documents.
batchSize: 16
maxBatchDelay: 5
responseTimeout: 300

In [None]:
%%writefile handler/neural_sparse_handler.py

import os
import re
import itertools
import json
import torch

from ts.torch_handler.base_handler import BaseHandler
from sentence_transformers.sparse_encoder import SparseEncoder

model_id = os.environ.get(
    "MODEL_ID", "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
)
max_bs = int(os.environ.get("MAX_BS", 32))
trust_remote_code = model_id.endswith("gte")

class SparseEncodingModelHandler(BaseHandler):
    def __init__(self):
        super().__init__()
        self.initialized = False

    def initialize(self, context):
        self.manifest = context.manifest
        properties = context.system_properties

        # Print initialization parameters
        print(f"Initializing SparseEncodingModelHandler with model_id: {model_id}")

        # load model and tokenizer
        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else "cpu"
        )
        print(f"Using device: {self.device}")
        self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)
        self.initialized = True

    def preprocess(self, requests):
        inputSentence = []
        batch_idx = []

        for request in requests:
            request_body = request.get("body")
            if isinstance(request_body, bytearray):
                request_body = request_body.decode("utf-8")
                request_body = json.loads((request_body))
            if isinstance(request_body, list):
                inputSentence += request_body
                batch_idx.append(len(request_body))
            else:
                inputSentence.append(request_body)
                batch_idx.append(1)

        return inputSentence, batch_idx

    def handle(self, data, context):
        inputSentence, batch_idx = self.preprocess(data)
        model_output = self.model.encode_document(inputSentence, batch_size=max_bs)
        sparse_embedding = list(map(dict,self.model.decode(model_output)))

        outputs = [sparse_embedding[s:e]
           for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],
                           itertools.accumulate(batch_idx))]
        return outputs

Wrap the handler folder to a tarball. And upload it to your S3 bucket.

In handler/neural_sparse_handler.py, we define the model loading, pre-process, inference and post-process. We use mixed-precision to accelerate the inference.

In handler/neural_sparse_config.yaml, we define some configs for the torch serve (include dynamic micro-batching)

In [None]:
import os

bucket_name = "your_bucket_name"
os.system("tar -czvf neural-sparse-handler.tar.gz -C handler/ .")
os.system(
    f"aws s3 cp neural-sparse-handler.tar.gz s3://{bucket_name}/neural-sparse-handler.tar.gz"
)

## Step 3
Use SageMaker python SDK to deploy the tarball on a real-time inference endpoint

Here we use ml.g5.xlarge. It's a GPU instance with good price-performance.

Please modify the region base according to your settings

In [None]:
# constants that can be customized for models
model_id = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
## The maximum number of documents to encode in a single inference step. Too large number will cause CUDA OOM.
## Even we set batchSize to 16, the actual documents number can be larger. Because one request can contain multiple documents.
max_batch_size = "32"

# constants related to deployment
model_name = "ns-handler"
endpoint_name = "ns-handler"
instance_type = "ml.g5.xlarge"
initial_instance_count = 1

# run this cell
import boto3
import sagemaker
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

role = sagemaker.get_execution_role()
sess = boto3.Session()
region = sess.region_name
smsess = sagemaker.Session(boto_session=sess)

envs = {
    "TS_ASYNC_LOGGING": "true",
    "MODEL_ID": model_id,
    "MAX_BS": max_batch_size,
}

baseimage = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=region,
    py_version="py312",
    image_scope="inference",
    version="2.6",
    instance_type=instance_type,
)

model = Model(
    model_data=f"s3://{bucket_name}/neural-sparse-handler.tar.gz",
    image_uri=baseimage,
    role=role,
    predictor_cls=Predictor,
    name=model_name,
    sagemaker_session=smsess,
    env=envs,
)

endpoint_name = endpoint_name
predictor = model.deploy(
    instance_type=instance_type,
    initial_instance_count=initial_instance_count,
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    ModelDataDownloadTimeoutInSeconds=3600,
    ContainerStartupHealthCheckTimeoutInSeconds=3600,
    VolumeSizeInGB=16,
)

print(predictor.endpoint_name)

## Step 4

After we create the endpoint, use some sample request to see how it works

In [None]:
# run this cell
import json

body = ["Currently New York is rainy."]
amz = boto3.client("sagemaker-runtime")

response = amz.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=json.dumps(body),
    ContentType="application/json",
)

res = response["Body"].read()
results = json.loads(res.decode("utf8"))
results

response:
```json
[{'weather': 1.0386549234390259,
  'york': 1.0295555591583252,
  'ny': 0.9703458547592163,
  'rain': 0.9549243450164795,
  'rainy': 0.9478658437728882,
  'nyc': 0.8449130058288574,
  'new': 0.6880059838294983,
  'raining': 0.6789529323577881,
  'current': 0.6762993931770325,
  'wet': 0.6448248028755188,
  'rainfall': 0.6405332088470459,
  'currently': 0.6092915534973145,
  'now': 0.586189329624176,
  'manhattan': 0.5858010053634644,
  'today': 0.5322379469871521,
  'temperature': 0.5275187492370605,
  'climate': 0.48528429865837097,
  'is': 0.481422483921051,
  'y': 0.4586825370788574,
  '##yo': 0.45718008279800415,
  'cloudy': 0.41763371229171753,
  'it': 0.41397932171821594,
  'forecast': 0.38210317492485046,
  'rains': 0.3785228431224823,
  'rained': 0.35427314043045044,
  'yorkshire': 0.31092309951782227,
  'snow': 0.30391135811805725,
  'yorker': 0.28260838985443115,
  'time': 0.27697092294692993,
  'sunny': 0.2620435059070587,
  'nyu': 0.2503677308559418,
  'in': 0.24964851140975952,
  'windy': 0.2452678382396698,
  'presently': 0.22908653318881989,
  'stormy': 0.21931196749210358,
  'temperatures': 0.21101005375385284,
  'tonight': 0.20632436871528625,
  'present': 0.20109090209007263,
  'this': 0.20102401077747345,
  'us': 0.1935725212097168,
  'nj': 0.18026664853096008,
  'storm': 0.17380213737487793,
  'week': 0.17336463928222656,
  'news': 0.16366833448410034,
  '##storm': 0.16161945462226868,
  'here': 0.14572882652282715,
  'temps': 0.13970820605754852,
  'lately': 0.13716177642345428,
  '##weather': 0.13432787358760834,
  'te': 0.1198926791548729,
  'yesterday': 0.11460382491350174,
  'or': 0.11349867284297943,
  'storms': 0.11013525724411011,
  'sunshine': 0.09905409067869186,
  'usa': 0.09774350374937057,
  'clouds': 0.09281915426254272,
  'humidity': 0.09233205765485764,
  'humid': 0.086763896048069,
  'daylight': 0.08338665962219238,
  'state': 0.08252169191837311,
  'winter': 0.07992527633905411,
  'summer': 0.07536710053682327,
  'fog': 0.06763386726379395,
  'mood': 0.06538641452789307,
  'like': 0.06360717862844467,
  'hurricane': 0.062024328857660294,
  'water': 0.061854153871536255,
  'hudson': 0.0577932633459568,
  'gloom': 0.04488009959459305,
  'flu': 0.04299859702587128,
  'sunday': 0.039578113704919815,
  'brooklyn': 0.03740933537483215,
  'season': 0.03519425913691521,
  'month': 0.026503682136535645,
  'america': 0.025791412219405174,
  'monsoon': 0.01986435428261757,
  'color': 0.015449629165232182,
  'seasons': 0.012146473862230778,
  'does': 0.006621183827519417,
  'snowy': 0.0020988560281693935}]
```

## Step 5
> **_NOTE:_**  **This step is supposed to be executed at an instance have access to OpenSearch cluster!**

Register this SageMaker endpoint at your OpenSearch cluster

Please check the OpenSearch doc for more information. Here we provide one demo request body using access_key and secret_key. Please choose the authentication according to your use case.

### create connector

(Fill the region and predictor.endpoint_name in request body)
```json
POST /_plugins/_ml/connectors/_create
{
  "name": "test",
  "description": "Test connector for Sagemaker model",
  "version": 1,
  "protocol": "aws_sigv4",
  "credential": {
    "access_key": "your access key",
    "secret_key": "your secret key"
  },
  "parameters": {
    "region": "{region}",
    "service_name": "sagemaker",
    "input_docs_processed_step_size": 2,
  },
  "actions": [
    {
      "action_type": "predict",
      "method": "POST",
      "headers": {
        "content-type": "application/json"
      },
      "url": "https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations",
      "request_body": "${parameters.input}"
    }
  ],
  "client_config":{
      "max_retry_times": -1,
      "max_connection": 60,
      "retry_backoff_millis": 10
  }
}
```

### register model
```json
POST /_plugins/_ml/models/_register?deploy=true
{
  "name": "test",
  "function_name": "remote",
  "version": "1.0.0",
  "connector_id": "{connector id}",
  "description": "Test connector for Sagemaker model"
}
```