In [None]:
import os
import sys

ROOT_PATH = os.path.abspath("../")
sys.path.append(ROOT_PATH)

In [None]:
import os
import json
import urllib
import time
import base64
import ffmpeg
import sagemaker
import boto3
import secrets
from PIL import Image
from diffusers.utils import export_to_video, make_image_grid
from botocore.exceptions import ClientError
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join

from IPython.display import Video
from common.utils.time import get_current_time, get_seed
from common.utils.images import encode_image_base64_from_file

In [None]:
sm_session_bucket = None

sm_session = sagemaker.Session()
sm_runtime_client = boto3.client("sagemaker-runtime")

if sm_session_bucket is None and sm_session is not None:
    # set to default bucket if a bucket name is not given
    sm_session_bucket = sm_session.default_bucket()
try:
    sm_role = sagemaker.get_execution_role()
except ValueError:
    iam_client = boto3.client("iam")
    sm_role = iam_client.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

print(f"sagemaker role arn: {sm_role}")
print(f"sagemaker bucket: {sm_session.default_bucket()}")
print(f"sagemaker session region: {sm_session.boto_region_name}")

In [None]:
seed = get_seed()

FRAME_OUT_PATH = "frames_out"
VIDEO_OUT_PATH = "video_out"

# Create directories to store ouput
os.makedirs(FRAME_OUT_PATH,exist_ok=True)
os.makedirs(VIDEO_OUT_PATH,exist_ok=True)

# Load Endpoint Name
# with open("endpoint.txt", "r") as f:
#     saved_data = json.load(f)

# endpoint_name = saved_data["endpoint_name"]
endpoint_name = "<YOU_NEED_TO_FILL_HERE>"

print(f"Endpoint: {endpoint_name}")
print(f"Seed: {seed}")

### 4.2 Set Movie Name and Inference Parameters

In [None]:
input_image_path = "sample/champagne.jpg"
video_output_name = "champagne"

img = Image.open("data/gen_llm_1.png")
width, height = img.size

print(width, height)
img.show()

In [None]:
fps = 6

data = {
    "image": encode_image_base64_from_file(input_image_path),
    "width": width,
    "height": height,
    "num_frames": 25,
    "num_inference_steps": 25,
    "min_guidance_scale": 1.0,
    "max_guidance_scale": 3.0,
    "fps": fps, # [5, 30]
    "motion_bucket_id": 127, # < 255
    "noise_aug_strength": 0.02,
    "decode_chunk_size": 8,
    "seed": seed,
}

## 5: Upload Request Payload and Invoke Endpoint


### 5.1: Upload Request Payload

- Amazon S3에 JSON request payload 업로드 후, 해당 payload로 inference 수행

In [None]:
def upload_data_to_s3(data):
    timestamp = get_current_time(format="%Y%m%d_%H%M%S")
    tmp_filename = f"payload_{timestamp}.json"
    
    with open(tmp_filename, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

    return sm_session.upload_data(
        tmp_filename,
        bucket=sm_session.default_bucket(),
        key_prefix="async_inference/input",
        extra_args={"ContentType": "application/json"},
    )

In [None]:
input_s3_location = upload_data_to_s3(data)
print(f"Request payload location: {input_s3_location}")

### 5.2: Invoke the Endpoint for Inference

- `num_frames`가 25인 경우, 약 2분 정도 소요됩니다.
- 모델 호출에 대한 응답을 받기 위해 Amazon S3 버킷을 폴링합니다.

In [None]:
def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sm_session.read_s3_file(bucket=bucket, key_prefix=key)
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                print("Waiting for model output...")
                time.sleep(15)
                continue
            raise


def load_video_frames(video_frames):
    loaded_video_frames = []

    for idx, video_frame in enumerate(video_frames):
        frame = bytes(video_frame, "raw_unicode_escape")
        frame_name = (
            f"{FRAME_OUT_PATH}/frame_0{idx+1}.jpg"
            if idx < 9
            else f"{FRAME_OUT_PATH}/frame_{idx+1}.jpg"
        )
        
        with open(frame_name, "wb") as fh:
            fh.write(base64.decodebytes(frame))

        image = Image.open(frame_name, mode="r")
        loaded_video_frames.append(image)

    return loaded_video_frames

In [None]:
response = sm_runtime_client.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    InvocationTimeoutSeconds=3600,
)

print(f"Model response payload location: {response['OutputLocation']}")

## 6: Frames to MP4 Video

### 6.1: Frames to MP4 Video

- 각 프레임 binary 객체를 JPEG로 변환한 다음, Hugging Face의 `diffusers.utils.export_to_video` 메서드를 사용하여 MP4로 결합합니다.

In [None]:
output = get_output(response["OutputLocation"])
data = json.loads(output)
loaded_video_frames = load_video_frames(data["frames"])
print(f"Load video frames: {len(loaded_video_frames)}")

video_output_path = f"{VIDEO_OUT_PATH}/{video_output_name}.mp4"
export_to_video(loaded_video_frames, video_output_path, fps=fps)
print(f"Video created: {video_output_path}")

### 6.2: Display Frames as Grid

- 25개의 프레임을 5x5 grid로 표시합니다

In [None]:
image = make_image_grid(loaded_video_frames, 5, 5)
(width, height) = (image.width // 2, image.height // 2)
im_resized = image.resize((width, height))
display(im_resized)

### 6.3: Display Video

- 생성한 비디오 파일을 Notebook에서 재생합니다


In [None]:
def display_video(video_path, frame_width):
    return Video(
        url=video_path,
        width=frame_width,
        html_attributes="controls muted autoplay loop"
    )

In [None]:
display_video(f"{video_output_path}", width // 4)

## 7: Generating of Multiple Video Variations

- 하나의 이미지로 여러 video variation을 생성합니다


In [None]:
input_image_path = "sample/beach_bike.jpg"
video_output_name = "beach_bike"

width = 1024
height = 576
fps = 6

for i in range(3):
    seed = get_seed()
    data = {
        "image": encode_image_base64_from_file(input_image_path),
        "width": width,
        "height": height,
        "num_frames": 25,
        "num_inference_steps": 25,
        "min_guidance_scale": 1.0,
        "max_guidance_scale": 3.0,
        "fps": fps,
        "motion_bucket_id": 127,
        "noise_aug_strength": 0.02,
        "decode_chunk_size": 8,
        "seed": seed,
    }

    input_s3_location = upload_data_to_s3(data)
    response = sm_runtime_client.invoke_endpoint_async(
        EndpointName=endpoint_name,
        InputLocation=input_s3_location,
        InvocationTimeoutSeconds=3600,
    )
    
    output = get_output(response["OutputLocation"])
    data = json.loads(output)
    loaded_video_frames = load_video_frames(data["frames"])

    video_output_path = f"{VIDEO_OUT_PATH}/{video_output_name}_{i}.mp4"
    export_to_video(loaded_video_frames, video_output_path, fps=fps)
    print(f"Video created: {video_output_path}")