# Hugging Faceで公開されている大規模言語モデルをSageMakerにデプロイ


* 対象モデル
  
  stabilityai/japanese-stablelm-base-alpha-7b
  
  https://huggingface.co/stabilityai/japanese-stablelm-base-alpha-7b


### SageMakerライブラリーのインストール

In [None]:
%pip install sagemaker --upgrade


### インポート

In [None]:
import sagemaker
import boto3


### IAMロールの取得

In [None]:
try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role_name = 'AmazonSageMaker-ExecutionRole-20230617T201891' # Role name with `AmazonSageMakerFullAccess` policy attached
	role = iam.get_role(RoleName=role_name)['Role']['Arn']


### モデル名などのパラメーターを指定

In [None]:
model_id = 'stabilityai/japanese-stablelm-base-alpha-7b'
instance_type = 'ml.g5.2xlarge'


### SageMakerへデプロイ

`HF_MODEL_TRUST_REMOTE_CODE`の指定が必要です

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel

# Hub model configuration <https://huggingface.co/models>
hub = {
  'HF_MODEL_ID': model_id, # model_id from hf.co/models
  'HF_TASK':'text-generation',          # NLP task you want to use for predictions
  'HF_MODEL_TRUST_REMOTE_CODE': 'true'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
  env=hub,                            # configuration for loading model from Hub
  role=role,                          # IAM role with permissions to create an endpoint
  transformers_version='4.28',        # Transformers version used
  pytorch_version='2.0',             # PyTorch version used
  py_version='py310',                  # Python version used
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
)


### 推論

In [None]:
prompt = """
AI で科学研究を加速するには、
""".strip()

data = {
    'inputs': prompt,
    'parameters': {
        'max_new_tokens': 128,
        'temperature': 1,
        'top_p': 0.95,
        'do_sample': True,
    }
}

# request
result = predictor.predict(data)

result


### エンドポイントの削除

In [None]:
predictor.delete_endpoint(delete_endpoint_config=False)
predictor.delete_model()
