<div style="display: flex; align-items: left;">
    <a href="https://sites.google.com/corp/google.com/genai-solutions/home?authuser=0">
        <img src="https://storage.googleapis.com/miscfilespublic/Linkedin%20Banner%20%E2%80%93%202.png" style="margin-right">
    </a>
</div>

In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<h1 align="center">Open Data QnA - Chat with your SQL Database (standalone notebook)</h1> 

---

This notebook is a standalone version for quick testing of the solution without the need to clone the complete Github repository. 
Given the standalone nature, the notebook only supports BigQuery as the data source and for the Vector Store. 

The notebook first walks through the Vector Store Setup needed for running the Open Data QnA application. 

Currently supported Source DBs **from the complete solution** are: 
- PostgreSQL on Google Cloud SQL 
- BigQuery

Furthermore, the following vector stores are supported 
- pgvector on PostgreSQL 
- BigQuery vector

**If you want to use a different data source or a different vector store, make sure to clone the repository and run the standard (non-standalone) notebook instead!**

The setup part covers the following steps: 
> 1. Configuration: Intial GCP project, IAM permissions, Environment  and Databases setup including logging on Bigquery for analytics

> 2. Creation of Table, Column and Known Good Query Embeddings in the Vector Store  for Retreival Augmented Generation(RAG)


Afterwards, you will be able to run the Open Data QnA Pipeline to generate SQL queries and answer questions over your data source. 

The pipeline run covers the following steps: 

> 1. Take user question and generate sql in the dialect corresponding to data source

> 2. Execute the sql query and retreive the data

> 3. Generate natural language respose and charts to display

> 4. Clean Up resources



### 📒 Using this interactive notebook

If you have not used this IDE with jupyter notebooks it will ask for installing Python + Jupyter extensions. Please go ahead install them

Click the **run** icons ▶️  of each cell within this notebook.

> 💡 Alternatively, you can run the currently selected cell with `Ctrl + Enter` (or `⌘ + Enter` on a Mac).

> ⚠️ **To avoid any errors**, wait for each section to finish in their order before clicking the next “run” icon.

This sample must be connected to a **Google Cloud project**, but nothing else is needed other than your Google Cloud project.

