In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Overview

This notebook demonstrates how to call model endpoint on Vertex AI and generate game assest designs based on prompts. The images generated are uploaded to GCS bucket for further processing.

### Objective

- Run online predictions for text-to-image.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

### 1. Initialize

In [None]:
# Cloud project id.
PROJECT_ID = ""  # @param {type:"string"}

# The region you want to launch jobs in.
REGION = ""  # @param {type:"string"}

# The Cloud Storage bucket for storing experiments output.
# Fill it without the 'gs://' prefix.
GCS_BUCKET = ""  # @param {type:"string"}

#bucket to store generated images
GENERATED_IMAGES_BUCKET = "" # @param {type:"string"}

# create a parameter called ENDPOINT_ID
ENDPOINT_ID=""  # @param {type:"string"}

# The service account for deploying fine tuned model. The service account looks like:
# '<account_name>@<project>.iam.gserviceaccount.com'
SERVICE_ACCOUNT = ""  # @param {type:"string"}

### 2. Define Helper Functions

In [None]:
from typing import Dict, List, Union
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from PIL import Image
from google.cloud import storage
from io import BytesIO

import base64
import glob
import os
import datetime
import time
import uuid
import matplotlib.pyplot as plt

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=GCS_BUCKET)

def generate_guid():
    """Generates a random UUID (GUID) string."""
    return str(uuid.uuid4())

def base64_to_image(image_str):
    """Convert base64 encoded string to an image."""
    try:
        image = Image.open(BytesIO(base64.b64decode(image_str, validate=False)))
        return image
    except Exception as e:
        print(f"Error decoding Base64 string: {e}")
        return None  # Or handle the error differently

def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new(
        mode="RGB", size=(cols * w + 10 * cols, rows * h), color=(255, 255, 255)
    )
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w + 10 * i, i // cols * h))
    return grid

def save_pil_image_to_gcs(pil_image, image_format="JPEG"):
    """Saves a PIL Image to Google Cloud Storage.

    Args:
        pil_image (PIL.Image): The PIL Image object to save.
        bucket_name (str): Name of the GCS bucket.
        destination_blob_name (str): Target filename in the bucket.
        image_format (str, optional): Format to save the image in.
            Defaults to "JPEG".  Supported formats depend on PIL.
    """

    try:
        storage_client = storage.Client()
        my_guid = generate_guid()
        bucket = storage_client.bucket(GENERATED_IMAGES_BUCKET)
        filename = f"generated_logo_{my_guid}.png"
        blob = bucket.blob(filename)

        output_buffer = BytesIO()
        pil_image.save(output_buffer, format=image_format)

        blob.upload_from_string(
            output_buffer.getvalue(), content_type=f"image/{image_format.lower()}"
        )

        print(f"Image saved to gs://{GENERATED_IMAGES_BUCKET}/{filename}")

    except Exception as e:
        print(f"An error occurred during upload: {e}")



### 3. Intialize Prompts

In [None]:
# Add or update prompts as needed

instances = [
  {
      "prompt": "a cat standing with hands in its pocket in space with a a warrior helmet",
      "height": 1024,
      "width": 1024,
      "guidance_scale": 5,
      "num_inference_steps": 25,
      "seed": 366868260,
      "cfg_scale": 7.5,
      "negative_prompt": "bad art, ugly, deformed, watermark, duplicated, bad spelling, No text, no clip art",
      "sampler": "DPM++ 2M SDE Karras",
  },
  {
      "prompt": "a cartoon dog standing posing in with a vintage warrior helmet, gold boots and a bow tie in a transparent background",
      "height": 1024,
      "width": 1024,
      "guidance_scale": 5,
      "num_inference_steps": 25,
      "seed": 366868260,
      "cfg_scale": 7.5,
      "negative_prompt": "bad art, ugly, deformed, watermark, duplicated, bad spelling, No text, no clip art",
      "sampler": "DPM++ 2M SDE Karras",
  }
]


### 4. Intialize endpoints and start predictions

In [None]:
api_endpoint = "us-central1-aiplatform.googleapis.com"
client_options = {"api_endpoint": api_endpoint}

# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)

# Initiaize endpoint
endpoint = client.endpoint_path(
    project=PROJECT_ID, location=REGION, endpoint=ENDPOINT_ID
)

# Temporary array to storage images to print while executing
copytoprint =[]

for instance in instances:
  response = client.predict(
      endpoint=endpoint, instances=[instance]
  )
  print("Response")
  print(" Deployed_model_id:", response.deployed_model_id)
  generatedimagepredictions = response.predictions
  for image in generatedimagepredictions:
    tempimg = base64_to_image(image)
    save_pil_image_to_gcs(tempimg)
    copytoprint.append(tempimg)

# Display the images after storing them on GCS bucket
image_grid(copytoprint)
