This notebook demonstrates batch prediction on GCP for chip research classification.

In [16]:
import os
from pathlib import Path

import vertexai
from bigframes import pandas as bpd
from dotenv import load_dotenv
from google.cloud import bigquery
from jinja2 import Template
from vertexai.batch_prediction import BatchPredictionJob

load_dotenv()

# Required
PROJECT = os.getenv('PROJECT', 'gcp-cset-projects')

# Model name must be one of 
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini#models_that_support_batch_predictions
MODEL = "gemini-1.5-flash-002"

# Must be us-central1 at time of writing
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini#request_a_batch_prediction_job_2
LOCATION = 'us-central1'
CENTRAL_DATASET = os.getenv("CENTRAL_DATASET", "tech_topics_demo_central")

# But our input data is in multi-region US
US_DATASET = os.getenv("US_DATASET", "tech_topics_demo")

bq_client = bigquery.Client(project=PROJECT)
vertexai.init(project=PROJECT, location=LOCATION)

central_dataset_reference = bigquery.DatasetReference(PROJECT, CENTRAL_DATASET)
central_dataset = bigquery.Dataset(central_dataset_reference)
central_dataset.location = 'us-central1'

us_dataset_reference = bigquery.DatasetReference(PROJECT, US_DATASET)
us_dataset = bigquery.Dataset(us_dataset_reference)

In [9]:
# Create the datasets if they don't exist
central_dataset = bq_client.create_dataset(central_dataset, exists_ok=True)
us_dataset = bq_client.create_dataset(us_dataset, exists_ok=True)

We'll run just 1K inputs through the pipeline. Below we write these to a table called `corpus` in a dataset located in multi-region location `US`.

In [None]:
corpus_sql = Path('sql/chip_corpus.sql').read_text()
corpus_sql += ' LIMIT 1000'

job_config = bigquery.QueryJobConfig(
    destination=f"{PROJECT}.{US_DATASET}.corpus",
    write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE
)

query_job = bq_client.query(corpus_sql, job_config=job_config)
query_job.result()

The next few cells copy our demo dataset containing the `corpus` inputs from the multi-region `US` location to `us-central1`, using the BQ Data Transfer Service. This is necessary because (at time of writing) Batch Prediction on Vertex AI is only available in `us-central1`, while our input data is in `US`.

First, we create a data transfer configuration.

In [25]:
from google.cloud import bigquery_datatransfer


def create_transfer_config(source_dataset, destination_dataset, destination_location):
    transfer_client = bigquery_datatransfer.DataTransferServiceClient()
    parent = f"projects/{PROJECT}/locations/{destination_location}"

    # Transfer configuration for a dataset copy
    transfer_config = bigquery_datatransfer.TransferConfig(
        destination_dataset_id=destination_dataset,
        display_name=f"Transfer {source_dataset} to {destination_dataset}",
        data_source_id="cross_region_copy",
        params={
            "source_dataset_id": source_dataset,
            "source_project_id": PROJECT,
        },
        schedule_options={
            "disable_auto_scheduling": True,
        }
    )
    transfer_config = transfer_client.create_transfer_config(
        parent=parent,
        transfer_config=transfer_config,
    )
    return transfer_config


# Create the transfer job config
us_to_central_config = create_transfer_config(US_DATASET, CENTRAL_DATASET, LOCATION)

Next, we trigger the transfer job.

In [None]:
import datetime

now = datetime.datetime.now(datetime.timezone.utc)

transfer_client = bigquery_datatransfer.DataTransferServiceClient()
transfer_run = transfer_client.start_manual_transfer_runs({
    "parent": us_to_central_config.name,
    "requested_run_time": now,
})
transfer_run = transfer_run.runs[0]

Finally, we wait for it to complete.

In [41]:
from google.cloud.bigquery_datatransfer import TransferState
import time

while transfer_run.state not in (TransferState.SUCCEEDED, TransferState.FAILED):
    time.sleep(5)
    transfer_run = transfer_client.get_transfer_run({
        "name": transfer_run.name,
    })
transfer_run.state

<TransferState.SUCCEEDED: 4>

If our BQ inputs were already in `us-central1`, we could skip the above steps.

We're now ready to prepare our inputs for the first batch prediction job, which generates a one-sentence summary for each title-abstract pair in the inputs.

Below we load the SQL that'll create the batch input table for the summarization task.

In [None]:
def render_template(template_path, **kwargs):
    return Template(Path(template_path).read_text()).render(**kwargs)


summary_inputs_sql = render_template(
    'sql/summary_inputs.sql',
    dataset=CENTRAL_DATASET,
    corpus='corpus'
)

We load the system instructions for summarization.

In [104]:
summary_prompt = Path('prompts/chip-summarization.txt').read_text()
print(summary_prompt)