You can use an existing project. Alternatively, you can create a new Cloud project [with cloud credits for free.](https://cloud.google.com/free/docs/gcp-free-tier)

# 🚧 **0. Prerequisites**

Make sure that Google Cloud CLI is installed before moving to the next cell! You can refer to the link below for guidance

Installation Guide: https://cloud.google.com/sdk/docs/install

##  **0.1. Setup Poetry Environment and Setup GCP Project** 

### 💻 **Install Code Dependencies**
Install the dependencies by runnign the commands below. 


In [None]:
!pip install vertexai 
!pip install pandas 
!pip install db-dtypes
!pip install google 
!pip install google-cloud-bigquery 
!pip install google-oauth 
!pip install tabulate
!pip install google-cloud-bigquery-connection

### 📌 **Important Step: Activate your virtual environment and authenticate with Google Cloud CLI [Run all these on Terminal]**

Once the venv created either in the local directory or in the cache directory. Open the terminal on the same machine where your notebooks are running and start running the below commands.


```
poetry shell #this command should activate your venv and you should see it enters into the venv

##inside the activated venv shell

gcloud auth login
gcloud auth application-default login
gcloud services enable \
    serviceusage.googleapis.com \
    cloudresourcemanager.googleapis.com --project <<Enter Project Id>>
gcloud auth application-default set-quota-project <<Enter Project Id for using resources>>

```
For IDEs adding Juypter Extensions will automatically give you option to change the kernel. If not, manually select the python interpreter in your IDE (The exact is shown in the above cell. Path would look like e.g. /home/admin_/Talk2Data/.venv/bin/python or ~cache/user/Talk2Data/.venv/bin/python)

**Extra Steps if you are running inside Jupyter Lab or Jupyter Environments on Workbench etc**

We need to manually add venv as Kernel in the those instance where you don't have choice to select the path manually like above.

Run the steps above and continue with below

```
##still inside the activated venv shell

pip install jupyter

ipython kernel install --name "openqna-venv" --user 

```

Restart your kernel or close the exsiting notebook and open again, you should now see the "openqna-venv" in the kernel drop down

**What did we do here?**

* Created Application Default Credentials to use for the code
* Added venv to kernel to select for runningt the notebooks (For standalone Jupyter setups like Workbench etc)

### 🔐 **Authenticate to Google Cloud** (Colab)
Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.


In [None]:
"""Colab Auth""" 
# from google.colab import auth
# auth.authenticate_user()

### 🔗 **Connect Your Google Cloud Project**
Time to connect your Google Cloud Project to this notebook. 

In [None]:
import os 

#@markdown Please fill in the value below with your GCP project ID and then run the cell.
project_id = "my_project"

# Quick input validations.
assert project_id, "⚠️ Please provide your Google Cloud Project ID"

# Configure gcloud.
!gcloud config set project {project_id}
print(f'Project has been set to {project_id}')

os.environ['GOOGLE_CLOUD_QUOTA_PROJECT']=project_id
os.environ['GOOGLE_CLOUD_PROJECT']=project_id

### ⚙️ **Enable Required API Services in the GCP Project**

In [None]:
#Enable all the required APIs for the Open Data QnA solution

!gcloud services enable \
  cloudapis.googleapis.com \
  compute.googleapis.com \
  iam.googleapis.com \
  run.googleapis.com \
  sqladmin.googleapis.com \
  aiplatform.googleapis.com \
  bigquery.googleapis.com \

### **Set Parameters**

Fill out the parameters and configuration settings below. 
These are the parameters for connecting to the source databases and setting configurations for the vector store tables to be created. 
Additionally, you can specify whether you have and want to use known-good-queries for the pipeline run and whether you want to enable logging.

**Known good queries:** if you have known working user question <-> SQL query pairs, you can put them into the file `scripts/known_good_sql.csv`. This will be used as a caching layer and for in-context learning: If an exact match of the user question is found in the vector store, the pipeline will skip SQL Generation and output the cached SQL query. If the similarity score is between 90-100%, the known good queries will be used as few-shot examples by the SQL Generator Agent. 

In [None]:
# Data source details
DATA_SOURCE = 'bigquery' # Options: 'bigquery' and 'cloudsql-pg' i.e, PostgreSQL database on Google Cloud SQL

# Please specify what you would like to use as vector store for embeddings
VECTOR_STORE = 'bigquery-vector' # Options: 'bigquery-vector' i.e, Bigquery vector and 'cloudsql-pgvector' i.e, pgvector on PostgreSQL


# If you have chosen 'bigquery' as DATA_SOURCE; provide information below
BQ_REGION = 'us-central1'
BQ_DATASET_NAME = 'imdb'
# you can specify the names of the bq tables you want to query over specifially. If left empty, Open Data QnA will parse through all the tables in the dataset.
BQ_TABLE_LIST = None # either None or a list of table names in format ['reviews', 'ratings']


# Specify if you have example question & known-good-query pairs you want to leverage 
EXAMPLES = False 

# If Logging is enabled OR you are using bigquery-vector as the data store, a Big Query dataset will be created to hold the tables. 
# You can rename the table below if you wish to do so. 
BQ_OPENDATAQNA_DATASET_NAME = 'opendataqna'

Quick input verifications below:

In [None]:

# Input verification - Source
assert DATA_SOURCE in {'bigquery', 'cloudsql-pg'}, "⚠️ Invalid DATA_SOURCE. Must be 'bigquery' or 'cloudsql-pg'"

# Input verification - Vector Store
assert VECTOR_STORE in {'bigquery-vector', 'cloudsql-pgvector'}, "⚠️ Invalid VECTOR_STORE. Must be 'bigquery-vector' or 'cloudsql-pgvector'"

if DATA_SOURCE == 'bigquery':
    assert BQ_REGION, "⚠️ Please provide the Data Set Region"
    assert BQ_DATASET_NAME, "⚠️ Please provide the name of the dataset on Bigquery"

    DATASET_REGION = BQ_REGION


## **0.2. Implement Agent Classes and Helper Functions**

#### Big Query Connector

The `BQConnector` is a class for interacting with Google BigQuery. It simplifies the process of connecting to BigQuery, retrieving data into Pandas DataFrames, and finding similar or exact matches for queries using vector embeddings. This connector also enables users to validate SQL queries through dry runs before execution, ensuring efficient and cost-effective data analysis. Whether you're exploring datasets, building data pipelines, or integrating BigQuery with other tools, the BQConnector provides a convenient and powerful interface for working with your data.

In [None]:
"""
BigQuery Connector Class
"""
from google.cloud import bigquery
from google.cloud import bigquery_connection_v1 as bq_connection
from abc import ABC
from datetime import datetime
import google.auth
import pandas as pd
from google.cloud.exceptions import NotFound

def get_auth_user():
    credentials, project_id = google.auth.default()

    if hasattr(credentials, 'service_account_email'):
        return credentials.service_account_email
    else:
        return "Not Determined"

def bq_specific_data_types(): 
    return '''
    BigQuery offers a wide variety of datatypes to store different types of data effectively. Here's a breakdown of the available categories:
    Numeric Types -
    INTEGER (INT64): Stores whole numbers within the range of -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807. Ideal for non-fractional values.
    FLOAT (FLOAT64): Stores approximate floating-point numbers with a range of -1.7E+308 to 1.7E+308. Suitable for decimals with a degree of imprecision.
    NUMERIC: Stores exact fixed-precision decimal numbers, with up to 38 digits of precision and 9 digits to the right of the decimal point. Useful for precise financial and accounting calculations.
    BIGNUMERIC: Similar to NUMERIC but with even larger scale and precision. Designed for extreme precision in calculations.
    
    Character Types -
    STRING: Stores variable-length Unicode character sequences. Enclosed using single, double, or triple quotes.
    
    Boolean Type -
    BOOLEAN: Stores logical values of TRUE or FALSE (case-insensitive).
    
    Date and Time Types -
    DATE: Stores dates without associated time information.
    TIME: Stores time information independent of a specific date.
    DATETIME: Stores both date and time information (without timezone information).
    TIMESTAMP: Stores an exact moment in time with microsecond precision, including a timezone component for global accuracy.
    
    Other Types
    BYTES: Stores variable-length binary data. Distinguished from strings by using 'B' or 'b' prefix in values.
    GEOGRAPHY: Stores points, lines, and polygons representing locations on the Earth's surface.
    ARRAY: Stores an ordered collection of zero or more elements of the same (non-ARRAY) data type.
    STRUCT: Stores an ordered collection of fields, each with its own name and data type (can be nested).
    
    This list covers the most common datatypes in BigQuery.
    '''

class BQConnector(ABC):
    """
    Instantiates a BigQuery Connector.
    """
    connectorType: str = "Base"

    def __init__(self,
                 project_id:str,
                 region:str,
                 dataset_name:str,
                 opendataqna_dataset:str):


        self.project_id = project_id
        self.region = region
        self.dataset_name = dataset_name
        self.opendataqna_dataset = opendataqna_dataset
        self.client=self.getconn()

    def getconn(self):
        client = bigquery.Client(project=self.project_id)
        return client
    
    def retrieve_df(self,query):
        return self.client.query_and_wait(query).to_dataframe()
   
    
    def retrieve_matches(self, mode, schema, qe, similarity_threshold, limit): 
        """
        This function retrieves the most similar table_schema and column_schema.
        Modes can be either 'table', 'column', or 'example' 
        """
        matches = []

        if mode == 'table':
            sql = '''select base.content as tables_content from vector_search(TABLE `{}.table_details_embeddings`, "embedding", 
            (SELECT {} as qe), top_k=> {},distance_type=>"COSINE") where 1-distance > {} '''
        
        elif mode == 'column':
            sql='''select base.content as columns_content from vector_search(TABLE `{}.tablecolumn_details_embeddings`, "embedding",
            (SELECT {} as qe), top_k=> {}, distance_type=>"COSINE") where 1-distance > {} '''

        elif mode == 'example': 
            sql='''select base.example_user_question, base.example_generated_sql from vector_search ( TABLE `{}.example_prompt_sql_embeddings`, "embedding",
            (select {} as qe), top_k=> {}, distance_type=>"COSINE") where 1-distance > {} '''
    
        else: 
            ValueError("No valid mode. Must be either table, column, or example")
            name_txt = ''


        results=self.client.query_and_wait(sql.format('{}.{}'.format(self.project_id,self.opendataqna_dataset),qe,limit,similarity_threshold)).to_dataframe()


        # CHECK RESULTS 
        if len(results) == 0:
            print("Did not find any results. Adjust the query parameters.")

        if mode == 'table': 
            name_txt = ''
            for _ , r in results.iterrows():
                name_txt=name_txt+r["tables_content"]+"\n"

        elif mode == 'column': 
            name_txt = '' 
            for _ ,r in results.iterrows():
                name_txt=name_txt+r["columns_content"]+"\n"

        elif mode == 'example': 
            name_txt = ''
            for _ , r in results.iterrows():
                example_user_question=r["example_user_question"]
                example_sql=r["example_generated_sql"]
                name_txt = name_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql

        else: 
            ValueError("No valid mode. Must be either table, column, or example")
            name_txt = ''

        matches.append(name_txt)
        

        return matches

    def getSimilarMatches(self, mode, schema, qe, num_matches, similarity_threshold):

        if mode == 'table': 
            match_result= self.retrieve_matches(mode, schema, qe, similarity_threshold, num_matches)
            match_result = match_result[0]
            # print(match_result)

        elif mode == 'column': 
            match_result= self.retrieve_matches(mode, schema, qe, similarity_threshold, num_matches)
            match_result = match_result[0]
        
        elif mode == 'example': 
            match_result= self.retrieve_matches(mode, schema, qe, similarity_threshold, num_matches)
            if len(match_result) == 0:
                match_result = None
            else:
                match_result = match_result[0]

        return match_result

    def getExactMatches(self, query):
        """Checks if the exact question is already present in the example SQL set"""
        check_history_sql=f"""SELECT example_user_question,example_generated_sql FROM {self.project_id}.{self.opendataqna_dataset}.example_prompt_sql_embeddings
                          WHERE lower(example_user_question) = lower('{query}') LIMIT 1; """

        exact_sql_history = self.client.query_and_wait(check_history_sql).to_dataframe()


        if exact_sql_history[exact_sql_history.columns[0]].count() != 0:
            sql_example_txt = ''
            exact_sql = ''
            for index, row in exact_sql_history.iterrows():
                example_user_question=row["example_user_question"]
                example_sql=row["example_generated_sql"]
                exact_sql=example_sql
                sql_example_txt = sql_example_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql

            print("Found a matching question from the history!" + str(sql_example_txt))
            final_sql=exact_sql

        else: 
            print("No exact match found for the user prompt")
            final_sql = None

        return final_sql

    def test_sql_plan_execution(self, generated_sql):
        try:

            job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)
            query_job = self.client.query(generated_sql,job_config=job_config)
            # print(query_job)
            exec_result_df=("This query will process {} bytes.".format(query_job.total_bytes_processed))
            correct_sql = True
            print(exec_result_df)
            return correct_sql, exec_result_df
        except Exception as e:
            return False,str(e)




