# IMDBデータセットを使って二値分類器を学習し、デプロイする

このnotebookは[Huggingfaceの公式サンプルnotebook](https://github.com/huggingface/notebooks/blob/main/sagemaker/01_getting_started_pytorch/sagemaker-notebook.ipynb)を基にしています．

より詳しい説明は[Hugging FaceのDocument](https://huggingface.co/docs/sagemaker/getting-started)を参照．

実行前に`pyproject.toml`にある依存パッケージをインストールしてください．

## Permissionの確認

このnotebookをローカルで実行する場合は[このガイド](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html)に従って、`AmagonSageMakerFullAccess`をattachしたIAMロールをあらかじめ作成し、そのロールの名前（下記では`sagemaker-local`）を指定．

In [3]:
# このnotebookをSageMaker Studio上で実行する場合は、このセルの代わりに下記を実行:
#
# import sagemaker
#
# role = sagameker.get_execution_role()

import boto3

role_name = "sagemaker-local"

iam = boto3.client("iam")
role = iam.get_role(RoleName=role_name)["Role"]["Arn"]


In [2]:
import sagemaker

sess = sagemaker.Session()


# 前処理

`datasets` ライブラリを利用して前処理済みのimdbデータセットを得る．

In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer

# tokenizer used in preprocessing
tokenizer_name = "distilbert-base-uncased"

# dataset used
dataset_name = "imdb"

# s3 key prefix for the data
s3_prefix = "samples/datasets/imdb"


In [5]:
# load dataset
train_dataset, test_dataset = load_dataset(dataset_name, split=["train", "test"])
test_dataset = test_dataset.shuffle().select(
    range(10000)
)  # smaller the size for test dataset to 10k

print(train_dataset)
print(test_dataset)


Reusing dataset imdb (/Users/lisa/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/2 [00:00<?, ?it/s]

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 10000
})


In [6]:
train_dataset[0]


{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [7]:
# download tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# tokenizer helper function
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)


# tokenize dataset
train_dataset = train_dataset.map(tokenize, batched=True)
test_dataset = test_dataset.map(tokenize, batched=True)

train_dataset[0]


  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [8]:
# set format for pytorch
train_dataset = train_dataset.rename_column("label", "labels")
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset = test_dataset.rename_column("label", "labels")
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

train_dataset[0]


{'labels': tensor(0),
 'input_ids': tensor([  101,  1045, 12524,  1045,  2572,  8025,  1011,  3756,  2013,  2026,
          2678,  3573,  2138,  1997,  2035,  1996,  6704,  2008,  5129,  2009,
          2043,  2009,  2001,  2034,  2207,  1999,  3476,  1012,  1045,  2036,
          2657,  2008,  2012,  2034,  2009,  2001,  8243,  2011,  1057,  1012,
          1055,  1012,  8205,  2065,  2009,  2412,  2699,  2000,  4607,  2023,
          2406,  1010,  3568,  2108,  1037,  5470,  1997,  3152,  2641,  1000,
          6801,  1000,  1045,  2428,  2018,  2000,  2156,  2023,  2005,  2870,
          1012,  1026,  7987,  1013,  1028,  1026,  7987,  1013,  1028,  1996,
          5436,  2003,  8857,  2105,  1037,  2402,  4467,  3689,  3076,  2315,
         14229,  2040,  4122,  2000,  4553,  2673,  2016,  2064,  2055,  2166,
          1012,  1999,  3327,  2016,  4122,  2000,  3579,  2014,  3086,  2015,
          2000,  2437,  2070,  4066,  1997,  4516,  2006,  2054,  1996,  2779,
         25430, 1

In [9]:
train_dataset


Dataset({
    features: ['text', 'labels', 'input_ids', 'attention_mask'],
    num_rows: 25000
})

## 前処理済みのデータをS3にアップロードする

In [10]:
import botocore
from datasets.filesystems import S3FileSystem

s3 = S3FileSystem()

# save train_dataset to s3
training_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/train"
train_dataset.save_to_disk(training_input_path, fs=s3)

# save test_dataset to s3
test_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/test"
test_dataset.save_to_disk(test_input_path, fs=s3)


## 学習を実行し、モデルと実行結果をS3にアップロードする

In [12]:
from sagemaker.huggingface import HuggingFace

# hyperparameters, which are passed into the training job
hyperparameters = {
    "epochs": 1,
    "train_batch_size": 32,
    "model_name": "distilbert-base-uncased",
}

# parse metrics in the training log
# https://huggingface.co/docs/sagemaker/train#sagemaker-metrics
metric_definitions = [
    {"Name": "loss", "Regex": "'loss': (.*?),"}
]

huggingface_estimator = HuggingFace(
    entry_point="train.py",
    source_dir="./smhf",
    instance_type="ml.p3.2xlarge",
    instance_count=1,
    role=role,
    transformers_version="4.12",
    pytorch_version="1.9",
    py_version="py38",
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
)

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit({"train": training_input_path, "test": test_input_path})


2022-04-16 10:10:53 Starting - Starting the training job...
2022-04-16 10:11:20 Starting - Preparing the instances for trainingProfilerReport-1650103853: InProgress
.........
2022-04-16 10:12:43 Downloading - Downloading input data...
2022-04-16 10:13:18 Training - Downloading the training image...........................
2022-04-16 10:17:59 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-04-16 10:17:43,973 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2022-04-16 10:17:43,995 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2022-04-16 10:17:44,001 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2022-04-16 10:17:44,419 sagemaker-training-toolkit INFO     Invoking user script[0m
[34mTraining

## 学習済みモデルを使って推論APIをデプロイする

In [13]:
from sagemaker.huggingface.model import HuggingFaceModel

model_data = "s3://sagemaker-ap-northeast-1-007376390068/huggingface-pytorch-training-2022-04-16-10-10-52-577/output/model.tar.gz"

model = HuggingFaceModel(
    model_data=model_data,
    role=role,
    transformers_version="4.12",
    pytorch_version="1.9",
    py_version="py38",
)

predictor = model.deploy(initial_instance_count=1, instance_type="ml.m5.xlarge")


-----!

## Sagemaker Serverless Inferenceを使ってデプロイする

[参考](https://github.com/aws/sagemaker-python-sdk/issues/3012)

In [22]:
import sagemaker
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.serverless import ServerlessInferenceConfig

model_data = "s3://sagemaker-ap-northeast-1-007376390068/huggingface-pytorch-training-2022-04-16-10-10-52-577/output/model.tar.gz"

image_uri = sagemaker.image_uris.retrieve(
    framework="huggingface",
    region="ap-northeast-1",
    version="4.12",
    py_version="py38",
    image_scope='inference',
    instance_type="ml.m5.xlarge",
    base_framework_version='pytorch1.9'
)

model = HuggingFaceModel(
    image_uri=image_uri,
    model_data=model_data,
    role=role,
    transformers_version="4.12",
    pytorch_version="1.9",
    py_version="py38",
)

serverless_config = ServerlessInferenceConfig(max_concurrency=1)
predictor = model.deploy(serverless_inference_config=serverless_config)


---------!

## 作成したエンドポイントにリクエストを送信して結果を得る

In [24]:
from sagemaker.huggingface.model import HuggingFacePredictor

endpoint_name = "huggingface-pytorch-inference-2022-04-16-12-17-45-987"

predictor = HuggingFacePredictor(
    endpoint_name=endpoint_name
)
sentiment_inputs = {
    "inputs": [
        "I love this coffee.",
        "I never drink this coffee.",
        "I always drink this coffee.",
    ]
}
predictor.predict(sentiment_inputs)


[{'label': 'LABEL_1', 'score': 0.9879827499389648},
 {'label': 'LABEL_0', 'score': 0.8062602877616882},
 {'label': 'LABEL_1', 'score': 0.9366642832756042}]

## エンドポイントを削除する

In [16]:
from sagemaker.huggingface.model import HuggingFacePredictor

predictor = HuggingFacePredictor(
    endpoint_name=endpoint_name
)
predictor.delete_endpoint()
