In [1]:
import google.auth
import os
from google.cloud import bigquery
from google.cloud import bigquery_connection_v1 as bq_connection
from abc import ABC
from datetime import datetime
import google.auth
import pandas as pd
from google.cloud.exceptions import NotFound
from dotenv import load_dotenv
# from pandas_gbq import to_gbq
from abc import ABC
import vertexai
from vertexai.language_models import TextGenerationModel
from vertexai.language_models import CodeGenerationModel
from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmCategory,HarmBlockThreshold
from vertexai.generative_models import GenerationConfig
from vertexai.language_models import TextEmbeddingModel
import time
load_dotenv()
import yaml


# project_id = os.getenv("PROJECT_ID")
# dataset = os.getenv("DATABASE_ID")

with open('../sql_config.yml' , 'r') as f:
    sql_config = yaml.load(f, Loader=yaml.FullLoader)

project_id = sql_config['bigquery']['project_id']
dataset = sql_config['bigquery']['dataset_id']
location = sql_config['bigquery']['region']

vertexai.init(project=sql_config['bigquery']['project_id'], location=sql_config['bigquery']['region'])





In [3]:

def return_table_schema_sql(project_id, dataset, table_names=None):
    """
    Returns the SQL query to get table schema info, optionally filtering by specific tables.
    """
    user_dataset = f"{project_id}.{dataset}"

    table_filter_clause = ""
    if table_names:
        # Extract individual table names from the input string
        #table_names = [name.strip() for name in table_names[1:-1].split(",")]  # Handle the string as a list
        formatted_table_names = [f"'{name}'" for name in table_names]
        table_filter_clause = f"""AND TABLE_NAME IN ({', '.join(formatted_table_names)})"""


    table_schema_sql = f"""
    (SELECT
        TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name,  OPTION_VALUE as table_description,
        (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS` as t
    WHERE
        OPTION_NAME = "description"
        {table_filter_clause}
    ORDER BY
        project_id, table_schema, table_name)

    UNION ALL

    (SELECT
        TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name,  "NA" as table_description,
        (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLES` as t 
    WHERE 
        NOT EXISTS (SELECT 1   FROM
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS`  
    WHERE
        OPTION_NAME = "description" AND  TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA)
        {table_filter_clause}
    ORDER BY
        project_id, table_schema, table_name)
    """
    return table_schema_sql

def return_column_schema_sql(project_id, dataset, table_names=None):
    """
    Returns the SQL query to get column schema info, optionally filtering by specific tables.
    """
    user_dataset = f"{project_id}.{dataset}"
    
    table_filter_clause = ""
    if table_names:
        # table_names = [name.strip() for name in table_names[1:-1].split(",")]  # Handle the string as a list
        formatted_table_names = [f"'{name}'" for name in table_names]
        table_filter_clause = f"""AND C.TABLE_NAME IN ({', '.join(formatted_table_names)})"""

    column_schema_sql = f"""
    SELECT
        C.TABLE_CATALOG as project_id, C.TABLE_SCHEMA as table_schema, C.TABLE_NAME as table_name, C.COLUMN_NAME as column_name,
        C.DATA_TYPE as data_type, C.DESCRIPTION as column_description, CASE WHEN T.CONSTRAINT_TYPE="PRIMARY KEY" THEN "This Column is a Primary Key for this table" WHEN 
        T.CONSTRAINT_TYPE = "FOREIGN_KEY" THEN "This column is Foreign Key" ELSE NULL END as column_constraints
    FROM
        `{user_dataset}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` C 
    LEFT JOIN 
        `{user_dataset}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` T 
        ON C.TABLE_CATALOG = T.TABLE_CATALOG AND
           C.TABLE_SCHEMA = T.TABLE_SCHEMA AND 
           C.TABLE_NAME = T.TABLE_NAME AND  
           T.ENFORCED ='YES'
    LEFT JOIN 
        `{user_dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE` K
        ON K.CONSTRAINT_NAME=T.CONSTRAINT_NAME AND C.COLUMN_NAME = K.COLUMN_NAME 
    WHERE
        1=1
        {table_filter_clause} 
    ORDER BY
        project_id, table_schema, table_name, column_name;
"""

    return column_schema_sql


