# Flan-t5-xlのデモ

- [Flan-t5-xlのモデル](https://huggingface.co/google/flan-t5-xl)
- [参考にしたサンプルのnotebook](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text2text-generation-flan-t5.ipynb)

## flan-t5が動くイメージのURIを取得

In [1]:
from sagemaker import image_uris


model_id, model_version, = (
    "huggingface-text2text-flan-t5-xl",
    "*",
)

inference_instance_type = "ml.p3.2xlarge"

deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type
)

deploy_image_uri

'763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04'

## flat-5のモデルのURIを取得
model_id と model_version はimageと合わせる。

In [3]:
from sagemaker import model_uris
from sagemaker.utils import name_from_base


# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

model_uri

's3://jumpstart-cache-prod-us-east-1/huggingface-infer/infer-huggingface-text2text-flan-t5-xl.tar.gz'

## モデルの作成
ここでは、まだモデルはデプロイされていない。

In [4]:
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.session import Session


sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

endpoint_name = name_from_base(f"flat-t5-{model_id}")

model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

## モデルをデプロイ
endpointが作成され、推論を実施できるようになる。
実行には時間がかかる。

In [5]:
import sagemaker, boto3, json


model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
    volume_size=30,
)

-----------!

## Endpointとやり取りする関数の作成

In [6]:
def query_endpoint(encoded_text, endpoint_name):
    # endpoint インプットとなるtextをなげる.
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/x-text", Body=encoded_text
    )
    return response


def parse_response(query_response):
    # endpointから返ってきた結果から作成されたテキストを抽出する.
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_text"]
    return generated_text


def generate_text(text):
    # flat-t5を使ってテキストを生成する.
    query_response = query_endpoint(text.encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    return f"入力: {text}\n出力: {generated_text}\n"

In [7]:
text1 = "Translate to Spanish:  I'm Ken."
text2 = "A step by step recipe to make bolognese pasta"
text3 = "元気ですか？"
text4 = """
Review: This moive is so great and once again dazzles and delights us
this movie review sentence negative or positive?"""


for text in [text1, text2, text3, text4]:
    print(generate_text(text))


入力: Translate to Spanish:  I'm Ken.
出力: Soy Ken.

入力: A step by step recipe to make bolognese pasta
出力: Step 1: Preheat oven to 375 degrees F. Place ground beef in a large

入力: 元気ですか？
出力: ?

入力: 
Review: This moive is so great and once again dazzles and delights us
this movie review sentence negative or positive?
出力: positive



## モデルとエンドポイントの削除

In [8]:
model_predictor.delete_model()
model_predictor.delete_endpoint()