Copyright 2024 Google, LLC. This software is provided as-is,
without warranty or representation for any use or purpose. Your
use of it is subject to your agreement with Google.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# How to use Batch Predicitons with Gemini

This notebook outlines how to interact with Vertex AI's Gemini batch predictions API. More info can be found at https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini

## Prepare the python development environment

First, let's identify any project specific variables to customize this notebook to your GCP environment. Change YOUR_PROJECT_ID with your own GCP project ID.

In [None]:
project_id = "YOUR_PROJECT_ID"
location = "global"
region = "us-central1"
bq_source_dataset_id = "parts_data"
bq_source_table = "inventory"
bq_batch_dataset_id = "gemini_batch_test"
bq_batch_table = "batch_input_table"
model_ver = "gemini-1.5-pro-001"
qa_model_ver = "gemini-1.5-flash-001"

Install any needed python modules from our requirements.txt file. Most Vertex Workbench environments include all the packages we'll be using, but if you are using an external Jupyter Notebook or require any additional packages for your own needs, you can simply add them to the included requirements.txt file an run the folloiwng commands.

In [None]:
#pip install -r requirements.txt

Update the google-cloud-aiplatform package to the latest version if needed

In [None]:
#pip install --upgrade google-cloud-aiplatform

Now we will import all required modules. For our purpose, we will be utilizing the following:

- vertexai - The primary library for working with the Vertex AI Platform on GCP 
- BatchPredictionJob - Used to submit and manage batch prediction jobs with Gemini
- bigquery - Work with data stored in BigQuery
- iPython.display - Render HTML and Markdown responses from the Gemini API's

In [None]:
import time
import json
from IPython.display import HTML, Markdown

from google.cloud import bigquery

import vertexai
from vertexai.preview.batch_prediction import BatchPredictionJob
from vertexai.generative_models import GenerativeModel

## Create an example source table in BQ

First we need to create a source table in BigQuery. For this example, we will create a new dataset and table to store some inventory data related to automotive parts. The source inventory data will then be imported from the inventory.csv file included in this repo.

Construct a BigQuery client object and set dataset_id to the ID of the dataset to create.

In [None]:
client = bigquery.Client(project_id)
source_dataset_id = bq_source_dataset_id

Construct a Dataset object to send to the API.

In [None]:
# Construct a full Dataset object to create.
dataset = bigquery.Dataset(f"{project_id}.{source_dataset_id}")

# Specify the geographic location where the dataset should reside.
dataset.location = region

# Send the dataset to your Google Cloud Project
dataset = client.create_dataset(dataset, exists_ok=True)  # API request
print(f"Created dataset {dataset.project}.{dataset.dataset_id}")

Create a source table for the example inventory

In [None]:
# Set table_id to the ID of the table to create.
table_id = f"{project_id}.{source_dataset_id}.{bq_source_table}"

# Set the schema of the table
schema = [
    bigquery.SchemaField("vehicle_manufacturer", "STRING", mode="NULLABLE"),
    bigquery.SchemaField("vehicle_model", "STRING", mode="NULLABLE"),
    bigquery.SchemaField("part_name", "STRING", mode="NULLABLE"),
    bigquery.SchemaField("part_number", "STRING", mode="NULLABLE"),
    bigquery.SchemaField("part_description", "STRING", mode="NULLABLE"),
]

# Create the table
table = bigquery.Table(table_id, schema=schema)
table = client.create_table(table)  # Make an API request.
print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")

Import the inventory.csv file to populate the new table with some example inventory data

In [None]:
# Set job config
job_config = bigquery.LoadJobConfig(
    source_format=bigquery.SourceFormat.CSV,
    skip_leading_rows=1,
)

# Open the local file
with open("inventory.csv", "rb") as source_file:
    # Create and run the load job
    job = client.load_table_from_file(
        source_file,
        table_id,
        job_config=job_config,
    )

    job.result()  # Wait for the load job to complete

# Print the number of rows loaded
table = client.get_table(table_id)
print(f"Loaded {table.num_rows} rows to {table_id}.")

## Create the Batch Input Dataset and Table

We will now create a new dataset and table for the batch prediction job. Batch predictions are a way to efficiently send multiple multimodal prompts that are not latency sensitive. Unlike online prediction, where you are limited to one input prompt at a time, you can send a large number of multimodal prompts in a single batch request. Then, your responses asynchronously populate in your BigQuery storage output location. More information can be found online at https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini

Define the BQ Dataset for the batch prediction job.

In [None]:
batch_dataset_id = bq_batch_dataset_id

Construct a Dataset object to create, specify the region to store the data in and send the request to the API.

In [None]:
dataset = bigquery.Dataset(f"{project_id}.{batch_dataset_id}")

# Specify the geographic location where the dataset should reside.
dataset.location = region

# Send the dataset to your Google Cloud Project
dataset = client.create_dataset(dataset, exists_ok=True)  # API request
print(f"Created dataset {dataset.project}.{dataset.dataset_id}")

Create the batch input table

In [None]:
# Set table_id to the ID of the table to create.
table_id = f"{project_id}.{batch_dataset_id}.{bq_batch_table}"

# Set the schema of the table
schema = [
    bigquery.SchemaField("request", "STRING", mode="NULLABLE"),
]

# Create the table
table = bigquery.Table(table_id, schema=schema)
table = client.create_table(table)  # Make an API request.
print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")

Parse the source inventory table and create an entry in the batch prediction table for each item