#Create Descrition Agent
"""
Provides the base class for all Agents 
"""





class Agent(ABC):
    """
    The core class for all Agents
    """

    agentType: str = "Agent"

    def __init__(self,
                model_id:str):
        """
        Args:
            PROJECT_ID (str | None): GCP Project Id.
            dataset_name (str): 
            TODO
        """

        self.model_id = model_id 

        if model_id == 'code-bison-32k':
            self.model = CodeGenerationModel.from_pretrained('code-bison-32k')
        elif model_id == 'text-bison-32k':
            self.model = TextGenerationModel.from_pretrained('text-bison-32k')
        elif model_id == 'codechat-bison-32k':
            self.model = CodeChatModel.from_pretrained("codechat-bison-32k")
        elif model_id == 'gemini-2.0-flash-001' or model_id == 'gemini-1.5-flash-001':
            self.model = GenerativeModel("gemini-2.0-flash-001")
        elif model_id == 'text-embedding-005':
            self.model = TextEmbeddingModel.from_pretrained('text-embedding-005')
        elif model_id == 'gemini-2.0-flash-lite-001':
            self.model = GenerativeModel("gemini-2.0-flash-lite-001")
        else:
            raise ValueError("Please specify a compatible model.")

    def generate_llm_response(self,prompt):
        context_query = self.model.generate_content(prompt,stream=False)
        return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "")

class EmbedderAgent(Agent, ABC): 
    """ 
    This Agent generates embeddings 
    """ 

    agentType: str = "EmbedderAgent"

    def __init__(self, mode, embeddings_model='text-embedding-005'): 
        if mode == 'vertex': 
            self.mode = mode 
            self.model = TextEmbeddingModel.from_pretrained(embeddings_model)

        else: raise ValueError('EmbedderAgent mode must be vertex')



    def create(self, question): 
        """Text embedding with a Large Language Model."""

        if self.mode == 'vertex': 
            if isinstance(question, str): 
                embeddings = self.model.get_embeddings([question])
                for embedding in embeddings:
                    vector = embedding.values
                return vector
            
            elif isinstance(question, list):  
                vector = list() 
                for q in question: 
                    embeddings = self.model.get_embeddings([q])

                    for embedding in embeddings:
                        vector.append(embedding.values) 
                return vector
            
            else: raise ValueError('Input must be either str or list')


class DescriptionAgent(Agent, ABC): 
    """ 
    Generates table and column descriptions. 
    """ 

    agentType: str = "DescriptionAgent"

    def generate_llm_response(self,prompt):
        context_query = self.model.generate_content(prompt,stream=False)
        return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "")


    def generate_missing_descriptions(self,source,table_desc_df, column_name_df):
        llm_generated=0
        for index, row in table_desc_df.iterrows():
            print(f"Table Description for {row['table_name']}")

            if row['table_description'] is None or row['table_description']=='NA':
                q=f"table_name == '{row['table_name']}' and table_schema == '{row['table_schema']}'"
                if source=='bigquery':
                    context_prompt = f"""
                        Generate table description short and crisp for the table {row['project_id']}.{row['table_schema']}.{row['table_name']}
                        Remember that these desciprtion should help LLMs to help build better SQL for any quries related to this table.
                        Parameters:
                        - column metadata: {column_name_df.query(q).to_markdown(index = False)}
                        - table metadata: {table_desc_df.query(q).to_markdown(index = False)}
                        
                        DO NOT generate description more than two lines
                    """


                table_desc_df.at[index,'table_description']=self.generate_llm_response(context_prompt)
                # print(row['table_description'])
                llm_generated=llm_generated+1
                time.sleep(13)
        print("\nLLM generated "+ str(llm_generated) + " Table Descriptions")
        llm_generated = 0

        
        for index, row in column_name_df.iterrows():
            # print(row['column_description'])
            print(f"column_description for {row['column_name']} at Table {row['table_name']}")
            if row['column_description'] is None or row['column_description']=='':
                q=f"table_name == '{row['table_name']}' and table_schema == '{row['table_schema']}'"
                if source=='bigquery':
                    context_prompt = f"""
                    Generate short and crisp description for the column {row['project_id']}.{row['table_schema']}.{row['table_name']}.{row['column_name']}

                    Remember that this description should help LLMs to help generate better SQL for any queries related to these columns.

                    Consider the below information to generate a good comment

                    Name of the column : {row['column_name']}
                    Data type of the column is : {row['data_type']}
                    Details of the table of this column are below:
                    {table_desc_df.query(q).to_markdown(index=False)}
                    Column Contrainst of this column are : {row['column_constraints']}

                    DO NOT generate description more than two lines
                """
                

                column_name_df.at[index,'column_description']=self.generate_llm_response(prompt=context_prompt)
                # print(row['column_description'])
                llm_generated=llm_generated+1
                time.sleep(13)
        print("\nLLM generated "+ str(llm_generated) + " Column Descriptions")
        
        return table_desc_df,column_name_df



