# Amazon Bedrock Batch Inference

## 情報
### 公開情報
- 開発者ドキュメント: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html
- Quota: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html#quotas-batch
- コードサンプル: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-example.html

### 2024.1.25
- Public Preview
- 利用方法
    - REST API: あると思うが面倒
    - CLI: 無さそう
    - SDK: プレビューのがある
        - PythonとJavaのみ
        - https://d2eo22ngex1n9g.cloudfront.net/Documentation/SDK/bedrock-python-sdk-reinvent.zip
    - コンソール: 無さそう

## 背景
- Bedrockのbatch推論のquotaがon demandと関係ない（だからquotaに引っ掛からなくてうれしい）と言う説
- text to image (SDXL or Titan or Both)で500枚くらいバッチ推論ジョブを発行してどんな風に実行が完了するか（もしくは完了しないか）を見てみる
- On demandのquotaより明らかに早かったら嬉しい


## 検証条件
- 推論方式: バッチ
- 生成: テキストから画像
- モデル: amazon.titan-image-generator-v1 & stability.stable-diffusion-xl-v1

In [2]:
model_id:str = "amazon.titan-image-generator-v1"
# model_id:str = "stability.stable-diffusion-xl-v1"

prompt:str = "A dog running at a park."
number_of_images:int = 60
image_generation_config:dict = {
    "numberOfImages": 1,
    "quality": "standard",
    "height": 1024,
    "width": 1024,
    "cfgScale": 8.0,
    "seed": 0
}

file_name:str = "input"

## 単発推論

まずは単発を確認

サンプルコード: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-image.html#model-parameters-titan-image-code-examples

In [None]:
from func import inference
inference(model_id, prompt)

## 検証

### 環境準備

In [None]:
SDK = "sdk"
ZIP = f"{SDK}.zip"
! wget --no-check-certificate -O $ZIP https://d2eo22ngex1n9g.cloudfront.net/Documentation/SDK/bedrock-python-sdk-reinvent.zip

In [None]:
! unzip -n -d $SDK $ZIP

下記より、以下を参照: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-example.html

In [None]:
! find $SDK -type f -name boto*.whl

下記実行後、カーネルの再起動が必要

In [None]:
%pip install $SDK/botocore-1.32.4-py3-none-any.whl
%pip install $SDK/boto3-1.29.4-py3-none-any.whl

In [None]:
! aws --version

### 入力データフォーマット
参考: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-data.html

以下、サンプルの入力JSON

入力JSON Linesフォーマット
```JSON
{
    "recordId": "12 character alphanumeric string",
    "modelInput": {JSON body}
}
...
{
    "recordId": "12 character alphanumeric string",
    "modelInput": {JSON body}
}
```

Titanの場合の推論入力JSON: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-body

```JSON
{
    "inputText": string,
    "textGenerationConfig": {
        "temperature": float,  
        "topP": float,
        "maxTokenCount": int,
        "stopSequences": [string]
    }
}
```

Tiatanの場合のバッチ推論入力JSON Lines
```JSON
{
    "modelInput": {
        "inputText": string,
        "textGenerationConfig": {
            "temperature": float,  
            "topP": float,
            "maxTokenCount": int,
            "stopSequences": [string]
        }
    }
}
...
{
    "modelInput": {
        "inputText": string,
        "textGenerationConfig": {
            "temperature": float,  
            "topP": float,
            "maxTokenCount": int,
            "stopSequences": [string]
        }
    }
}
```

### 条件設定

### 環境設定の読み込み

In [3]:
from yaml import safe_load
with open("config.yaml") as config_file:
    config = safe_load(config_file)

role = config.get("role")
bucket_name = config.get("bucket_name")

### バッチ用入力データ出力

In [4]:
# prompts = [prompt for _ in range(number_of_images)]

In [5]:
prompts = [f"{i+2} dogs running at a park." for i in range(number_of_images)]

In [6]:
import json

requests = list()
for i, prompt in enumerate(prompts):
    body = {
        "taskType": "TEXT_IMAGE",
        "textToImageParams": {
            "text": prompt
        },
        "imageGenerationConfig": image_generation_config
    }
    record = {"recordId": str(i).zfill(12), "modelInput": body}
    requests.append(json.dumps(record))
else:
    jsonl = "\n".join(requests)
    
jsonl

'{"recordId": "000000000000", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "2 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}\n{"recordId": "000000000001", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "3 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}\n{"recordId": "000000000002", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "4 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}\n{"recordId": "000000000003", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "5 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 102

In [7]:
from boto3 import resource
s3 = resource('s3')
bucket = s3.Bucket(bucket_name)
dir = "Bedrock/Batch-Inference"
jsonl_key = f"{dir}/input/{file_name}.jsonl"
jsonl_obj = bucket.Object(key=jsonl_key)
jsonl_obj.put(Body=jsonl)

{'ResponseMetadata': {'RequestId': '8Q65DGR2GVCP0CHN',
  'HostId': 'SqC/L/hqdQOTFkPA5uBeIeGdTA4UIxeoTUAP5nSxvPxL28i9axpcgGTcKgw7xqf85MeOWmUOYGU=',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amz-id-2': 'SqC/L/hqdQOTFkPA5uBeIeGdTA4UIxeoTUAP5nSxvPxL28i9axpcgGTcKgw7xqf85MeOWmUOYGU=',
   'x-amz-request-id': '8Q65DGR2GVCP0CHN',
   'date': 'Tue, 30 Jan 2024 04:35:36 GMT',
   'x-amz-server-side-encryption': 'AES256',
   'etag': '"032d31570b84e0e5c2bea8f14b38a93f"',
   'server': 'AmazonS3',
   'content-length': '0'},
  'RetryAttempts': 0},
 'ETag': '"032d31570b84e0e5c2bea8f14b38a93f"',
 'ServerSideEncryption': 'AES256'}