In [None]:
connector = BQConnector(project_id,BQ_REGION,BQ_DATASET_NAME,BQ_OPENDATAQNA_DATASET_NAME)

### **Initial Step: Specifying Agent Classes**
Before moving forward with the standalone notebook, let's first implement the class functions for the agents we are going to use. 

#### Base Class for Agents
The `Agent` class is the foundation for building agents with Large Language Models (LLMs) on Vertex AI. It handles model initialization based on the specified model ID and provides a method to generate responses from the chosen LLM. This base class serves as a template for creating specialized agents within the Open Data QnA project, promoting code reusability and a structured approach to working with LLMs.

In [None]:
"""
Provides the base class for all Agents 
"""

from abc import ABC
import vertexai
from vertexai.language_models import TextGenerationModel
from vertexai.language_models import CodeGenerationModel
from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmCategory,HarmBlockThreshold
from vertexai.generative_models import GenerationConfig




class Agent(ABC):
    """
    The core class for all Agents
    """

    agentType: str = "Agent"

    def __init__(self,
                model_id:str):
        """
        Args:
            PROJECT_ID (str | None): GCP Project Id.
            dataset_name (str): 
            TODO
        """

        self.model_id = model_id 

        if model_id == 'code-bison-32k':
            self.model = CodeGenerationModel.from_pretrained('code-bison-32k')
        elif model_id == 'text-bison-32k':
            self.model = TextGenerationModel.from_pretrained('text-bison-32k')
        elif model_id == 'codechat-bison-32k':
            self.model = CodeChatModel.from_pretrained("codechat-bison-32k")
        elif model_id.startswith'(gemini-1'):
            self.model = GenerativeModel(model_id)
            self.safety_settings: Optional[dict] = {
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            }
        else:
            raise ValueError("Please specify a compatible model.")

    def generate_llm_response(self,prompt):
        context_query = self.model.generate_content(prompt,safety_settings=self.safety_settings,stream=False)
        return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "")

#### Agent class for generating embeddings
The `EmbedderAgent` class is designed to generate text embeddings using Vertex AI's TextEmbeddingModel. It takes a string or a list of strings as input and returns the corresponding embeddings as a list of vectors. The agent's create method handles both single and multiple text inputs, ensuring that the embeddings are correctly retrieved and formatted.

In [None]:
from abc import ABC
from vertexai.language_models import TextEmbeddingModel



class EmbedderAgent(Agent, ABC): 
    """ 
    This Agent generates embeddings 
    """ 

    agentType: str = "EmbedderAgent"

    def __init__(self, mode, embeddings_model='textembedding-gecko@002'): 
        if mode == 'vertex': 
            self.mode = mode 
            self.model = TextEmbeddingModel.from_pretrained(embeddings_model)

        else: raise ValueError('EmbedderAgent mode must be vertex')



    def create(self, question): 
        """Text embedding with a Large Language Model."""

        if self.mode == 'vertex': 
            if isinstance(question, str): 
                embeddings = self.model.get_embeddings([question])
                for embedding in embeddings:
                    vector = embedding.values
                return vector
            
            elif isinstance(question, list):  
                vector = list() 
                for q in question: 
                    embeddings = self.model.get_embeddings([q])

                    for embedding in embeddings:
                        vector.append(embedding.values) 
                return vector
            
            else: raise ValueError('Input must be either str or list')
        

#### Agent class for building SQL

The `BuildSQLAgent` class specializes in constructing SQL queries for BigQuery (and PostgreSQL, on the main repo). It leverages an LLM to analyze user questions, table schemas, and column details to generate syntactically and semantically correct BigQuery-compliant SQL queries. It adheres to specific guidelines like minimal table joins, safe casting, and accurate column referencing. 

In [None]:
class BuildSQLAgent(Agent, ABC):
    """
    This Agent produces the SQL query 
    """

    agentType: str = "BuildSQLAgent"

    def build_sql(self,source_type, user_question,tables_schema,tables_detailed_schema, similar_sql, max_output_tokens=2048, temperature=0.4, top_p=1, top_k=32): 
        context_prompt = f"""
            You are a BigQuery SQL guru. Write a SQL comformant query for Bigquery that answers the following question while using the provided context to correctly refer to the BigQuery tables and the needed column names.

            Guidelines:
            - Join as minimal tables as possible.
            - When joining tables ensure all join columns are the same data_type.
            - Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
            - Use always SAFE_CAST. If performing a SAFE_CAST, use only Bigquery supported datatypes.
            - Always SAFE_CAST and then use aggregate functions
            - Don't include any comments in code.
            - Remove ```sql and ``` from the output and generate the SQL in single line.
            - Tables should be refered to using a fully qualified name with enclosed in ticks (`) e.g. `project_id.owner.table_name`.
            - Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
            - Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
            - Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
            - Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
            - Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
            - Table names are case sensitive. DO NOT uppercase or lowercase the table names.
            - Always enclose subqueries and union queries in brackets.
            - Refer to the examples provided i.e. {similar_sql}


        Here are some examples of user-question and SQL queries:
        {similar_sql}

        question:
        {user_question}

        Table Schema:
        {tables_schema}

        Column Description:
        {tables_detailed_schema}

        """
        

        if 'gemini' in self.model_id:
            # Generation Config
            config = GenerationConfig(
                max_output_tokens=max_output_tokens, temperature=temperature, top_p=top_p, top_k=top_k
            )

            # Generate text
            context_query = self.model.generate_content(context_prompt, generation_config=config, stream=False)
            generated_sql = str(context_query.candidates[0].text)

        else:
            context_query = self.model.predict(context_prompt, max_output_tokens = max_output_tokens, temperature=temperature)
            generated_sql = str(context_query.candidates[0])


        return generated_sql

#### Agent class for debugging loop

The `DebugSQLAgent` is designed to troubleshoot and refine BigQuery SQL queries. It initiates a chat session with a language model (e.g., Gemini or CodeChat) and iteratively analyzes queries, identifying and addressing errors. If a query is invalid, the agent uses feedback from the model to generate alternative queries that adhere to best practices and answer the original question. The start_debugger method orchestrates this process, returning a potentially refined query along with validity information and an audit trail.

In [None]:
from abc import ABC

from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel

import pandas as pd
import json  