In [15]:

def get_embedding_chunked(textinput, batch_size): 
    embedder = EmbedderAgent('vertex')

    for i in range(0, len(textinput), batch_size):
        request = [x["content"] for x in textinput[i : i + batch_size]]
        response = embedder.create(request) # Vertex Textmodel Embedder 

        # Store the retrieved vector embeddings for each chunk back.
        for x, e in zip(textinput[i : i + batch_size], response):
            x["embedding"] = e

    # Store the generated embeddings in a pandas dataframe.
    out_df = pd.DataFrame(textinput)
    return out_df


#Function to generate embeddings:
def retrieve_embeddings(): 
    """ Augment all the DB schema blocks to create document for embedding """

    #TABLE EMBEDDINGS
    table_details_chunked = []

    for _, row_aug in table_desc_df.iterrows():
        cur_project_name =str(row_aug['project_id'])
        cur_table_name = str(row_aug['table_name'])
        cur_table_schema = str(row_aug['table_schema'])
        curr_col_names = str(row_aug['table_columns'])
        curr_tbl_desc = str(row_aug['table_description'])


        table_detailed_description=f"""
        Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
        Table Columns List: [{curr_col_names}] |
        Table Description: {curr_tbl_desc} """

        r = {"table_schema": cur_table_schema,"table_name": cur_table_name,"content": table_detailed_description}
        table_details_chunked.append(r)

    table_details_embeddings = get_embedding_chunked(table_details_chunked, 10)


    ### COLUMN EMBEDDING ###
    """
    This SQL returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints
    for the schema specified above, e.g. 'retail'
    """

    column_details_chunked = []

    for _, row_aug in column_name_df.iterrows():
        cur_project_name =str(row_aug['project_id'])
        cur_table_name = str(row_aug['table_name'])
        cur_table_owner = str(row_aug['table_schema'])
        curr_col_name = str(row_aug['table_schema'])+'.'+str(row_aug['table_name'])+'.'+str(row_aug['column_name'])
        curr_col_datatype = str(row_aug['data_type'])
        curr_col_description = str(row_aug['column_description'])
        curr_col_constraints = str(row_aug['column_constraints'])
        curr_column_name = str(row_aug['column_name'])


        column_detailed_description=f"""
        Column Name: {curr_col_name}|
        Full Table Name : {cur_project_name}.{cur_table_schema}.{cur_table_name} |
        Data type: {curr_col_datatype}|
        Column description: {curr_col_description}|
        Column Constraints: {curr_col_constraints} """

        r = {"table_schema": cur_table_owner,"table_name": cur_table_name,"column_name":curr_column_name, "content": column_detailed_description}
        column_details_chunked.append(r)

    column_details_embeddings = get_embedding_chunked(column_details_chunked, 10)


    return table_details_embeddings, column_details_embeddings