In [8]:
! aws s3 cp s3://$bucket_name/$jsonl_key -

{"recordId": "000000000000", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "2 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}
{"recordId": "000000000001", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "3 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}
{"recordId": "000000000002", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "4 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "width": 1024, "cfgScale": 8.0, "seed": 0}}}
{"recordId": "000000000003", "modelInput": {"taskType": "TEXT_IMAGE", "textToImageParams": {"text": "5 dogs running at a park."}, "imageGenerationConfig": {"numberOfImages": 1, "quality": "standard", "height": 1024, "

### 推論

In [9]:
from boto3 import client
bedrock = client(service_name="bedrock")

In [10]:
inputDataConfig = ({
    "s3InputDataConfig": {
        "s3Uri": f"s3://{bucket_name}/{jsonl_key}"
    }
})

output_dir:str = f"s3://{bucket_name}/{dir}/output/"
outputDataConfig=({
    "s3OutputDataConfig": {
        "s3Uri": output_dir
    }
})

from utils import get_formatted_time
job_name:str = f"{model_id}-{get_formatted_time()}"
job_name

'amazon.titan-image-generator-v1-20240130-133537'

In [11]:
# %%time

# response = bedrock.create_model_invocation_job(
#     roleArn = role,
#     modelId = model_id,
#     jobName = job_name,
#     inputDataConfig = inputDataConfig,
#     outputDataConfig = outputDataConfig
# )
# print(response)

# job_arn = response.get("jobArn")
# job_id = job_arn.split("/")[-1].strip()
# print(job_id)

# from time import sleep
# job_status:str = "Submitted"

# while job_status in ("Submitted", "InProgress"):
#     sleep(10)
#     job:dict = bedrock.get_model_invocation_job(jobIdentifier=job_arn)
#     job_status:str = job.get("status")
#     print(job_status)

# print(bedrock.get_model_invocation_job(jobIdentifier=job_arn).get("message"))
# ! aws s3 cp $output_dir$job_id/manifest.json.out -

In [None]:
from time import perf_counter, sleep

def wait(monitored_status:str) -> str:
    job_status:str = monitored_status
    print(job_status)

    started_time = perf_counter()
    while job_status == monitored_status:
        job:dict = bedrock.get_model_invocation_job(jobIdentifier=job_arn)
        job_status:str = job.get("status")
        sleep(1)
    ended_time = perf_counter()

    time_delta = ended_time - started_time
    print(time_delta)
    return job_status

In [13]:
%%time

response = bedrock.create_model_invocation_job(
    roleArn = role,
    modelId = model_id,
    jobName = job_name,
    inputDataConfig = inputDataConfig,
    outputDataConfig = outputDataConfig
)
from pprint import pprint
pprint(response)

job_arn = response.get("jobArn")
job_id = job_arn.split("/")[-1].strip()
print(job_id)

job_status:str = wait(monitored_status="Submitted")
job_status:str = wait(monitored_status=job_status)

print(bedrock.get_model_invocation_job(jobIdentifier=job_arn).get("message"))
! aws s3 cp $output_dir$job_id/manifest.json.out -

{'ResponseMetadata': {'RequestId': '1e8f7967-6a28-4ad2-aaad-e02ea60c1519', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Tue, 30 Jan 2024 04:35:41 GMT', 'content-type': 'application/json', 'content-length': '85', 'connection': 'keep-alive', 'x-amzn-requestid': '1e8f7967-6a28-4ad2-aaad-e02ea60c1519'}, 'RetryAttempts': 0}, 'jobArn': 'arn:aws:bedrock:us-east-1:624045005200:model-invocation-job/ibcf54hj71gq'}
ibcf54hj71gq
Submitted


ジョブの状態はこちら: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-list.html

In [None]:
# jobs = bedrock.list_model_invocation_jobs().get("invocationJobSummaries")
# for job in jobs: print(job.get("status"))

以下のファイル内容表示は、画像のバイナリのためたった数枚でもエディタが重くなるデータサイズなので注意

In [None]:
# ! aws s3 cp $output_dir$job_id/input.jsonl.out -

In [None]:
import base64
import io
import json
from PIL import Image

output_key = f"{dir}/output/{job_id}/{file_name}.jsonl.out"
output_obj = bucket.Object(key=output_key).get()
binary_contents = output_obj.get("Body").read()

# contents = list()
for line in io.BytesIO(binary_contents):
    content = json.loads(line.decode("utf-8"))
    # finish_reason = content.get("error")
    # if finish_reason is not None: print(f"Image generation error. Error is {finish_reason}")
    # contents.append(content)
    print(content.get("modelInput").get("textToImageParams").get("text"))
    base64_image = content.get("modelOutput").get("images")[0]
    base64_bytes = base64_image.encode('ascii')
    image_bytes = base64.b64decode(base64_bytes)
    image = Image.open(io.BytesIO(image_bytes))
    image.show()