In [None]:
!pip install huggingface_hub
!pip install -Uq pip
!pip install -Uq boto3 sagemaker

In [None]:
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from utils.LineIterator import LineIterator

try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

hub = {
	'HF_MODEL_ID':'DeepvizLab/newton-cot',
	'SM_NUM_GPUS': json.dumps(1),
    'MAX_INPUT_LENGTH': json.dumps(2048),  # Max length of input text
    # 'MAX_TOTAL_TOKENS': json.dumps(4096),  # Max length of the generation (including input text)
    # 'MAX_BATCH_TOTAL_TOKENS': json.dumps(8192),  # Limits the number of tokens that can be processed in parallel during the generation
    'HUGGING_FACE_HUB_TOKEN':'<REPLACE WITH YOUR TOKEN>'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.2"),
	env=hub,
	role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.2xlarge",
	container_startup_health_check_timeout=1000,
  )


In [19]:
endpoint_name = "'<REPLACE WITH YOUR ENDPOINT NAMES>'"

In [20]:
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload), 
        ContentType="application/json",
        CustomAttributes='accept_eula=false'
    )
    return response_stream

def print_response_stream(response_stream):
    event_stream = response_stream['Body']
    start_json = b'{'
    stop_token = '</s>'
    for line in LineIterator(event_stream):
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if data['token']['text'] != stop_token:
                print(data['token']['text'],end='')

def build_prompt(text):
  system_prompt = "You have to reply to the user requests"
  return f"""
   <s>[INST] <<SYS>>
   { system_prompt }
   <</SYS>>
    {text} [/INST]
  """

In [None]:
sagemaker_runtime = boto3.client('sagemaker-runtime')
prompt = "Tell me somthing about the Roman Empire"

inference_params = {
        "do_sample": True,
        "top_p": 0.6,
        "temperature": 0.9,
        "top_k": 50,
        "max_new_tokens": 512,
        "repetition_penalty": 1.03,
        "stop": ["</s>"],
        "return_full_text": False
    }

payload = {
    "inputs":  build_prompt(prompt),
    "parameters": inference_params,
    "stream": True ## <-- to have response stream.
}

resp = get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload)
print_response_stream(resp)