# 01 - Generate Images 

This notebook will generate 1,000 images from Vertex AI Imagen API and save them into a Google Cloud storage bucket.


References:

* [Generate images using text prompts  |  Generative AI on Vertex AI  |  Google Cloud](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images)
* [Imagen API  |  Generative AI on Vertex AI  |  Google Cloud](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api)

by: Justin Marciszewski | justinjm@google.com | AI/ML Specialist CE

## Pre-requisites 

See [README.md](README.md) for full details 

1. Setup Google Cloud project 
2. Create Vertex AI Workbench instance 

## Install required packages

Run the cell below to check if required packages are installed.

If any are not, they will be installed and kernel will automatically restart and show a notificaiton.

If they are already installed, nothing will happen and proceed to the next step.

In [None]:
# tuples of (import name, install name)
packages = [
    ('google.cloud.aiplatform', 'google-cloud-aiplatform'),
    ('PIL', 'Pillow')
]

import importlib
install = False
for package in packages:
    if not importlib.util.find_spec(package[0]):
        print(f'installing package {package[1]}')
        install = True
        !pip install {package[1]} -U -q --user

### Restart Kernel (If Installs Occured)

After a kernel restart the code submission can start with the next cell after this one.

In [None]:
if install:
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Set constants

In [None]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

LOCATION = "us-central1"  

BUCKET_NAME = PROJECT_ID + "-" + "fruit-veg-image-model"
BUCKET_NAME

### Set user_flag for testing vs production

Test will generate 100 images and incur roughly 2$ spend.

Prod will generated 1000 images and incur roughly 20$ spend.

In [None]:
## TODO - SET USER FLAG (either 'test' or 'prod') #############################
user_flag = 'test'

# Basic input validation
if user_flag not in ['test', 'prod']:
    raise ValueError("Invalid input. Please enter either 'test' or 'prod'.")

print("=" * 80)
print(f"User flag set to: {user_flag}")
print("=" * 80)

## Packages

In [4]:
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel

from google.cloud import storage
from google.cloud.exceptions import NotFound
import random

from datetime import datetime
import time
import timeit 
from io import BytesIO
from PIL import Image

from concurrent.futures import ThreadPoolExecutor
import threading
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from google.api_core.exceptions import GoogleAPIError

## Enable APIs

In [None]:
%%bash
# 1. Get the list of CURRENTLY ENABLED services
enabled_services=$(gcloud services list --enabled | awk '{print $1}')

# 2. Services we WANT to ensure are enabled
services_to_enable=("storage.googleapis.com" "aiplatform.googleapis.com")

# 3. Check each desired service against the enabled list
for service in "${services_to_enable[@]}"; do
    if ! echo "$enabled_services" | grep -q "$service"; then
        echo "Enabling $service..."
        gcloud services enable "$service"
    else
        echo "$service is already enabled."
    fi
done

## clients

In [6]:
#  Vertex AI Client 
vertexai.init(project=PROJECT_ID, location=LOCATION)
generation_model = ImageGenerationModel.from_pretrained("imagegeneration@006")

#  Google Cloud Storage client
gcs = storage.Client(project=PROJECT_ID)

## Create Storage Bucket

In [7]:
def check_and_create_bucket(bucket_name, location):
    try:
        gcs.get_bucket(bucket_name)
        print(f"Bucket {bucket_name} already exists.")
    except NotFound:
        bucket = gcs.create_bucket(bucket_or_name=bucket_name, location=location)
        print(f"Bucket {bucket_name} created.")

Create storage bucket if needed 

* First run - uncomment the cell  and run it  
* Second or great run - comment the cell out to ensure bucket contents are not overwritten

In [None]:
check_and_create_bucket(BUCKET_NAME, LOCATION)

Finally, set the bucket we created using the client above for use in the image generation code.

In [9]:
bucket = gcs.bucket(BUCKET_NAME)

## Functions to generate images 

In [None]:
background = ["white", "cookie sheet", "cutting board"]
objects = ["bellpepper_ripe", "apple_ripe", "banana_ripe", "bellpepper_rotten", "apple_rotten", "banana_rotten"]
images_per_prompt = 1
if user_flag == 'test':
    num_prompts = 100
elif user_flag == 'prod':
    num_prompts = 1000

print("=" * 80)
print(f"Ready to generate {num_prompts} images")
print(f"Estimated cost: ${num_prompts*.02}")
print("=" * 80)

In [35]:
image_counter = 1  
counter_lock = threading.Lock() 


@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type(GoogleAPIError))
def upload_to_gcs(img_byte_arr, blob_name):
    blob = bucket.blob(blob_name)
    blob.upload_from_file(img_byte_arr, content_type="image/jpeg")


def generate_and_upload_image(chosen_background, chosen_object, prompt, num_prompts):
    global image_counter

    response = generation_model.generate_images(prompt=prompt, number_of_images=images_per_prompt)

    if response and response.images:
        for i, image_data in enumerate(response.images):
            pil_image = Image.open(BytesIO(image_data._image_bytes))
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            blob_name = f"{chosen_object}_{chosen_background}_{random.randint(10000, 99999)}_{timestamp}.jpg".replace(' ', '_')

            img_byte_arr = BytesIO()
            pil_image.save(img_byte_arr, format="JPEG")
            img_byte_arr.seek(0)
            try:
                upload_to_gcs(img_byte_arr, blob_name)
                with counter_lock:
                    print(f"Image {image_counter} of {num_prompts}: '{blob_name}' uploaded to GCS.")
                    image_counter += 1
            except Exception as e:  
                print(f"Error processing image: {e}")

In [36]:
def main():
    with ThreadPoolExecutor(max_workers=5) as executor:
        for _ in range(num_prompts):
            chosen_background = random.choice(background)
            chosen_object = random.choice(objects)
            
            chosen_object_food = chosen_object.split("_")[0]
            chosen_object_quality = chosen_object.split("_")[1]
            prompt = f"Draw an image of {chosen_object_food} that is {chosen_object_quality} and on a {chosen_background} and photorealistic photography" 

            executor.submit(generate_and_upload_image, chosen_background, chosen_object, prompt, num_prompts)  

## Run Image Generation

The next cell executes the image generation and will be a long running operation.

Leave this notebook open and do not close it to ensure the job completes.

Re-run the cell below if the job stops due to unforeseen error(s).

If for some reason the cell below does not re-run successfuly, go to the top navigation bar of the workbench instnace and click: "Kernel > Restart Kernel & clear all outputs" then re-run the code in this notebook.

In [None]:
start_time = timeit.default_timer()
print("Started image generation job at: ", 
      datetime.fromtimestamp(datetime.now().timestamp()).strftime("%H:%M:%S"))
print(f"Generating {num_prompts} total images")

main()

end_time = timeit.default_timer()
print("Completed image generation job at: ", 
      datetime.fromtimestamp(datetime.now().timestamp()).strftime("%H:%M:%S"))
execution_time = end_time - start_time
print(f"Executed the function in {round(execution_time, 2)} seconds")