class DebugSQLAgent(Agent, ABC): 
    """ 
    This Chat Agent runs the debugging loop.
    """ 

    agentType: str = "DebugSQLAgent"

    def __init__(self, chat_model_id = 'gemini-1.5-pro-001'): 
        self.chat_model_id = chat_model_id
        # self.model = CodeChatModel.from_pretrained("codechat-bison-32k")


    def init_chat(self, tables_schema,tables_detailed_schema,sql_example="-No examples provided..-"):
        context_prompt = f"""
        You are an BigQuery SQL guru. This session is trying to troubleshoot an BigQuery SQL query.  As the user provides versions of the query and the errors returned by BigQuery,
        return a new alternative SQL query that fixes the errors. It is important that the query still answer the original question.


        Guidelines:
        - Join as minimal tables as possible.
        - When joining tables ensure all join columns are the same data_type.
        - Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
        - Use always SAFE_CAST. If performing a SAFE_CAST, use only Bigquery supported datatypes.
        - Always SAFE_CAST and then use aggregate functions
        - Don't include any comments in code.
        - Remove ```sql and ``` from the output and generate the SQL in single line.
        - Tables should be refered to using a fully qualified name with enclosed in ticks (`) e.g. `project_id.owner.table_name`.
        - Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
        - Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
        - Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
        - Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
        - Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
        - Table names are case sensitive. DO NOT uppercase or lowercase the table names.
        - Always enclose subqueries and union queries in brackets.
        - Refer to the examples provided i.e. {sql_example}

        Parameters:
        - table metadata: {tables_schema}
        - column metadata: {tables_detailed_schema}
        - SQL example: {sql_example}

        """

        
        if self.chat_model_id == 'codechat-bison-32k':
            chat_model = CodeChatModel.from_pretrained("codechat-bison-32k")
            chat_session = chat_model.start_chat(context=context_prompt)
        elif self.chat_model_id == 'gemini-1.5-pro-001':
            chat_model = GenerativeModel("gemini-1.5-pro-001-001")
            chat_session = chat_model.start_chat(response_validation=False)
            chat_session.send_message(context_prompt)
        elif self.chat_model_id == 'gemini-ultra':
            chat_model = GenerativeModel("gemini-1.0-ultra-001")
            chat_session = chat_model.start_chat(response_validation=False)
            chat_session.send_message(context_prompt)
        else:
            raise ValueError('Invalid chat_model_id')
        
        return chat_session


    def rewrite_sql_chat(self, chat_session, question, error_df):


        context_prompt = f"""
            What is an alternative SQL statement to address the error mentioned below?
            Present a different SQL from previous ones. It is important that the query still answer the original question.
            All columns selected must be present on tables mentioned on the join section.
            Avoid repeating suggestions.

            Original SQL:
            {question}

            Error:
            {error_df}

            """

        if self.chat_model_id =='codechat-bison-32k':
            response = chat_session.send_message(context_prompt)
            resp_return = (str(response.candidates[0])).replace("```sql", "").replace("```", "")
        elif self.chat_model_id =='gemini-1.5-pro-001':
            response = chat_session.send_message(context_prompt, stream=False)
            resp_return = (str(response.text)).replace("```sql", "").replace("```", "")
        elif self.chat_model_id == 'gemini-ultra':
            response = chat_session.send_message(context_prompt, stream=False)
            resp_return = (str(response.text)).replace("```sql", "").replace("```", "")
        else:
            raise ValueError('Invalid chat_model_id')

        return resp_return


    def start_debugger  (self,
                        source_type,
                        query,
                        user_question, 
                        SQLChecker,
                        tables_schema, 
                        tables_detailed_schema,
                        AUDIT_TEXT, 
                        similar_sql="-No examples provided..-", 
                        DEBUGGING_ROUNDS = 2,
                        LLM_VALIDATION=True):
        i = 0  
        STOP = False 
        invalid_response = False 
        chat_session = self.init_chat(tables_schema,tables_detailed_schema,similar_sql)
        sql = query.replace("```sql","").replace("```","").replace("EXPLAIN ANALYZE ","")

        AUDIT_TEXT=AUDIT_TEXT+"\n\nEntering the debugging steps!"
        while (not STOP):

            # Check if LLM Validation is enabled 
            if LLM_VALIDATION: 
                # sql = query.replace("```sql","").replace("```","").replace("EXPLAIN ANALYZE ","")
                json_syntax_result = SQLChecker.check(user_question,tables_schema,tables_detailed_schema, sql) 

            else: 
                json_syntax_result['valid'] = True 

            if json_syntax_result['valid'] is True:
                # Testing SQL Execution
                if LLM_VALIDATION: 
                    AUDIT_TEXT=AUDIT_TEXT+"\nGenerated SQL is syntactically correct as per LLM Validation!"
                
                else: 
                    AUDIT_TEXT=AUDIT_TEXT+"\nLLM Validation is deactivated. Jumping directly to dry run execution."

                    
                correct_sql, exec_result_df = connector.test_sql_plan_execution(sql)
                print("exec_result_df:" + exec_result_df)
                if not correct_sql:
                        AUDIT_TEXT=AUDIT_TEXT+"\nGenerated SQL failed on execution! Here is the feedback from bigquery dryrun/ explain plan:  \n" + str(exec_result_df)
                        rewrite_result = self.rewrite_sql_chat(chat_session, sql, exec_result_df)
                        print('\n Rewritten and Cleaned SQL: ' + str(rewrite_result))
                        AUDIT_TEXT=AUDIT_TEXT+"\nRewritten and Cleaned SQL: \n' + str({rewrite_result})"
                        sql = str(rewrite_result).replace("```sql","").replace("```","").replace("EXPLAIN ANALYZE ","")

                else: STOP = True
            else:
                print(f'\nGenerated qeury failed on syntax check as per LLM Validation!\nError Message from LLM:  {json_syntax_result} \nRewriting the query...')
                AUDIT_TEXT=AUDIT_TEXT+'\nGenerated qeury failed on syntax check as per LLM Validation! \nError Message from LLM:  '+ str(json_syntax_result) + '\nRewriting the query...'
                
                syntax_err_df = pd.read_json(json.dumps(json_syntax_result))
                rewrite_result=self.rewrite_sql_chat(chat_session, sql, syntax_err_df)
                print(rewrite_result)
                AUDIT_TEXT=AUDIT_TEXT+'\n Rewritten SQL: ' + str(rewrite_result)
                sql=str(rewrite_result).replace("```sql","").replace("```","").replace("EXPLAIN ANALYZE ","")
            i+=1
            if i > DEBUGGING_ROUNDS:
                AUDIT_TEXT=AUDIT_TEXT+ "Exceeded the number of iterations for correction!"
                AUDIT_TEXT=AUDIT_TEXT+ "The generated SQL can be invalid!"
                STOP = True
                invalid_response=True
            # After the while is completed
        if i > DEBUGGING_ROUNDS:
            invalid_response=True
        # print(AUDIT_TEXT)
        return sql, invalid_response, AUDIT_TEXT

#### Agent class for validation

 The `ValidateSQLAgent` class validates SQL queries using a language model. It assesses the query's validity against predefined guidelines, including column existence, join conditions, table relationships, and formatting conventions. Given a user question, table schema, and column descriptions, it generates a JSON response indicating the query's validity and any errors found. This aids in ensuring the accuracy and correctness of SQL queries before execution.

In [None]:
import json 
from abc import ABC


