<a href="https://colab.research.google.com/github/moasifk/nlp-to-sql/blob/main/NLP_SQL_VertexAI_BigQuery.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
! pip install langchain==0.0.340 --quiet
! pip install chromadb==0.4.13 --quiet
! pip install google-cloud-bigquery[pandas] --quiet
! pip install google-cloud-aiplatform --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.4/49.4 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.8/437.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.8/60.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.3/41.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

# Vertex configuration

In [2]:
VERTEX_PROJECT = "cmi2-0-genai" # @param{type: "string"}
VERTEX_REGION = "us-central1" # @param{type: "string"}

# BigQuery configuration

Don't change these options. These are the public datasets used in this workshop.

In [3]:
BIGQUERY_DATASET = "noaa_tsunami" # @param{type: "string"}
BIGQUERY_PROJECT = "bigquery-public-data" # @param{type: "string"}

# Authentication

In [4]:
from google.colab import auth
auth.authenticate_user()

import vertexai
vertexai.init(project=VERTEX_PROJECT, location=VERTEX_REGION)

# Schemas as context for the prompt

In [5]:
from google.cloud import bigquery
import json

bq_client = bigquery.Client(project=VERTEX_PROJECT)
bq_tables = bq_client.list_tables(dataset=f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}")
schemas = []
for bq_table in bq_tables:
   t = bq_client.get_table(f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}.{bq_table.table_id}")
   schema_fields = [f.to_api_repr() for f in t.schema]
   schema = f"The schema for table {bq_table.table_id} is the following: \n```{json.dumps(schema_fields, indent=1)}```"
   schemas.append(schema)

print(f"Found {len(schemas)} tables in dataset {BIGQUERY_PROJECT}:{BIGQUERY_DATASET}")

Found 2 tables in dataset bigquery-public-data:noaa_tsunami


# Vector store

We add the schemas as documents to a vector store, to be added to the prompt later.

We will retrieve only one document from the store for the prompt: the most relevant doc.

In [7]:
from langchain.embeddings import VertexAIEmbeddings
from langchain.vectorstores import Chroma

embeddings = VertexAIEmbeddings()
vector_store = Chroma.from_texts(schemas, embedding=embeddings)
try: # Avoid duplicated documents
  vector_store.delete_collection()
except:
  print("No need to clean the vector store")
n_docs = len(vector_store.get()['ids'])
retriever = vector_store.as_retriever(search_kwargs={'k': 1})
print(f"The vector store has {n_docs} documents")

The vector store has 2 documents


# Models

In [11]:
from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI

query_model = ChatVertexAI(model_name="codechat-bison", max_output_tokens=2048)
interpret_data_model = ChatVertexAI(max_output_tokens=2048)
agent_model = ChatVertexAI(max_output_tokens=1024)

# Get a SQL query chain

In [12]:
SQL_PROMPT = """You are a SQL and BigQuery expert.

Your job is to create a query for BigQuery in SQL.

The following paragraph contains the schema of the table used for a query. It is encoded in JSON format.

{context}

Create a BigQuery SQL query for the following user input, using the above table.

The user and the agent have done this conversation so far:
{chat_history}

Follow these restrictions strictly:
- Only return the SQL code.
- Do not add backticks or any markup. Only write the query as output. NOTHING ELSE.
- In FROM, always use the full table path, using `{project}` as project and `{dataset}` as dataset.
- Always transform country names to full uppercase. For instance, if the country is Japan, you should use JAPAN in the query.

User input: {question}

SQL query:
"""

In [13]:
from langchain.schema.vectorstore import VectorStoreRetriever
def get_documents(retriever: VectorStoreRetriever, question: str) -> str:
  # Return only the first document
  output = ""
  for d in retriever.get_relevant_documents(question):
    output += d.page_content
    output += "\n"
    return output

In [14]:
from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser


### EXERCISE STARTS HERE
prompt_template = PromptTemplate(
    input_variables=["context", "chat_history", "question", "project", "dataset"],
    template=SQL_PROMPT)

partial_prompt = prompt_template.partial(project=BIGQUERY_PROJECT,
                                         dataset=BIGQUERY_DATASET)

