# Sentence Embeddings with Hugging Face Transformers, Sentence Transformers and Amazon SageMaker - Custom Inference for creating document embeddings with Hugging Face's Transformers


Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a an [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) running a Sentence Transformers for document embeddings. 
Currently, the [SageMaker Hugging Face Inference Toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) supports the [pipeline feature](https://huggingface.co/transformers/main_classes/pipelines.html) from Transformers for zero-code deployment. This mean you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefor we only need to provide an environment variable `HF_TASK` and `HF_MODEL_ID` when creating our endpoint and the Inference Toolkit will take care of it. 
This is a great feature if you are working with existing [pipelines](https://huggingface.co/transformers/main_classes/pipelines.html). 

If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via a `inference.py` script. The Hugging Face Inference Toolkit allows the user to override the default methods of the `HuggingFaceHandlerService`. 

The custom module can override the following methods:  

* `model_fn(model_dir)` overrides the default method for loading a model. The return value `model` will be used in `predict` for predictions. `predict` receives argument the `model_dir`, the path to your unzipped `model.tar.gz`.
* `transform_fn(model, data, content_type, accept_type)` overrides the default transform function with your custom implementation. You will need to implement your own `preprocess`, `predict` and `postprocess` steps in the `transform_fn`. This method can't be combined with `input_fn`, `predict_fn` or `output_fn` mentioned below.
* `input_fn(input_data, content_type)` overrides the default method for preprocessing. The return value `data` will be used in `predict` for predicitions. The inputs are:
  - `input_data` is the raw body of your request.
  - `content_type` is the content type from the request header.
* `predict_fn(processed_data, model)` overrides the default method for predictions. The return value `predictions` will be used in `postprocess`. The input is `processed_data`, the result from `preprocess`.
* `output_fn(prediction, accept)` overrides the default method for postprocessing. The return value `result` will be the response of your request (e.g.`JSON`). The inputs are:
  - `predictions` is the result from `predict`.
  - `accept` is the return accept type from the HTTP Request, e.g. `application/json`.

In this example are we going to use a Sentence Transformers to create document embeddings using mean pooling. 


_NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances_



## Development Environment and Permissions

### Installation 


In [None]:
%pip install sagemaker --upgrade

In [None]:
import sagemaker

assert sagemaker.__version__ >= "2.75.0"

### Permissions

_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Create custom `inference.py` script

To use the custom inference script, you need to create a `inference.py` script. In our example we are going to overwrite the `model_fn` to load our sentence transformer correctly and the `predict_fn` to apply mean pooling.

We are going to use the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.

In [None]:
%%writefile scripts/inference.py

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
    # return dictonary, which will be json serializable
    return {"vectors": sentence_embeddings}

: 

## Create custom `HuggingfaceModel` 


In [8]:
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join

# Hub Model configuration. <https://huggingface.co/models>
hub = {
    'HF_MODEL_ID':'sentence-transformers/all-MiniLM-L6-v2',
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   entry_point="inference.py",   # entry point of the script
   source_dir="code",            # source directory of the script
   env=hub,                      # configuration for loading model from Hub
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.12",  # transformers version used
   pytorch_version="1.9",        # pytorch version used
   py_version='py38',            # python version used
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
    )

------------!

## Request Inference Endpoint using the `HuggingfacePredictor`

The `.deploy()` returns an `HuggingFacePredictor` object which can be used to request inference.

In [11]:
data = {
  "inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .",
}

res = predictor.predict(data=data)
print(res)

[{'label': 'LABEL_2', 'score': 0.8808117508888245}, {'label': 'LABEL_0', 'score': 0.6126593947410583}, {'label': 'LABEL_2', 'score': 0.9425230622291565}, {'label': 'LABEL_0', 'score': 0.5511414408683777}]


### Delete the async inference endpoint & Autoscaling policy

In [44]:
response = asg_client.deregister_scalable_target(
    ServiceNamespace='sagemaker',
    ResourceId=resource_id,
    ScalableDimension='sagemaker:variant:DesiredInstanceCount'
)
async_predictor.delete_endpoint()