In [1]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path
from sagemaker.utils import name_from_base

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
default_bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name

In [2]:
code_tarname = 'acc_longchat13b_model'

In [3]:
!git clone https://github.com/lm-sys/FastChat.git {code_tarname}

Cloning into 'acc_longchat13b_model'...
remote: Enumerating objects: 3761, done.[K
remote: Counting objects: 100% (154/154), done.[K
remote: Compressing objects: 100% (96/96), done.[K
remote: Total 3761 (delta 87), reused 104 (delta 56), pack-reused 3607[K
Receiving objects: 100% (3761/3761), 30.47 MiB | 46.64 MiB/s, done.
Resolving deltas: 100% (2620/2620), done.


In [None]:
!cp model.py {code_tarname}
!cp requirements.txt {code_tarname}
!cp serving.properties {code_tarname}

!rm -rf {code_tarname}.tar.gz
!rm -rf {code_tarname}/.ipynb_checkpoints
!tar czvf {code_tarname}.tar.gz {code_tarname}/

# copy the deployment configs tar to a path (different from hf model artifacts)
s3_code_artifact = sess.upload_data(f"{code_tarname}.tar.gz",
                                    default_bucket,
                                    sagemaker.utils.name_from_base("tmp/v1") # generate random path
                                   )
print(s3_code_artifact)

In [4]:
from sagemaker.model import Model

# specify a inference container version, found at: 
# https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers
inference_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.22.1-deepspeed0.9.2-cu118"

model = Model(image_uri=inference_image_uri,
              model_data=s3_code_artifact,
              role=role)

endpoint_name = sagemaker.utils.name_from_base(code_tarname.replace('_','-'))

model.deploy(initial_instance_count = 1,
             instance_type = 'ml.g5.12xlarge',
             endpoint_name = endpoint_name,
             container_startup_health_check_timeout = 1200
            )

----------------------!

## Predict

In [5]:
from sagemaker import serializers, deserializers

# endpoint_name = 'acc-longchat7b-model-2023-07-28-06-49-46-747'

predictor = sagemaker.Predictor(
            endpoint_name=endpoint_name,
            sagemaker_session=sess,
            serializer=serializers.JSONSerializer(),
            deserializer=deserializers.JSONDeserializer(),
            )

In [9]:
predictor.predict(
    {"inputs": "tuna sandwich nutritional content is ",
     "parameters": {"do_sample": True, "temperature": 0.7, "repetition_penalty":1.1, "max_new_tokens": 200}
    }
)

{'generated_text': '1079 calories, 14g protein, 6g carbohydrates, and 3.5g fat. The sandwich consists of canned tuna, whole grain bread, lettuce, tomatoes, onions, and mayonnaise or mustard. Tuna is a good source of protein and omega-3 fatty acids, while whole grain bread provides fiber and other nutrients. Lettuce, tomatoes, onions, and mayonnaise or mustard add additional vitamins and minerals to the meal.'}

In [10]:
predictor.predict(
    {"inputs": '''A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input
             USER: how to make a good pizza? ASSISTANT:''', 
     "parameters": {"temperature": 0.7, "repetition_penalty":1.2, "max_new_tokens": 200}
     }
)

{'generated_text': 'Making a great pizza requires a combination of quality ingredients, proper preparation techniques, and attention to detail. Here is a step-by-step guide on how to create a delicious homemade pizza:\n\nIngredients:\n\n* 1 large pizza dough (store-bought or homemade)\n* Freshly grated mozzarella cheese\n* Shredded lettuce\n* Sliced tomatoes\n* Chopped olives (optional)\n* Ground beef or sausage (optional)\n* Your favorite sauce (homemade or store-bought)\n* Olive oil\n* Salt and pepper\n* Optional: fresh basil leaves for garnish\n\nInstructions:\n\n1. Preheat your oven to the highest temperature possible (usually around 500°F / 260°C'}

In [11]:
%%timeit -n3 -r1
predictor.predict(
    {"inputs": "what nutritional contents is in a tuna sandwich", 
     "parameters": {"do_sample": True, "min_length":100, "max_length":100}
     }
)

6.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 3 loops each)