async def store_schema_embeddings(table_details_embeddings, 
                            tablecolumn_details_embeddings, 
                            project_id,
                            schema):
    """ 
    Store the vectorised table and column details in the DB table.
    This code may run for a few minutes.  
    """
         
    client=bigquery.Client(project=project_id)

    #Store table embeddings
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.table_details_embeddings` (
        source_type string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, content string, embedding ARRAY<FLOAT64>)''')
    #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
    table_details_embeddings['source_type']='BigQuery'
    for _, row in table_details_embeddings.iterrows():
        client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.table_details_embeddings`
                WHERE table_schema= '{row["table_schema"]}' and table_name= '{row["table_name"]}' '''
                    )
    client.load_table_from_dataframe(table_details_embeddings,f'{project_id}.{schema}.table_details_embeddings')


    #Store column embeddings
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.tablecolumn_details_embeddings` (
        source_type string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, column_name string NOT NULL,
        content string, embedding ARRAY<FLOAT64>)''')
    #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
    tablecolumn_details_embeddings['source_type']='BigQuery'
    for _, row in tablecolumn_details_embeddings.iterrows():
        client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.tablecolumn_details_embeddings`
                WHERE table_schema= '{row["table_schema"]}' and table_name= '{row["table_name"]}' and column_name= '{row["column_name"]}' '''
                    )
    client.load_table_from_dataframe(tablecolumn_details_embeddings,f'{project_id}.{schema}.tablecolumn_details_embeddings')

    return "Embeddings are stored successfully"


In [4]:
project_id

'energyagentai'

In [16]:

###Get the Tables and Columns in Database
client = bigquery.Client(project=project_id)
table_schema_sql = return_table_schema_sql(project_id, dataset)
table_desc_df = client.query_and_wait(table_schema_sql).to_dataframe()
column_schema_sql = return_column_schema_sql(project_id, dataset)
column_desc_df = client.query_and_wait(column_schema_sql).to_dataframe()

In [17]:

#print(table_desc_df.head())
descriptor = DescriptionAgent('gemini-2.0-flash-001')
#GENERATE MISSING DESCRIPTIONS
table_desc_df,column_name_df= descriptor.generate_missing_descriptions('bigquery',table_desc_df,column_desc_df)


Table Description for market_data


Table Description for products
Table Description for smartmeter_data
Table Description for customer_base

LLM generated 3 Table Descriptions
column_description for account_number at Table customer_base
column_description for age at Table customer_base
column_description for annual_income at Table customer_base
column_description for app_usage_monthly at Table customer_base
column_description for avg_monthly_bill at Table customer_base
column_description for bill_payment_consistency at Table customer_base
column_description for bill_shock_events at Table customer_base
column_description for business_size at Table customer_base
column_description for business_type at Table customer_base
column_description for call_center_interactions_12m at Table customer_base
column_description for churn_destination at Table customer_base
column_description for churn_reason at Table customer_base
column_description for city at Table customer_base
column_description for communication_preference at Table 

In [20]:
table_desc_df.to_csv('table_desc_df.csv', index=False)
column_name_df.to_csv('column_name_df.csv', index=False)

In [None]:

##Get EMbeddings added to table adn columns description
table_schema_embeddings, col_schema_embeddings = retrieve_embeddings()
#store the embeddings back to the vector db.
await(store_schema_embeddings(table_details_embeddings=table_schema_embeddings, 
                                tablecolumn_details_embeddings=col_schema_embeddings, 
                                project_id=project_id,
                                schema='alberta_energy_ai'                               
                                ))
print("Table and Column embeddings are saved to vector store")



Table and Column embeddings are saved to vector store


: 

Add Known Good SQL

In [4]:
embedder = EmbedderAgent('vertex')

def escape_single_quotes(value):
    return value.replace("'", "''")


async def setup_kgq_table( project_id,
                            schema):
    """ 
    This function sets up or refreshes the Vector Store for Known Good Queries (KGQ)
    """

    # Create BQ Client
    client=bigquery.Client(project=project_id)

    # Delete an old table
    client.query_and_wait(f'''DROP TABLE IF EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings`''')
    # Create a new emptry table
    client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings` (
                            table_schema string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
                            embedding ARRAY<FLOAT64>)''')
        