class ValidateSQLAgent(Agent, ABC): 
    """ 
    This Chat Agent checks the SQL for vailidity
    """ 

    agentType: str = "ValidateSQLAgent"


    def check(self, user_question, tables_schema, columns_schema, generated_sql):

        context_prompt = f"""

            Classify the SQL query: {generated_sql} as valid or invalid?

            Guidelines to be valid:
            - all column_name in the query must exist in the table_name.
            - If a join includes d.country_id and table_alias d is equal to table_name DEPT, then country_id column_name must exist with table_name DEPT in the table column metadata. If not, the sql is invalid
            - all join columns must be the same data_type.
            - table relationships must be correct.
            - Tables should be refered to using a fully qualified name including owner and table name.
            - Use table_alias.column_name when referring to columns. Example: dept_id=hr.dept_id
            - Capitalize the table names on SQL "where" condition.
            - Use the columns from the "SELECT" statement while framing "GROUP BY" block.
            - Always refer the column name with rightly mapped table-name as seen in the table schema.
            - Must be syntactically and symantically correct SQL with proper relation mapping i.e owner, table and column relation.
            - Always the table should be refered as schema.table_name.


        Parameters:
        - SQL query: {generated_sql}
        - table schema: {tables_schema}
        - column description: {columns_schema}

        Respond using a valid JSON format with two elements valid and errors. Remove ```json and ``` from the output:
        {{ "valid": true or false, "errors":errors }}

        Initial user question:
        {user_question}


        """

        
        if self.model_id =='gemini-1.5-pro-001':
            context_query = self.model.generate_content(context_prompt, stream=False)
            generated_sql = str(context_query.candidates[0].text)

        else:
            context_query = self.model.predict(context_prompt, max_output_tokens = 8000, temperature=0)
            generated_sql = str(context_query.candidates[0])


        json_syntax_result = json.loads(str(generated_sql).replace("```json","").replace("```",""))

        # print('\n SQL Syntax Validity:' + str(json_syntax_result['valid']))
        # print('\n SQL Syntax Error Description:' +str(json_syntax_result['errors']) + '\n')
        
        return json_syntax_result

#### Agent for generating table and column descriptions

The `DescriptionAgent` class automatically generates concise descriptions for BigQuery tables and columns. It uses an LLM to analyze column metadata and table details, producing descriptions that can aid in understanding the data and improving SQL query generation. The agent iterates through dataframes containing table and column information, filling in missing descriptions with LLM-generated content.

In [None]:
from abc import ABC


class DescriptionAgent(Agent, ABC): 
    """ 
    Generates table and column descriptions. 
    """ 

    agentType: str = "DescriptionAgent"

    def generate_llm_response(self,prompt):
        context_query = self.model.generate_content(prompt,safety_settings=self.safety_settings,stream=False)
        return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "")


    def generate_missing_descriptions(self,source,table_desc_df, column_name_df):
        llm_generated=0
        for index, row in table_desc_df.iterrows():
            if row['table_description'] is None or row['table_description']=='NA':
                q=f"table_name == '{row['table_name']}' and table_schema == '{row['table_schema']}'"
                if source=='bigquery':
                    context_prompt = f"""
                        Generate table description short and crisp for the table {row['project_id']}.{row['table_schema']}.{row['table_name']}
                        Remember that these desciprtion should help LLMs to help build better SQL for any quries related to this table.
                        Parameters:
                        - column metadata: {column_name_df.query(q).to_markdown(index = False)}
                        - table metadata: {table_desc_df.query(q).to_markdown(index = False)}
                        
                        DO NOT generate description more than two lines
                    """


                table_desc_df.at[index,'table_description']=self.generate_llm_response(context_prompt)
                # print(row['table_description'])
                llm_generated=llm_generated+1
        print("\nLLM generated "+ str(llm_generated) + " Table Descriptions")
        llm_generated = 0

        
        for index, row in column_name_df.iterrows():
            # print(row['column_description'])
            if row['column_description'] is None or row['column_description']=='':
                q=f"table_name == '{row['table_name']}' and table_schema == '{row['table_schema']}'"
                if source=='bigquery':
                    context_prompt = f"""
                    Generate short and crisp description for the column {row['project_id']}.{row['table_schema']}.{row['table_name']}.{row['column_name']}

                    Remember that this description should help LLMs to help generate better SQL for any queries related to these columns.

                    Consider the below information to generate a good comment

                    Name of the column : {row['column_name']}
                    Data type of the column is : {row['data_type']}
                    Details of the table of this column are below:
                    {table_desc_df.query(q).to_markdown(index=False)}
                    Column Contrainst of this column are : {row['column_constraints']}

                    DO NOT generate description more than two lines
                """
                

                column_name_df.at[index,'column_description']=self.generate_llm_response(prompt=context_prompt)
                # print(row['column_description'])
                llm_generated=llm_generated+1
        print("\nLLM generated "+ str(llm_generated) + " Column Descriptions")
        return table_desc_df,column_name_df

#### Response Agent
The 'ResponseAgent' is a specialized chat agent that translates structured data from SQL queries into natural language responses. It bridges the gap between technical database results and user-friendly communication by leveraging language models like Gemini Pro. By analyzing both the user's question and the SQL output, it crafts informative and relevant answers, enhancing the user's interaction with complex data. This agent is adaptable, supporting multiple language models with minor adjustments, and is designed to make data more accessible and understandable to users.

In [None]:
import json 
from abc import ABC

class ResponseAgent(Agent, ABC): 
    """
    A specialized Chat Agent designed to provide natural language responses to user questions based on SQL query results.

    This agent acts as a bridge between structured data returned from SQL queries and the user's natural language input. It leverages a language model (e.g., Gemini Pro or others) to interpret the query results and craft informative, human-readable answers.

    Key Features:

    * **Natural Language Generation:**  Transforms SQL results into user-friendly responses.
    * **Model Flexibility:** Supports multiple language models (currently handles Gemini Pro and others with slight adjustments).
    * **Contextual Understanding:**  Incorporates the user's original question and the SQL results to provide accurate and relevant answers. 

    Attributes:
        agentType (str): Identifies this agent as a "ResponseAgent".
        model_id (str): Indicates the specific language model being used.

    Methods:
        run(user_question, sql_result):
            Generates a natural language response based on the user's question and the SQL results.
            
    Example:

        response_agent = ResponseAgent(model_id='gemini-1.5-pro-001')
        response = response_agent.run("How many customers are in California?", sql_result) 
        # response might be: "There are 153 customers in California based on the data."
    """

    agentType: str = "ResponseAgent"

    # TODO: Make the LLM Validator optional
    def run(self, user_question, sql_result):

        context_prompt = f"""

            You are a Data Assistant that helps to answer users' questions on their data within their databases.
            The user has provided the following question in natural language: "{str(user_question)}"

            The system has returned the following result after running the SQL query: "{str(sql_result)}".

            Provide a natural sounding response to the user to answer the question with the SQL result provided to you. 
        """

        
        if self.model_id =='gemini-1.5-pro-001':
            context_query = self.model.generate_content(context_prompt, stream=False)
            generated_sql = str(context_query.candidates[0].text)

        else:
            context_query = self.model.predict(context_prompt, max_output_tokens = 8000, temperature=0)
            generated_sql = str(context_query.candidates[0])
        
        return generated_sql

#### Helper function: chunking 

The `get_embedding_chunked` function efficiently generates embeddings for large amounts of text. It divides the input text (textinput) into smaller batches, processes them in parallel using the EmbedderAgent, and stores the resulting embeddings for each chunk. The function returns a pandas DataFrame containing the original text chunks along with their corresponding embedding vectors, facilitating further analysis or storage.

In [None]:
def get_embedding_chunked(textinput, batch_size): 
    embedder = EmbedderAgent('vertex')

    for i in range(0, len(textinput), batch_size):
        request = [x["content"] for x in textinput[i : i + batch_size]]
        response = embedder.create(request) # Vertex Textmodel Embedder 

        # Store the retrieved vector embeddings for each chunk back.
        for x, e in zip(textinput[i : i + batch_size], response):
            x["embedding"] = e

    # Store the generated embeddings in a pandas dataframe.
    out_df = pd.DataFrame(textinput)
    return out_df

# **1. Vector Store Setup** (Run once)
---

This section walks through the Vector Store Setup needed for running the Open Data QnA application. 

It covers the following steps: 
> 1. Configuration: Environment and Databases setup including logging on Bigquery for analytics