# Input will be like {"input": "SOME_QUESTION", "chat_history": "HISTORY"}
docs = {"context": lambda x: get_documents(retriever, x['input'])}
question = {"question": itemgetter("input")}
chat_history = {"chat_history": itemgetter("chat_history")}
query_chain = docs | question | chat_history | partial_prompt | query_model
query = query_chain | StrOutputParser()
### EXERCISE ENDS HERE

In [15]:
from langchain.callbacks.tracers import ConsoleCallbackHandler
# Example
x = {"input": "Which countries in Asia had more houses damaged? Give me the top 3", "chat_history": ""}
print(query.invoke(x, config={'callbacks': [ConsoleCallbackHandler()]}))

[32;1m[1;3m[chain/start][0m [1m[1:chain:RunnableSequence] Entering Chain run with input:
[0m{
  "input": "Which countries in Asia had more houses damaged? Give me the top 3",
  "chat_history": ""
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:RunnableSequence > 2:chain:RunnableParallel] Entering Chain run with input:
[0m{
  "input": "Which countries in Asia had more houses damaged? Give me the top 3",
  "chat_history": ""
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:RunnableSequence > 2:chain:RunnableParallel > 3:chain:<lambda>] Entering Chain run with input:
[0m{
  "input": "Which countries in Asia had more houses damaged? Give me the top 3",
  "chat_history": ""
}
[32;1m[1;3m[chain/start][0m [1m[1:chain:RunnableSequence > 2:chain:RunnableParallel > 4:chain:RunnableLambda] Entering Chain run with input:
[0m{
  "input": "Which countries in Asia had more houses damaged? Give me the top 3",
  "chat_history": ""
}[32;1m[1;3m[chain/start][0m [1m[1:chain:RunnableSequence > 2

# Add more outputs to the previous chain

In [16]:
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.schema.runnable import RunnableLambda

def _dict_to_json(x: dict) -> str:
  return "```\n" + json.dumps(x) + "\n```"

query_response_schema = [
    ResponseSchema(name="query", description="SQL query to solve the user question."),
    ResponseSchema(name="question", description="Question asked by the user."),
    ResponseSchema(name="context", description="Documents retrieved from the vector store.")
  ]
query_output_parser = StructuredOutputParser.from_response_schemas(query_response_schema)
query_output_json = docs | question | {"query": query} | RunnableLambda(_dict_to_json) | StrOutputParser()
query_output = query_output_json | query_output_parser

In [28]:
# Example
x = {"input": "Which countries in Asia had more houses damaged? Give me the top 3", "chat_history": ""}
query_output.invoke(x)  # Output is now a dictionary, input for the next chain

{'context': 'The schema for table historical_runups is the following: \n```[\n {\n  "name": "id",\n  "type": "INTEGER",\n  "mode": "NULLABLE",\n  "description": "The unique numeric identifier of the record."\n },\n {\n  "name": "tsevent_id",\n  "type": "INTEGER",\n  "mode": "NULLABLE",\n  "description": "The unique numeric identifier of the tsunami source event record that links the runup with the event."\n },\n {\n  "name": "year",\n  "type": "INTEGER",\n  "mode": "NULLABLE",\n  "description": "Valid values: -2000 to Present Format +/-yyyy (-is B.C, +is A.D.)  The Date and Time are given in Universal Coordinated Time (also known as Greenwich Mean Time). The local date may be one day different."\n },\n {\n  "name": "month",\n  "type": "INTEGER",\n  "mode": "NULLABLE",\n  "description": "Valid values: 1-12 The Date and Time are given in Universal Coordinated Time (also known as Greenwich Mean Time). The local date may be one day different."\n },\n {\n  "name": "day",\n  "type": "INTEGER

# Interpret the output chain

In [18]:
INTERPRET_PROMPT = """You are a BigQuery expert. You are also expert in extracting data from CSV.

The following paragraph describes the schema of the table used for a query. It is encoded in JSON format.

{context}

A user asked this question:
{question}

To find the answer, the following SQL query was run in BigQuery:
```
{query}
```

The result of that query was the following table in CSV format:
```
{result}
```

Based on those results, provide a brief answer to the user question.

Follow these restrictions strictly:
- Do not add any explanation about how the answer is obtained, just write the answer.
- Extract any value related to the answer only from the result of the query. Do not use any other data source.
- Just write the answer, omit the question from your answer, this is a chat, just provide the answer.
- If you cannot find the answer in the result, do not make up any data, just say "I cannot find the answer"
"""

In [19]:
from google.cloud import bigquery
def get_bq_csv(bq_client: bigquery.Client, query: str) -> str:
  df = bq_client.query(query, location="US").to_dataframe()
  return df.to_csv(index=False)

In [20]:
# Get the output of the previous chain

### EXERCISE STARTS HERE
query = {"query": itemgetter("query")}
context = {"context": itemgetter("context")}
question = {"question": itemgetter("question")}
query_result = {"result": lambda x: get_bq_csv(bq_client, x["query"])}

prompt = PromptTemplate(
    input_variables=["question", "query", "result", "context"],
    template=INTERPRET_PROMPT)

run_bq_chain = context | question | query | query_result | prompt
run_bq_result = run_bq_chain | interpret_data_model | StrOutputParser()
### EXERCISE ENDS HERE

In [30]:
run_bq_result

{
  context: RunnableLambda(...),
  question: RunnableLambda(...),
  query: RunnableLambda(...),
  result: RunnableLambda(lambda x: get_bq_csv(bq_client, x['query']))
}
| PromptTemplate(input_variables=['context', 'query', 'question', 'result'], template='You are a BigQuery expert. You are also expert in extracting data from CSV.\n\nThe following paragraph describes the schema of the table used for a query. It is encoded in JSON format.\n\n{context}\n\nA user asked this question:\n{question}\n\nTo find the answer, the following SQL query was run in BigQuery:\n```\n{query}\n```\n\nThe result of that query was the following table in CSV format:\n```\n{result}\n```\n\nBased on those results, provide a brief answer to the user question.\n\nFollow these restrictions strictly:\n- Do not add any explanation about how the answer is obtained, just write the answer.\n- Extract any value related to the answer only from the result of the query. Do not use any other data source.\n- Just write the a

In [31]:
# Example
x = {"input": "Which countries in Asia had more houses damaged Give me the top 3", "chat_history": ""}
print(run_bq_result.invoke(query_output.invoke(x)))

BadRequest: 400 Syntax error: Unexpected identifier `` at [1:2]; reason: invalidQuery, location: query, message: Syntax error: Unexpected identifier `` at [1:2]

Location: US
Job ID: 859f5975-7331-4c1c-8e0b-9d23db712a5c


# Agent: putting everything together

In [None]:
from langchain.memory import ConversationBufferWindowMemory

agent_memory = ConversationBufferWindowMemory(
    memory_key="chat_history",
    k=10,
    return_messages=True)

In [None]:
AGENT_PROMPT = """You are a very powerful assistant that can answer questions using BigQuery.

You can invoke the tool user_question_tool to answer questions using BigQuery.

You can invoke the tool Calculator if you need to do mathematical operations.

Always use the tools to try to answer the questions. Use the chat history for context. Never try to use any other external information.

Assume that the user may write with misspellings, fix the spelling of the user before passing the question to any tool.

Don't mention what tool you have used in your answer.
"""

In [None]:
from langchain import LLMMathChain
from langchain.tools import Tool

math_chain = LLMMathChain.from_llm(llm=agent_model)
math_tool = Tool(
  name="Calculator",
  description="Useful for when you need to answer questions about math.",
  func=math_chain.run,
  coroutine=math_chain.arun)

In [None]:
from langchain.tools import tool
from langchain.callbacks.tracers import ConsoleCallbackHandler

@tool
def user_question_tool(question) -> str:
  """Useful to answer natural language questions from users using BigQuery."""
  config={'callbacks': [ConsoleCallbackHandler()]}
  memory = agent_memory.buffer_as_str.strip()
  question = {"input": question, "chat_history": memory}
  result = run_bq_result.invoke(query_output.invoke(question, config=config), config=config)
  return result.strip()

In [None]:
from langchain.agents import AgentType, initialize_agent, AgentExecutor

agent_kwgards = {"system_message": AGENT_PROMPT}
agent_tools = [math_tool, user_question_tool]

agent_memory.clear()

# Fill the missing options
agent = initialize_agent(
    ### TODO
    ### TODO
    ### TODO
    agent_kwgards=agent_kwgards,
    max_iterations=5,
    early_stopping_method='generate',
    verbose=True)

In [None]:
q = "Which countries had more houses damaged? Give me the top 3"
agent.invoke(x)

In [None]:
agent_memory

In [None]:
q = "Of those countries, which one had more deaths?"
agent.invoke(q)