async def store_kgq_embeddings(df_kgq, 
                            project_id,
                            schema
                            ):
    """ 
    Create and save the Known Good Query Embeddings to Vector Store  
    """

    client=bigquery.Client(project=project_id)
    
    example_sql_details_chunked = []

    for _, row_aug in df_kgq.iterrows():

        example_user_question = str(row_aug['prompt'])
        example_generated_sql = str(row_aug['sql'])
        example_database_name = str(row_aug['database_name'])
        emb =  embedder.create(example_user_question)
        

        r = {"example_database_name":example_database_name,"example_user_question": example_user_question,"example_generated_sql": example_generated_sql,"embedding": emb}
        example_sql_details_chunked.append(r)

    example_prompt_sql_embeddings = pd.DataFrame(example_sql_details_chunked)
    # Preprocess the DataFrame to escape single quotes and replace newlines
    example_prompt_sql_embeddings['example_database_name'] = example_prompt_sql_embeddings['example_database_name'].str.replace("'", "''")
    example_prompt_sql_embeddings['example_user_question'] = example_prompt_sql_embeddings['example_user_question'].str.replace("'", "''").str.replace("\n", " ")
    example_prompt_sql_embeddings['example_generated_sql'] = example_prompt_sql_embeddings['example_generated_sql'].str.replace("'", "''").str.replace("\n", " ")


    for _, row in example_prompt_sql_embeddings.iterrows():
            
            print(f"Example SQL Embeddings for {row['example_user_question']}")
            print(f"Example SQL Embeddings for {row['example_database_name']}")
            print(f"Example SQL Embeddings for {row['example_generated_sql']}")
            client.query_and_wait(f'''DELETE FROM `{project_id}.{schema}.example_prompt_sql_embeddings`
                        WHERE table_schema= '{row["example_database_name"]}' and example_user_question= '{row["example_user_question"]}' '''
                            )
                # embedding=np.array(row["embedding"])
            cleaned_sql = row["example_generated_sql"].replace("\n", " ")
            client.query_and_wait(f'''INSERT INTO `{project_id}.{schema}.example_prompt_sql_embeddings` 
                VALUES ("{row["example_database_name"]}","{row["example_user_question"]}" , 
                "{cleaned_sql}",{row["embedding"]} )''')
                    

In [8]:

# Load the file
df_kgq = pd.read_csv("..//data//knowngoodsql.csv")
df_kgq = df_kgq.loc[:, ["prompt", "sql", "database_name"]]
df_kgq = df_kgq.dropna()
# Preprocess the DataFrame to escape single quotes and replace newlines
df_kgq['database_name'] = df_kgq['database_name'].str.replace("'", "''")
df_kgq['prompt'] = df_kgq['prompt'].str.replace("'", "''")
df_kgq['sql'] = df_kgq['sql'].str.replace("'", "''").str.replace("\n", " ")



In [10]:
##Setup the KGSQL Table
await(setup_kgq_table( project_id,'alberta_energy_ai'))

In [12]:

# Add KGQ to the vector store
await(store_kgq_embeddings(df_kgq,
                            project_id=project_id,
                            schema='alberta_energy_ai'
                            ))

print('Done!!')


Example SQL Embeddings for What is the average monthly bill for customers who have churned?
Example SQL Embeddings for energyagentai
Example SQL Embeddings for SELECT AVG(avg_monthly_bill) AS average_monthly_bill FROM `energyagentai.alberta_energy_ai.customer_base` WHERE is_churned = 1;
Example SQL Embeddings for Show me all customer data for customer ID CUST_001234
Example SQL Embeddings for energyagentai
Example SQL Embeddings for SELECT * FROM `energyagentai.alberta_energy_ai.customer_base` WHERE customer_id = ''''CUST_001234'''';
Example SQL Embeddings for Explain the churn prediction for customer CUST_001234
Example SQL Embeddings for energyagentai
Example SQL Embeddings for SELECT * FROM `energyagentai.alberta_energy_ai.customer_base` WHERE customer_id = ''''CUST_001234'''';
Example SQL Embeddings for What factors contribute to churn risk for high-usage customers?
Example SQL Embeddings for energyagentai
Example SQL Embeddings for SELECT * FROM `energyagentai.alberta_energy_ai.cu