> 2. Creation of Table, Column and Known Good Query Embeddings in the Vector Store  for Retreival Augmented Generation(RAG)




## 📈 **1.1 Set Up your Data Source and Vector Store**

This section assumes that a datasource is already set up in your GCP project. 


### ⚙️  **Database Setup for Vector Store: BigQuery-vector**

Create dataset on Big Query to store the embeddings tables.
If Bigquery is the vector store, the same database is used for logging. 

In [None]:
# Create a new data set on Bigquery to use as Vector store; the same will be used for logging as well
if VECTOR_STORE == 'bigquery-vector':
  BQ_OPENDATAQNA_DATASET_NAME = "opendataqna" #@param {type:"string"} - name of the dataset in Vector Store

  from google.cloud import bigquery
  import google.api_core 
  client=bigquery.Client(project=PROJECT_ID)
  dataset_ref = f"{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}"


  # Create the dataset if it does not exist already
  try:
      client.get_dataset(dataset_ref)
      print("Destination Dataset exists")
  except google.api_core.exceptions.NotFound:
      print("Cannot find the dataset hence creating.......")
      dataset=bigquery.Dataset(dataset_ref)
      dataset.location=DATASET_REGION
      client.create_dataset(dataset)
      print(str(dataset_ref)+" is created")

##  **1.2. Create Embeddings in Vector Store for RAG** 

### 🖋️ **Create Table and Column Embeddings**

In this step, table and column metadata is retreived from the data source and embeddings are generated for both.
For this, we first specify helper functions for retrieving table and column schemas. 

Helper function to return table schema details: 

In [None]:
def return_table_schema_sql(project_id, dataset, table_names=None):
    """
    Returns the SQL query to get table schema info, optionally filtering by specific tables.
    """
    user_dataset = f"{project_id}.{dataset}"

    table_filter_clause = ""
    if table_names:
        # Extract individual table names from the input string
        #table_names = [name.strip() for name in table_names[1:-1].split(",")]  # Handle the string as a list
        formatted_table_names = [f"'{name}'" for name in table_names]
        table_filter_clause = f"""AND TABLE_NAME IN ({', '.join(formatted_table_names)})"""


    table_schema_sql = f"""
    (SELECT
        TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name,  OPTION_VALUE as table_description,
        (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS` as t
    WHERE
        OPTION_NAME = "description"
        {table_filter_clause}
    ORDER BY
        project_id, table_schema, table_name)

    UNION ALL

    (SELECT
        TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name,  "NA" as table_description,
        (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLES` as t 
    WHERE 
        NOT EXISTS (SELECT 1   FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS`  
    WHERE
        OPTION_NAME = "description" AND  TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA)
        {table_filter_clause}
    ORDER BY
        project_id, table_schema, table_name)
    """
    return table_schema_sql


Helper function to return column schema details: 

In [None]:

def return_column_schema_sql(project_id, dataset, table_names=None):
    """
    Returns the SQL query to get column schema info, optionally filtering by specific tables.
    """
    user_dataset = f"{project_id}.{dataset}"
    
    table_filter_clause = ""
    if table_names:
        # table_names = [name.strip() for name in table_names[1:-1].split(",")]  # Handle the string as a list
        formatted_table_names = [f"'{name}'" for name in table_names]
        table_filter_clause = f"""AND C.TABLE_NAME IN ({', '.join(formatted_table_names)})"""

    column_schema_sql = f"""
    SELECT
        C.TABLE_CATALOG as project_id, C.TABLE_SCHEMA as table_schema, C.TABLE_NAME as table_name, C.COLUMN_NAME as column_name,
        C.DATA_TYPE as data_type, C.DESCRIPTION as column_description, CASE WHEN T.CONSTRAINT_TYPE="PRIMARY KEY" THEN "This Column is a Primary Key for this table" WHEN 
        T.CONSTRAINT_TYPE = "FOREIGN_KEY" THEN "This column is Foreign Key" ELSE NULL END as column_constraints
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` C 
    LEFT JOIN 
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` T 
        ON C.TABLE_CATALOG = T.TABLE_CATALOG AND
           C.TABLE_SCHEMA = T.TABLE_SCHEMA AND 
           C.TABLE_NAME = T.TABLE_NAME AND  
           T.ENFORCED ='YES'
    LEFT JOIN 
        `{user_dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE` K
        ON K.CONSTRAINT_NAME=T.CONSTRAINT_NAME AND C.COLUMN_NAME = K.COLUMN_NAME 
    WHERE
        1=1
        {table_filter_clause} 
    ORDER BY
        project_id, table_schema, table_name, column_name;
"""

    return column_schema_sql


Retrieve table and column dataframes with schema details and descriptions: 

In [None]:
client = bigquery.Client(project=project_id)

table_schema_sql = return_table_schema_sql(project_id, BQ_DATASET_NAME, BQ_TABLE_LIST)
table_desc_df = client.query_and_wait(table_schema_sql).to_dataframe()

column_schema_sql = return_column_schema_sql(project_id, BQ_DATASET_NAME, BQ_TABLE_LIST) 
column_name_df = client.query_and_wait(column_schema_sql).to_dataframe()


descriptor = DescriptionAgent('gemini-1.5-pro-001')




#GENERATE MISSING DESCRIPTIONS
table_desc_df,column_name_df= descriptor.generate_missing_descriptions('bigquery',table_desc_df,column_name_df)

In [None]:
table_desc_df['table_description']

In [None]:
table_desc_df.head(10)

In [None]:
column_name_df.head(10)

Function to generate embeddings: 

In [None]:
import pandas as pd 

def retrieve_embeddings(): 
    """ Augment all the DB schema blocks to create document for embedding """

    #TABLE EMBEDDINGS
    table_details_chunked = []

    for _, row_aug in table_desc_df.iterrows():
        cur_project_name =str(row_aug['project_id'])
        cur_table_name = str(row_aug['table_name'])
        cur_table_schema = str(row_aug['table_schema'])
        curr_col_names = str(row_aug['table_columns'])
        curr_tbl_desc = str(row_aug['table_description'])


        table_detailed_description=f"""
        Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
        Table Columns List: [{curr_col_names}] |
        Table Description: {curr_tbl_desc} """

        r = {"table_schema": cur_table_schema,"table_name": cur_table_name,"content": table_detailed_description}
        table_details_chunked.append(r)

    table_details_embeddings = get_embedding_chunked(table_details_chunked, 10)


    ### COLUMN EMBEDDING ###
    """
    This SQL returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints
    for the schema specified above, e.g. 'retail'
    """

    column_details_chunked = []

    for _, row_aug in column_name_df.iterrows():
        cur_project_name =str(row_aug['project_id'])
        cur_table_name = str(row_aug['table_name'])
        cur_table_owner = str(row_aug['table_schema'])
        curr_col_name = str(row_aug['table_schema'])+'.'+str(row_aug['table_name'])+'.'+str(row_aug['column_name'])
        curr_col_datatype = str(row_aug['data_type'])
        curr_col_description = str(row_aug['column_description'])
        curr_col_constraints = str(row_aug['column_constraints'])
        curr_column_name = str(row_aug['column_name'])


        column_detailed_description=f"""
        Column Name: {curr_col_name}|
        Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
        Data type: {curr_col_datatype}|
        Column description: {curr_col_description}|
        Column Constraints: {curr_col_constraints} """

        r = {"table_schema": cur_table_owner,"table_name": cur_table_name,"column_name":curr_column_name, "content": column_detailed_description}
        column_details_chunked.append(r)

    column_details_embeddings = get_embedding_chunked(column_details_chunked, 10)


    return table_details_embeddings, column_details_embeddings