You are an expert in physics, chemistry, engineering, and materials science. Your task is to provide a concise summary of a research paper from its title and abstract text. (Or if the text isn't from a research paper, simply answer 'None'.) Your summary should be one sentence in length. Briefly mention the motivation for the work, and then focus on the research task(s) and research method(s) employed by the authors. Do not describe the purported benefits, advantages, importance, or impact of the research.


Define some UDFs we're using in the pipeline.

In [106]:
udfs_sql = render_template('sql/udfs.sql', dataset=CENTRAL_DATASET)
query_job = bq_client.query(udfs_sql, location=LOCATION)
query_job.result()

<google.cloud.bigquery.table._EmptyRowIterator at 0x115c58590>

And we're ready to create the batch inputs.

In [109]:
summary_inputs_table = 'summary_inputs'
summary_outputs_table = 'summary_outputs'


def create_input_table(
        query,
        destination_table,
        prompt,
        temperature=0.5,
        max_output_tokens=512
):
    job_config = bigquery.QueryJobConfig(
        destination=f"{PROJECT}.{CENTRAL_DATASET}.{destination_table}",
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
        query_parameters=[
            bigquery.ScalarQueryParameter("systemInstructions", "STRING", prompt),
            bigquery.ScalarQueryParameter("temperature", "FLOAT64", temperature),
            bigquery.ScalarQueryParameter("maxOutputTokens", "INT64", max_output_tokens),
        ]
    )
    job = bq_client.query(query, job_config=job_config, location=LOCATION)
    print(f'Wrote {job.result().total_rows:,} rows to {destination_table}')


create_input_table(
    summary_inputs_sql,
    summary_inputs_table,
    summary_prompt
)

Wrote 1,000 rows to summary_inputs


Below we start the batch summarization job.

In [110]:
def create_batch_job(input_table, output_table, model):
    input_uri = f"bq://{PROJECT}.{CENTRAL_DATASET}.{input_table}"
    output_uri = f"bq://{PROJECT}.{CENTRAL_DATASET}.{output_table}"
    return BatchPredictionJob.submit(
        source_model=model,
        input_dataset=input_uri,
        output_uri_prefix=output_uri,
    )


summary_job = create_batch_job(summary_inputs_table, summary_outputs_table, MODEL)

BatchPredictionJob created. Resource name: projects/855475113448/locations/us-central1/batchPredictionJobs/9023521745373495296
To use this BatchPredictionJob in another session:
job = batch_prediction.BatchPredictionJob('projects/855475113448/locations/us-central1/batchPredictionJobs/9023521745373495296')
View Batch Prediction Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/batch-predictions/9023521745373495296?project=855475113448


In [111]:
def await_job(job):
    while not job.has_ended:
        time.sleep(5)
        job.refresh()
    elapsed = datetime.datetime.now(tz=datetime.timezone.utc) - job.create_time
    if job.has_succeeded:
        print(f"Job succeeded after {elapsed} h:m:s")
    else:
        print(f"Job failed: {job.error}")


await_job(summary_job)

Job succeeded after 0:02:58.059158 h:m:s


It doesn't take long for just 1K inputs. Now, using the one-sentence summaries, we prepare a classification task. Below is SQL for creating the batch input table for classification.

In [None]:
classify_inputs_sql = render_template(
    'sql/classify_inputs.sql',
    dataset=CENTRAL_DATASET,
    summaries=summary_outputs_table,
)

Here are the system instructions for classification.

In [117]:
classify_prompt = Path('prompts/chip-classification.txt').read_text()
print(classify_prompt)

You are an expert in physics, chemistry, engineering, and materials science. Your task is to classify a research publication given a summary of its contents. Think carefully, and then decide whether the work focuses on the design and manufacturing of integrated circuits, or has applications to chip technology. If so answer YES; otherwise answer NO. Limit your answer to YES or NO.


Using the above, we create the batch inputs for classification.

In [118]:
classify_inputs_table = 'classify_inputs'
classify_outputs_table = 'classify_outputs'

create_input_table(
    classify_inputs_sql,
    classify_inputs_table,
    classify_prompt,
    temperature=0.0,
    max_output_tokens=5
)

Wrote 793 rows to classify_inputs


And we start the batch classification job.

In [119]:
classify_job = create_batch_job(classify_inputs_table, classify_outputs_table, MODEL)
await_job(classify_job)

BatchPredictionJob created. Resource name: projects/855475113448/locations/us-central1/batchPredictionJobs/4534558816791953408
To use this BatchPredictionJob in another session:
job = batch_prediction.BatchPredictionJob('projects/855475113448/locations/us-central1/batchPredictionJobs/4534558816791953408')
View Batch Prediction Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/batch-predictions/4534558816791953408?project=855475113448
Job succeeded after 0:02:54.205242 h:m:s


The SQL below parses the output from the classification task.

In [None]:
labels_sql = render_template(
    'sql/labels.sql',
    dataset=CENTRAL_DATASET,
    labels=classify_outputs_table,
    corpus='corpus',
)

We write the resulting labels to a `labels` table.

In [123]:
labels_table = 'labels'

job_config = bigquery.QueryJobConfig(
    destination=f"{PROJECT}.{CENTRAL_DATASET}.{labels_table}",
    write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
)
query_job = bq_client.query(labels_sql, job_config=job_config, location=LOCATION)
print(f'Wrote {query_job.result().total_rows:,} rows to {labels_table}')

Wrote 1,000 rows to labels


The outputs from each step include `usageMetadata` that we can aggregate to estimate costs. For 1K inputs, pipeline costs were about 0.6 cents.

In [129]:
usage_sql = render_template(
    'sql/usage.sql',
    dataset=CENTRAL_DATASET,
    summaries=summary_outputs_table,
    labels=classify_outputs_table,
)

bpd.close_session()
bpd.options.bigquery.location = LOCATION

bpd.read_gbq(usage_sql)

Unnamed: 0,prompt,output,prompt_cost,output_cost,total_cost
0,428550,48503,0.004018,0.001819,0.005837


Wrap up by summarizing our predictions.

In [131]:
bpd.read_gbq(f"""\
select
  label,
  count(*) as count
from {CENTRAL_DATASET}.labels
group by label
""")

Unnamed: 0,label,count
0,True,35
1,False,965
