# Hugging Faceで公開されている大規模言語モデルをSageMakerにデプロイ


* 対象モデル
  
  line-corporation/japanese-large-lm-3.6b-instruction-sft
  
  https://huggingface.co/line-corporation/japanese-large-lm-3.6b-instruction-sft


### SageMakerライブラリーのインストール

In [None]:
%pip install sagemaker --upgrade


### インポート

In [None]:
import sagemaker
import boto3


### IAMロールの取得

In [None]:
try:
	role = sagemaker.get_execution_role()
except ValueError:
	iam = boto3.client('iam')
	role_name = 'AmazonSageMaker-ExecutionRole' # Role name with `AmazonSageMakerFullAccess` policy attached
	role = iam.get_role(RoleName=role_name)['Role']['Arn']


### モデル名などのパラメーターを指定

In [None]:
model_id = 'line-corporation/japanese-large-lm-3.6b-instruction-sft'
instance_type = 'ml.g5.2xlarge'


### SageMakerへデプロイ

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel

# Hub model configuration <https://huggingface.co/models>
hub = {
  'HF_MODEL_ID': model_id, # model_id from hf.co/models
  'HF_TASK':'text-generation'          # NLP task you want to use for predictions
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
  env=hub,                            # configuration for loading model from Hub
  role=role,                          # IAM role with permissions to create an endpoint
  transformers_version='4.26',        # Transformers version used
  pytorch_version='1.13',             # PyTorch version used
  py_version='py39',                  # Python version used
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
)


### 推論

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)

input_text = '''四国の県名を全て列挙してください。'''

data = {
   'inputs': f'ユーザー: {input_text}\nシステム: ',
   'parameters': {
      'max_length': 1024,
      'do_sample': True,
      'temperature': 0.7,
      'top_p': 0.9,
      'top_k': 0,
      'repetition_penalty': 1.1,
      'num_beams': 1,
      'pad_token_id' : tokenizer.pad_token_id,
      'num_return_sequences': 1,    
   }
}

# request
result = predictor.predict(data)

result


### エンドポイントの削除

In [None]:
predictor.delete_endpoint(delete_endpoint_config=False)
predictor.delete_model()