Generate embeddings: 

In [None]:
# Create Table and Column Embeddings
table_schema_embeddings, col_schema_embeddings = retrieve_embeddings()


print("Table and Column embeddings are created")

In [None]:
table_schema_embeddings.head(10)

### 💾 **Save the Table and Column Embeddings in the Vector Store**
The table and column embeddings created in the above step are save to the Vector Store chosen

In [None]:
from google.cloud import bigquery

async def store_schema_embeddings(table_details_embeddings, 
                            tablecolumn_details_embeddings, 
                            project_id,
                            schema):
    """ 
    Store the vectorised table and column details in the DB table.
    This code may run for a few minutes.  
    """
         
    client=bigquery.Client(project=project_id)

    #Store table embeddings
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.table_details_embeddings` (
        source_type string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, content string, embedding ARRAY<FLOAT64>)''')
    #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
    table_details_embeddings['source_type']='BigQuery'
    for _, row in table_details_embeddings.iterrows():
        client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.table_details_embeddings`
                WHERE table_schema= '{row["table_schema"]}' and table_name= '{row["table_name"]}' '''
                    )
    client.load_table_from_dataframe(table_details_embeddings,f'{project_id}.{schema}.table_details_embeddings')


    #Store column embeddings
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.tablecolumn_details_embeddings` (
        source_type string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, column_name string NOT NULL,
        content string, embedding ARRAY<FLOAT64>)''')
    #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
    tablecolumn_details_embeddings['source_type']='BigQuery'
    for _, row in tablecolumn_details_embeddings.iterrows():
        client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.tablecolumn_details_embeddings`
                WHERE table_schema= '{row["table_schema"]}' and table_name= '{row["table_name"]}' and column_name= '{row["column_name"]}' '''
                    )
    client.load_table_from_dataframe(tablecolumn_details_embeddings,f'{project_id}.{schema}.tablecolumn_details_embeddings')

    return "Embeddings are stored successfully"


Next cell may take a while depending on the size of your data source. It stores the embeddings back to the vector db.

In [None]:

await(store_schema_embeddings(table_details_embeddings=table_schema_embeddings, 
                                tablecolumn_details_embeddings=col_schema_embeddings, 
                                project_id=project_id,
                                schema=BQ_OPENDATAQNA_DATASET_NAME                               
                                ))


print("Table and Column embeddings are saved to vector store")

### 🗄️ **Load Known Good SQL into Vector Store**
Known Good Queries are used to create query cache for Few shot examples. Creating a query cache is highly recommended for best outcomes! 

The following cell will load the Natural Language Question and Known Good SQL pairs into our Vector Store. There pairs are loaded from `known_good_sql.csv` file inside scripts folder. If you have your own Question-SQL examples, curate them in .csv file before running the cell below. 

If no Known Good Queries are available at this time to create query cache, you can use [3_LoadKnownGoodSQL.ipynb](3_LoadKnownGoodSQL.ipynb) to load them later!!" Empty table for KGQ embedding will be created!



#### Format of the Known Good SQL File (known_good_sql.csv)

prompt | sql | database_name [3 columns]

prompt ==> User Question 

sql ==> SQL for the user question (Note that the sql should enclosed in quotes and only in single line. Please remove the line  break)

database_name ==>This name should exactly  match the SCHEMA   NAME for Postgres Source or BQ_DATASET_NAME

In [None]:
import os
import pandas as pd
from google.cloud import bigquery


embedder = EmbedderAgent('vertex')


async def setup_kgq_table( project_id,
                            schema):
    """ 
    This function sets up or refreshes the Vector Store for Known Good Queries (KGQ)
    """

    # Create BQ Client
    client=bigquery.Client(project=project_id)

    # Delete an old table
    client.query_and_wait(f'''DROP TABLE IF EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings`''')
    # Create a new emptry table
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings` (
                            table_schema string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
                            embedding ARRAY<FLOAT64>)''')
        


