## 使用Sagemaker部署huggingface模型

本notebook主要包含以下内容
1. 部署huggingface模型到sagemaker推理endpoint
2. 使用两种模式使用endpoint进行推理

### 准备工作
0. 申请对应机型的配额，通常大致需求如下
    * 7B, ml.g5.2xlarge
    * 13B, ml.g5.12xlarge
    * 70B, ml.g5.24xlarge+
1. 创建Sagemaker notebook（比如t3.medium）
2. 复制并打开此 notebook

### 部署模型

In [None]:
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
	'HF_MODEL_ID':'TheBloke/Airoboros-L2-70B-3.1.2-GPTQ',
	'SM_NUM_GPUS': json.dumps(4),
    'QUANTIZE': 'gptq'                                      # 根据模型参数配置设置
}



# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.2"),
	env=hub,
	role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.24xlarge",             # 确保配额足够
	container_startup_health_check_timeout=300,
  )
  


### 推理

#### 使用 predictor

In [None]:
# send request
predictor.predict({
	"inputs": "Hey my name is Julien! How are you?",
})

#### 使用 sagemaker runtime调用

In [None]:
import boto3
import json

client = boto3.client('sagemaker-runtime')

# Sagemaker endpoint 名称
endpoint_name = "huggingface-pytorch-tgi-inference-2024-03-29-08-47-13-951"                                       # Your endpoint name.
content_type = "application/json"                                        # The MIME type of the input data in the request body.
accept = "application/json"                                              # The desired MIME type of the inference in the response.
payload = """
{
  "inputs": "Tell me a story about cow boy",
  "parameters": {
            "temperature": 0.7,
            "max_new_tokens": 1024,
            "top_p": 0.95,
            "top_k": 40
  }
}
"""                                          # Payload for inference.
response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Accept=accept,
    Body=bytes(payload, 'UTF-8')
    )

result = json.loads(response['Body'].read().decode('UTF-8'))

print(result)