In [None]:
# Table IDs
inventory_table_id = f"{project_id}.{bq_source_dataset_id}.{bq_source_table}"
batch_input_table_id = f"{project_id}.{bq_batch_dataset_id}.{bq_batch_table}"

Deinfe the SQL query to fetch data and format the request string

In [None]:
# SQL query to prepare data for the destination table
# SQL query to read data from the source table
query = f"""
SELECT 
    vehicle_manufacturer, 
    vehicle_model, 
    part_name 
FROM 
    `{inventory_table_id}`
"""

Run the job

In [None]:
# Execute the query
query_job = client.query(query)
results = query_job.result()

for row in results:
    # Create the JSON structure for each row
    data = {
        "contents": [
            {
                "role": "user",
                "parts": {
                    "text": f"Write an SEO optimized text for a Product Listing Page of 300 to 400 words. Keep the following points in mind: Main subject & main keyword: {row.vehicle_manufacturer} {row.vehicle_model} auto parts. Sub keywords: buy {row.vehicle_model} {row.part_name}, order {row.vehicle_model} {row.part_name} online. Written for the following website: https://my_autoparts.com/. Make sure the headers are not too similar and write it in HTML."
                }
            }
        ],
        "system_instruction": {
            "parts": [{"text": "You are an SEO engineer, specializing in generating content for search engine optimization."}]
        },
        "generation_config": {"top_k": 5}
    }

    # Convert the Python dictionary to a JSON string
    json_data = json.dumps(data)

    # Insert the JSON data into the destination table
    errors = client.insert_rows_json(batch_input_table_id, [{"request": json_data}])
    if errors == []:
        continue
        #print("New row inserted.")
    else:
        print(f"Encountered errors while inserting row: {errors}")
        
print('Rows inserted')

## Define and submit a Batch Prediction job for Gemini

Initialize vertexai

In [None]:
vertexai.init(project=project_id, location=region)

Next we'll create the Gemini batch prediction job

In [None]:
job = BatchPredictionJob.submit(
    model_ver,   # source_model 
    #"gs://rkiles-test/gemini-batch/batch_data2.json", # input URI if using GCS
    input_dataset = f'bq://{project_id}.{bq_batch_dataset_id}.{bq_batch_table}',  # input dataset if using BQ
    output_uri_prefix = f'bq://{project_id}.{bq_batch_dataset_id}'  # This will generate a new output table in BQ
)

View and monitor the job status. You can also view the status in the GCP Cloud Console under Vertex AI -> Batch Predictions

In [None]:
# Check job status
print(f"Job resouce name: {job.resource_name}")
print(f"Model resource name with the job: {job.model_name}")
print(f"Job state: {job.state.name}")

# Refresh the job until complete
while not job.has_ended:
  time.sleep(5)
  job.refresh()

# Check if the job succeeds
if job.has_succeeded:
  print("Job succeeded!")
else:
  print(f"Job failed: {job.error}")

Check the location of the output

In [None]:
print(f"Job output location: {job.output_location}")

Capture the output table of the batch prediction job

In [None]:
output_table = job.output_location.split(".")[-1]
print(output_table)

List all the GenAI batch prediction jobs under the project

In [None]:
#for bpj in BatchPredictionJob.list():
  #print(f"Job ID: '{bpj.name}', Job state: {bpj.state.name}, Job model: {bpj.model_name}")

## Print the response

Let's print the response from the batch predictions

We will start by defining a function that we can use to rate and verify the quality of the generated results from the batch prediction job. We can specify a different model from the one used for the batch prediction. In this example, we used gemini-1.5-pro-002 for the predictions and gemini-1.5-flash-001 for rating and QA.

In [None]:
# Function to get user rating
def get_rating(source):
    model = GenerativeModel(
        model_name=qa_model_ver,
        system_instruction=[
            "You are a professional Search Engine Optimization engineer.",
            "You specialize in creating content that is optimized for search engine results.",
        ],
    )

    response = model.generate_content(
        f'''<OBJECTIVE>Rank the following content for SEO quality using a system of 1-10 with 1 being the lowest and 10 being the highest. Provide reasoning for your ranking.</OBJECTIVE> {source}'''
    )
    return response.text

Specify the output table created by the batch prediciton job to parse and define the query

In [None]:
output_table_id = f'{bq_batch_dataset_id}.{output_table}'

# SQL query to read data from the output table
query = f"""
SELECT *
FROM `{output_table_id}`
"""

Create the query client

In [None]:
query_job = client.query(query)
results = query_job.result()

Run a for loop to print the output of the batch process and pause between each result. Press Enter to continue, 'r' to review/rate the response (this calls the get_rating function we created earlier) or the 'q' button to quit the loop.

In [None]:
# Process the results
for row in results:
    # Load the JSON string from the 'response' column
    response_data = json.loads(row.response)

    # Extract the generated text
    generated_text = response_data[0]['content']['parts'][0]['text']

    # Display the generated HTML
    #display(HTML(f"<strong>Request:</strong> {row.request}<br>"))
    print(f"**Generated Text:**\n")
    display(Markdown(generated_text))  # Use Markdown for rendering
    print("-" * 50)  # Add a separator between iterations  # Add a separator between iterations

    while True:
        user_input = input("Press Enter for next, 'r' to review/rate, or 'q' to quit: ")
        if user_input.lower() == 'q':
            break  # Exit the loop if the user presses 'q'
        elif user_input.lower() == 'r':
            rating = get_rating(generated_text)
            display(Markdown(rating))
            # TODO: Store the rating in your database or use it as needed
            break
        else: 
            break  # Continue to the next iteration if Enter is pressed

    if user_input.lower() == 'q':
        break