In [2]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Use BigQuery DataFrames to cluster and characterize complaints

<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/googleapis/python-bigquery-dataframes/tree/main/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/googleapis/python-bigquery-dataframes/tree/main/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/googleapis/python-bigquery-dataframes/tree/main/notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      Open in Vertex AI Workbench
    </a>
  </td>                                                                                               
</table>

## Overview

The goal of this notebook is to demonstrate a comment characterization algorithm for an online business. We will accomplish this using [Google's PaLM 2](https://ai.google/discover/palm2/) and [KMeans clustering](https://en.wikipedia.org/wiki/K-means_clustering) in three steps:

1. Use PaLM2TextEmbeddingGenerator to [generate text embeddings](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings) for each of 10000 complaints sent to an online bank. If you're not familiar with what a text embedding is, it's a list of numbers that are like coordinates in an imaginary "meaning space" for sentences. (It's like [word embeddings](https://en.wikipedia.org/wiki/Word_embedding), but for more general text.) The important point for our purposes is that similar sentences are close to each other in this imaginary space.
2. Use KMeans clustering to group together complaints whose text embeddings are near to eachother. This will give us sets of similar complaints, but we don't yet know _why_ these complaints are similar.
3. Simply ask PaLM2TextGenerator in English what the difference is between the groups of complaints that we got. Thanks to the power of modern LLMs, the response might give us a very good idea of what these complaints are all about, but remember to ["understand the limits of your dataset and model."](https://ai.google/responsibility/responsible-ai-practices/#:~:text=Understand%20the%20limitations%20of%20your%20dataset%20and%20model)

We will tie these pieces together in Python using BigQuery DataFrames. [Click here](https://cloud.google.com/bigquery/docs/dataframes-quickstart) to learn more about BigQuery DataFrames!

### Dataset

This notebook uses the [CFPB Consumer Complaint Database](https://console.cloud.google.com/marketplace/product/cfpb/complaint-database).

### Costs

This tutorial uses billable components of Google Cloud:

* BigQuery (compute)
* BigQuery ML

Learn about [BigQuery compute pricing](https://cloud.google.com/bigquery/pricing#analysis_pricing_models),
and [BigQuery ML pricing](https://cloud.google.com/bigquery/pricing#bqml),
and use the [Pricing Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Step 1: Text embedding 

### Project Setup

In [3]:
import bigframes.pandas as bpd

bpd.options.bigquery.project = "bigframes-dev"
bpd.options.bigquery.location = "us"

Data Input

In [4]:
input_df = bpd.read_gbq("bigquery-public-data.cfpb_complaints.complaint_database")

In [5]:
issues_df = input_df[["consumer_complaint_narrative"]].dropna()
issues_df.head(n=5) # View the first five complaints

Unnamed: 0,consumer_complaint_narrative
0,"Those Accounts Are Not mine, I never authorize..."
11,"Legal Department, This credit dispute is being..."
12,"Hello my name is XXXX XXXX, I have looked into..."
15,I HAVE REVIEWED MY CREDIT REPORT AND FOUND SOM...
16,On my credit report these are not my items rep...


In [6]:
# Choose 10,000 complaints randomly
downsampled_issues_df = issues_df.sample(n=10000)

Generate the text embeddings

In [7]:
from bigframes.ml.llm import PaLM2TextEmbeddingGenerator

model = PaLM2TextEmbeddingGenerator() # No connection id needed

In [8]:
# Will take ~5 minutes to compute the embeddings
predicted_embeddings = model.predict(downsampled_issues_df)
# Notice the lists of numbers that are our text embeddings for each complaint
predicted_embeddings.head() 

Unnamed: 0,text_embedding
355,"[0.0032048337161540985, 0.018182063475251198, ..."
414,"[-0.025085292756557465, -0.05178036540746689, ..."
650,"[0.0020703477784991264, -0.027994778007268906,..."
969,"[-0.009529653936624527, -0.03827650472521782, ..."
1009,"[0.0190849881619215, -0.026688968762755394, 0...."


In [None]:
# Join the complaints with their embeddings in the same DataFrame
combined_df = downsampled_issues_df.join(predicted_embeddings)

## Step 2: KMeans clustering

In [10]:
from bigframes.ml.cluster import KMeans

cluster_model = KMeans(n_clusters=10) # We will divide our complaints into 10 groups

In [11]:
# Use KMeans clustering to calculate our groups. Will take ~5 minutes.
cluster_model.fit(combined_df[["text_embedding"]])
clustered_result = cluster_model.predict(combined_df[["text_embedding"]])
# Notice the CENTROID_ID column, which is the ID number of the group that
# each complaint belongs to.
clustered_result.head(n=5)

Unnamed: 0,CENTROID_ID
355,4
414,2
650,1
969,5
1009,5


In [12]:
# Join the group number to the complaints and their text embeddings
combined_clustered_result = combined_df.join(clustered_result)

## Step 3: Summarize the complaints

In [13]:
# Using bigframes, with syntax identical to pandas,
# filter out the first and second groups
cluster_1_result = combined_clustered_result[
    combined_clustered_result["CENTROID_ID"] == 1][["consumer_complaint_narrative"]
]
cluster_1_result_pandas = cluster_1_result.head(5).to_pandas()

cluster_2_result = combined_clustered_result[
    combined_clustered_result["CENTROID_ID"] == 2][["consumer_complaint_narrative"]
]
cluster_2_result_pandas = cluster_2_result.head(5).to_pandas()

In [14]:
# Build plain-text prompts to send to PaLM 2. Use only 5 complaints from each group.
prompt1 = 'comment list 1:\n'
for i in range(5):
    prompt1 += str(i + 1) + '. ' + \
        cluster_1_result_pandas["consumer_complaint_narrative"].iloc[i] + '\n'

prompt2 = 'comment list 2:\n'
for i in range(5):
    prompt2 += str(i + 1) + '. ' + \
        cluster_2_result_pandas["consumer_complaint_narrative"].iloc[i] + '\n'

print(prompt1)
print(prompt2)


comment list 1:
1. I bought my home XX/XX/XXXX for the amount of {$220000.00}. The home was appraised closing with a value of {$260000.00} at closing. When purchasing the home I did not provide a downpayment in the amount of 20 % of the home value, therefore I had to purchase private mortgage insurance ( P.M.I. ) on the home until 20 % of the home value was paid off. 20 % of {$260000.00} is {$53000.00}. This means I would have to owe ( $ XXXX- {$53000.00} ) {$210000.00} or less for the P.M.I. to be taken off of my monthly mortgage payments. According to law, the lender should take the P.M.I. off of my loan once the 20 % is met. At the time of closing my borrower did not provide me a PMI disclosure form to identify when the 20 % mark would be met. 

When closing on my home my loan was thru XXXX XXXX XXXX, for the past 5+ years my loan was taken over by Wells Fargo Home Mortgage and they are my current lenders. I have never missed or been late on a mortgage payment. 

In XX/XX/XXXX I rea

In [15]:
# The plain English request we will make of PaLM 2
prompt = (
    "Please highlight the most obvious difference between"
    "the two lists of comments:\n" + prompt1 + prompt2
)
print(prompt)

Please highlight the most obvious difference betweenthe two lists of comments:
comment list 1:
1. I bought my home XX/XX/XXXX for the amount of {$220000.00}. The home was appraised closing with a value of {$260000.00} at closing. When purchasing the home I did not provide a downpayment in the amount of 20 % of the home value, therefore I had to purchase private mortgage insurance ( P.M.I. ) on the home until 20 % of the home value was paid off. 20 % of {$260000.00} is {$53000.00}. This means I would have to owe ( $ XXXX- {$53000.00} ) {$210000.00} or less for the P.M.I. to be taken off of my monthly mortgage payments. According to law, the lender should take the P.M.I. off of my loan once the 20 % is met. At the time of closing my borrower did not provide me a PMI disclosure form to identify when the 20 % mark would be met. 

When closing on my home my loan was thru XXXX XXXX XXXX, for the past 5+ years my loan was taken over by Wells Fargo Home Mortgage and they are my current lenders

In [16]:
from bigframes.ml.llm import PaLM2TextGenerator

q_a_model = PaLM2TextGenerator(connection_name="bigframes-dev.us.bigframes-ml")

In [17]:
# Make a DataFrame containing only a single row with our prompt for PaLM 2
df = bpd.DataFrame({"prompt": [prompt]})

  if _pandas_api.is_sparse(col):


In [19]:
# Send the request for PaLM 2 to generate a response to our prompt
major_difference = q_a_model.predict(df)
# PaLM 2's response is the only row in the dataframe result 
major_difference["ml_generate_text_llm_result"].iloc[0]

BadRequest: 400 POST https://bigquery.googleapis.com/bigquery/v2/projects/bigframes-dev/jobs?prettyPrint=false: Syntax error: Unclosed string literal at [5:104]

Location: us
Job ID: 9b28df64-af3c-4dcc-b679-4300c3deab88
 [{'@type': 'type.googleapis.com/google.rpc.DebugInfo', 'detail': '[INVALID_INPUT] message=QUERY_ERROR: [Syntax error: Unclosed string literal at [5:104]] errorProto=code: "QUERY_ERROR"\nargument: "Syntax error: Unclosed string literal at [5:104]"\nlocation_type: OTHER\nlocation: "query"\n\n\tat com.google.cloud.helix.common.Exceptions.fromProto(Exceptions.java:2072)\n\tat com.google.cloud.helix.server.job.DremelErrorUtil.checkStatusWithDremelDetails(DremelErrorUtil.java:162)\n\tat com.google.cloud.helix.server.job.GoogleSqlQueryTransformer.parseQueryUncached(GoogleSqlQueryTransformer.java:527)\n\tat com.google.cloud.helix.server.job.GoogleSqlQueryTransformer.parseQuery(GoogleSqlQueryTransformer.java:511)\n\tat com.google.cloud.helix.server.job.GoogleSqlQueryTransformer.validateQuery(GoogleSqlQueryTransformer.java:251)\n\tat com.google.cloud.helix.server.job.LocalQueryJobController.checkQuery(LocalQueryJobController.java:4331)\n\tat com.google.cloud.helix.server.job.LocalQueryJobController.checkInternal(LocalQueryJobController.java:4461)\n\tat com.google.cloud.helix.server.job.LocalQueryJobController.checkAsync(LocalQueryJobController.java:4415)\n\tat com.google.cloud.helix.server.job.LocalSqlJobController.checkAsync(LocalSqlJobController.java:125)\n\tat com.google.cloud.helix.server.job.LocalJobController.check(LocalJobController.java:1247)\n\tat com.google.cloud.helix.server.job.JobControllerModule$1.check(JobControllerModule.java:461)\n\tat com.google.cloud.helix.server.job.JobStateMachine$1.check(JobStateMachine.java:3585)\n\tat com.google.cloud.helix.server.job.JobStateMachine.dryRunJob(JobStateMachine.java:2515)\n\tat com.google.cloud.helix.server.job.JobStateMachine.execute(JobStateMachine.java:2494)\n\tat com.google.cloud.helix.server.job.ApiJobStateChanger.execute(ApiJobStateChanger.java:33)\n\tat com.google.cloud.helix.server.job.rosy.HelixJobRosy.insertNormalizedJob(HelixJobRosy.java:1998)\n\tat com.google.cloud.helix.server.job.rosy.HelixJobRosy.insertJobInternal(HelixJobRosy.java:2467)\n\tat com.google.cloud.helix.server.job.rosy.HelixJobRosy.insertInternal(HelixJobRosy.java:2492)\n\tat com.google.cloud.helix.server.job.rosy.HelixJobRosy.insertRequestInternal(HelixJobRosy.java:3918)\n\tat com.google.cloud.helix.server.job.rosy.HelixJobRosy.insert(HelixJobRosy.java:3892)\n\tat jdk.internal.reflect.GeneratedMethodAccessor305.invoke(Unknown Source)\n\tat java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)\n\tat java.base/java.lang.reflect.Method.invoke(Unknown Source)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$innerContinuation$3(RpcRequestProxy.java:435)\n\tat com.google.cloud.helix.common.rosy.RosyRequestDapperHookFactory$TracingRequestHook.call(RosyRequestDapperHookFactory.java:88)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RosyRequestCredsHookFactory$1.call(RosyRequestCredsHookFactory.java:56)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RosyRequestConcurrentCallsHookFactory$Hook.call(RosyRequestConcurrentCallsHookFactory.java:101)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RosyRequestVarzHookFactory$Hook.call(RosyRequestVarzHookFactory.java:464)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.server.rosy.RosyRequestAuditHookFactory$1.call(RosyRequestAuditHookFactory.java:110)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RequestSecurityExtensionForGwsHookFactory$1.call(RequestSecurityExtensionForGwsHookFactory.java:69)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RosyRequestSecurityContextHookFactory$1.call(RosyRequestSecurityContextHookFactory.java:80)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.server.rosy.RosyRequestContextHookFactory.call(RosyRequestContextHookFactory.java:58)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.lambda$makeContinuation$4(RpcRequestProxy.java:461)\n\tat com.google.cloud.helix.common.rosy.RpcRequestProxy.invoke(RpcRequestProxy.java:666)\n\tat com.sun.proxy.$Proxy52.insert(Unknown Source)\n\tat com.google.cloud.helix.proto.proto2api.HelixJobService$ServiceParameters$1.handleRequest(HelixJobService.java:917)\n\tat com.google.net.rpc3.impl.server.RpcServerInterceptor2Util$RpcApplicationHandlerAdaptor.handleRequest(RpcServerInterceptor2Util.java:82)\n\tat com.google.net.rpc3.impl.server.AggregatedRpcServerInterceptors.interceptRpc(AggregatedRpcServerInterceptors.java:97)\n\tat com.google.net.rpc3.impl.server.RpcServerInterceptor2Util$InterceptedApplicationHandlerImpl.handleRequest(RpcServerInterceptor2Util.java:67)\n\tat com.google.net.rpc3.impl.server.RpcServerInternalContext.runRpcInApplicationWithCancellation(RpcServerInternalContext.java:686)\n\tat com.google.net.rpc3.impl.server.RpcServerInternalContext.lambda$runRpcInApplication$0(RpcServerInternalContext.java:651)\n\tat io.grpc.Context.run(Context.java:536)\n\tat com.google.net.rpc3.impl.server.RpcServerInternalContext.runRpcInApplication(RpcServerInternalContext.java:651)\n\tat com.google.net.rpc3.util.RpcInProcessConnector$ServerInternalContext.lambda$runWithExecutor$1(RpcInProcessConnector.java:1964)\n\tat com.google.common.context.ContextRunnable.runInContext(ContextRunnable.java:83)\n\tat io.grpc.Context.run(Context.java:536)\n\tat com.google.tracing.GenericContextCallback.runInInheritedContext(GenericContextCallback.java:75)\n\tat com.google.common.context.ContextRunnable.run(ContextRunnable.java:74)\n\tat java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)\n\tat java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)\n\tat java.base/java.lang.Thread.run(Unknown Source)\n\tSuppressed: java.lang.Exception: Including call stack from HelixFutures\n\t\tat com.google.cloud.helix.common.HelixFutures.getHelixException(HelixFutures.java:76)\n\t\tat com.google.cloud.helix.common.HelixFutures.get(HelixFutures.java:42)\n\t\tat com.google.cloud.helix.server.job.JobStateMachine.dryRunJob(JobStateMachine.java:2514)\n\t\t... 45 more\n\tSuppressed: java.lang.Exception: Including call stack from HelixFutures\n\t\tat com.google.cloud.helix.common.HelixFutures.getHelixException(HelixFutures.java:76)\n\t\tat com.google.cloud.helix.common.HelixFutures.get(HelixFutures.java:42)\n\t\t... 41 more\n'}]