# Deploying the model to a SageMaker endpoint

In [None]:
# IP Address of the Elasticsearch server
IP_ADDRESS_ES = <IP address of ES server>

In [None]:
import sagemaker

sess = sagemaker.Session()
sagemaker_session_bucket = sess.default_bucket()
role = sagemaker.get_execution_role()

In [None]:
%%writefile code/inference.py
from haystack.nodes import BM25Retriever
from haystack.nodes import FARMReader
from haystack.pipelines import ExtractiveQAPipeline
from haystack.document_stores import ElasticsearchDocumentStore
import json


def dumper(obj):
    try:
        return obj.toJSON()
    except:
        return obj.__dict__


def model_fn(model_dir):
    document_store = ElasticsearchDocumentStore(
        host=$IP_ADDRESS_ES,
        port=9200,
        index="document"
    )
    
    retriever = BM25Retriever(document_store=document_store)
    reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
    pipe = ExtractiveQAPipeline(reader, retriever)
    
    return pipe


def predict_fn(data, pipe):
    query = data.pop("inputs")
    params = data.pop("parameters", None)

    prediction = pipe.run(query=query, params=params)
    response = json.dumps(prediction, default=dumper)
    
    return {"response": response}

In [None]:
%cd model
!tar zcvf model.tar.gz *

In [None]:
s3_location=f"s3://{sagemaker_session_bucket}/haystack-demo/model.tar.gz"
!aws s3 cp model.tar.gz $s3_location

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel

huggingface_model = HuggingFaceModel(
   model_data=s3_location,
   role=role,
   transformers_version="4.17",
   pytorch_version="1.10",
   py_version='py38',
)

In [None]:
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
)

## Testing the endpoint

In [None]:
data = {
  "inputs": "Who killed Tywin?", "parameters": {"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}
}
res = predictor.predict(data=data)

In [None]:
import json
result = json.loads(res['response'])

In [None]:
for r in result['answers']:
    print(f"Answer: {r['answer']}\nConfidence: {r['score']*100:.1f}%\nContext: {r['context']}\nDocument: {r['meta']['name']}\n")