async def store_kgq_embeddings(df_kgq, 
                            project_id,
                            schema
                            ):
    """ 
    Create and save the Known Good Query Embeddings to Vector Store  
    """

    client=bigquery.Client(project=project_id)
    
    example_sql_details_chunked = []

    for _, row_aug in df_kgq.iterrows():

        example_user_question = str(row_aug['prompt'])
        example_generated_sql = str(row_aug['sql'])
        example_database_name = str(row_aug['database_name'])
        emb =  embedder.create(example_user_question)
        

        r = {"example_database_name":example_database_name,"example_user_question": example_user_question,"example_generated_sql": example_generated_sql,"embedding": emb}
        example_sql_details_chunked.append(r)

    example_prompt_sql_embeddings = pd.DataFrame(example_sql_details_chunked)

    for _, row in example_prompt_sql_embeddings.iterrows():
            client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.example_prompt_sql_embeddings`
                        WHERE table_schema= '{row["example_database_name"]}' and example_user_question= '{row["example_user_question"]}' '''
                            )
                # embedding=np.array(row["embedding"])
            cleaned_sql = row["example_generated_sql"].replace("\n", " ")
            client.query_and_wait(f'''INSERT INTO `{project_id}.{schema}.example_prompt_sql_embeddings` 
                VALUES ("{row["example_database_name"]}","{row["example_user_question"]}" , 
                "{cleaned_sql}",{row["embedding"]} )''')
                    


Next cell stores the kgq to the vector db:

In [None]:
if EXAMPLES:
    print("Examples are provided, creating KGQ embeddings and saving them to Vector store.....")
    
    current_dir = os.getcwd()
    root_dir = os.path.expanduser('~')  # Start at the user's home directory

    while current_dir != root_dir:
        for dirpath, dirnames, filenames in os.walk(current_dir):
            config_path = os.path.join(dirpath, 'known_good_sql.csv')
            if os.path.exists(config_path):
                file_path = config_path  # Update root_dir to the found directory
                break  # Stop outer loop once found

        current_dir = os.path.dirname(current_dir)

    print("Known Good SQL Found at Path :: "+file_path)

    # Load the file
    df_kgq = pd.read_csv(file_path)
    df_kgq = df_kgq.loc[:, ["prompt", "sql", "database_name"]]
    df_kgq = df_kgq.dropna()

    # Add KGQ to the vector store
    await(store_kgq_embeddings(df_kgq,
                                project_id=project_id,
                                schema=BQ_OPENDATAQNA_DATASET_NAME
                                ))

    print('Done!!')

else:
    print("⚠️ WARNING: No Known Good Queries are provided to create query cache for Few shot examples!")
    print("Creating a query cache is highly recommended for best outcomes")

### 🥁 If all the above steps are executed suucessfully, the following should be set up:

* GCP project and all the required IAM permissions

* Environment to run the solution

* Data source and Vector store for the solution

__________________________________________________________________________________________________________________

# **2. Run the Open Data QnA Pipeline**

### 🔗 **3A. Connect to Datasource, Vector Source and Vertex AI**


In [None]:

from google.cloud import aiplatform
import vertexai

found_in_vector = 'N'
final_sql='Not Generated Yet'

vertexai.init(project=project_id, location=BQ_REGION)
aiplatform.init(project=project_id, location=BQ_REGION)

In [None]:
BQ_DATASET_NAME = project_id+'.'+BQ_DATASET_NAME

###  ❓ **Ask your Natural Language Question**

In [None]:
print("\033[1mData Source:- "+ DATA_SOURCE)
print("Vector Store:- "+ VECTOR_STORE)
    
# Suggested question for 'fda_food' dataset: "What are the top 5 cities with highest recalls?"
# Suggested question for 'google_dei' dataset: "How many asian men were part of the leadership workforce in 2021?"

# user_question = input(prompt_for_question) #Uncomment if you want to ask question yourself
user_question = 'How many movies have a rating higher than four?' # Or Enter Question here

print("User Question:- "+user_question)

### 🏃 **Run the Pipeline**

In [None]:
# Fetch the USER_DATABASE based on data source

call_await = False

num_table_matches = 5
num_column_matches = 10
table_similarity_threshold = 0.3
column_similarity_threshold = 0.3 
example_similarity_threshold = 0.3 
num_sql_matches=3

DEBUGGING_ROUNDS = 2
RUN_DEBUGGER = True 
LLM_VALIDATION=True
EXECUTE_FINAL_SQL = True 


# Fetch the embedding of the user's input question 
embedded_question = embedder.create(user_question)

USER_DATABASE = BQ_OPENDATAQNA_DATASET_NAME

SQLBuilder = BuildSQLAgent('gemini-1.5-pro-001')
SQLDebugger = DebugSQLAgent('gemini-1.5-pro-001')
SQLChecker = ValidateSQLAgent('gemini-1.5-pro-001')

# Reset AUDIT_TEXT
AUDIT_TEXT = ''

AUDIT_TEXT = AUDIT_TEXT + "\nUser Question : " + str(user_question) + "\nUser Database : " + str(USER_DATABASE)
process_step = "\n\nGet Exact Match: "

# Look for exact matches in known questions IF kgq is enabled 
if EXAMPLES: 
    exact_sql_history = connector.getExactMatches(user_question) 

else: exact_sql_history = None 

# If exact user query has been found, retrieve the SQL and skip Generation Pipeline 
if exact_sql_history is not None:
    found_in_vector = 'Y' 
    final_sql = exact_sql_history
    invalid_response = False
    AUDIT_TEXT = AUDIT_TEXT + "\nExact match has been found! Going to retreive the SQL query from cache and serve!"


else:
    # No exact match found. Proceed looking for similar entries in db IF kgq is enabled 
    if EXAMPLES: 
        AUDIT_TEXT = AUDIT_TEXT +  process_step + "\nNo exact match found in query cache, retreiving revelant schema and known good queries for few shot examples using similarity search...."
        process_step = "\n\nGet Similar Match: "
        if call_await:
            similar_sql = await connector.getSimilarMatches('example', USER_DATABASE, embedded_question, num_sql_matches, example_similarity_threshold)
        else:
            similar_sql = connector.getSimilarMatches('example', USER_DATABASE, embedded_question, num_sql_matches, example_similarity_threshold)

    else: similar_sql = "No similar SQLs provided..."

    process_step = "\n\nGet Table and Column Schema: "
    # Retrieve matching tables and columns
    if call_await: 
        table_matches =  await connector.getSimilarMatches('table', USER_DATABASE, embedded_question, num_table_matches, table_similarity_threshold)
        column_matches =  await connector.getSimilarMatches('column', USER_DATABASE, embedded_question, num_column_matches, column_similarity_threshold)
    else:
        table_matches =  connector.getSimilarMatches('table', USER_DATABASE, embedded_question, num_table_matches, table_similarity_threshold)
        column_matches =  connector.getSimilarMatches('column', USER_DATABASE, embedded_question, num_column_matches, column_similarity_threshold)

    AUDIT_TEXT = AUDIT_TEXT +  process_step + "\nRetrieved Similar Known Good Queries, Table Schema and Column Schema: \n" + '\nRetrieved Tables: \n' + str(table_matches) + '\n\nRetrieved Columns: \n' + str(column_matches) + '\n\nRetrieved Known Good Queries: \n' + str(similar_sql)
    
    
    # If similar table and column schemas found: 
    if len(table_matches.replace('Schema(values):','').replace(' ','')) > 0 or len(column_matches.replace('Column name(type):','').replace(' ','')) > 0 :

        # GENERATE SQL
        process_step = "\n\nBuild SQL: "
        generated_sql = SQLBuilder.build_sql(DATA_SOURCE,user_question,table_matches,column_matches,similar_sql)
        final_sql=generated_sql
        AUDIT_TEXT = AUDIT_TEXT + process_step +  "\nGenerated SQL : " + str(generated_sql)
        
        if 'unrelated_answer' in generated_sql :
            invalid_response=True

        # If agent assessment is valid, proceed with checks  
        else:
            invalid_response=False

            if RUN_DEBUGGER: 
                generated_sql, invalid_response, AUDIT_TEXT = SQLDebugger.start_debugger(DATA_SOURCE, generated_sql, user_question, SQLChecker, table_matches, column_matches, AUDIT_TEXT, similar_sql, DEBUGGING_ROUNDS, LLM_VALIDATION) 
                # AUDIT_TEXT = AUDIT_TEXT + '\n Feedback from Debugger: \n' + feedback_text

            final_sql=generated_sql
            AUDIT_TEXT = AUDIT_TEXT + "\nFinal SQL after Debugger: \n" +str(final_sql)


    # No matching table found 
    else:
        invalid_response=True
        print('No tables found in Vector ...')
        AUDIT_TEXT = AUDIT_TEXT + "\nNo tables have been found in the Vector DB. The question cannot be answered with the provide data source!"

print(f'\n\n AUDIT_TEXT: \n {AUDIT_TEXT}')

In [None]:
final_sql

In [None]:
invalid_response

## Run against db 

In [None]:
Responder = ResponseAgent('gemini-1.5-pro-001')


if not invalid_response:
    try: 
        if EXECUTE_FINAL_SQL is True:
                final_exec_result_df=connector.retrieve_df(final_sql.replace("```sql","").replace("```","").replace("EXPLAIN ANALYZE ",""))
                print('\nQuestion: ' + user_question + '\n')
                # print('\n Final SQL Execution Result: \n')
                # print(final_exec_result_df)
                response = final_exec_result_df
                _resp=Responder.run(user_question, response)
                AUDIT_TEXT = AUDIT_TEXT + "\nModel says " + str(_resp) 


        else:  # Do not execute final SQL
                print("Not executing final SQL since EXECUTE_FINAL_SQL variable is False\n ")
                response = "Please enable the Execution of the final SQL so I can provide an answer"
                _resp=Responder.run(user_question, response)
                AUDIT_TEXT = AUDIT_TEXT + "\nModel says " + str(_resp) 

    except ValueError: 
          print('')
    # except Exception as e: 
    #     print(f"An error occured. Aborting... Error Message: {e}")
        
else:  # Do not execute final SQL
    print("Not executing final SQL as it is invalid, please debug!")
    response = "I am sorry, I could not come up with a valid SQL."
    _resp=Responder.run(user_question, response)
    AUDIT_TEXT = AUDIT_TEXT + "\nModel says " + str(_resp)

print("Final Answer:" + str(_resp))

## 🗑 **Clean Up Notebook Resources**
Make sure to delete your Cloud SQL instance and BigQuery Datasets when your are finished with this notebook to avoid further costs. 💸 💰

Uncomment and run the cell below to delete 

In [None]:
# #delete BigQuery Dataset using bq utility
# !bq rm -r -f -d {BQ_DATASET_NAME}

# #delete BigQuery 'Open Data QnA' Vector Store Dataset using bq utility
# !bq rm -r -f -d {BQ_OPENDATAQNA_DATASET_NAME}

