In [1]:
from dotenv import load_dotenv
import os
from google.cloud import dlp_v2
from google.cloud.dlp_v2 import types
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langchain_experimental.utilities import PythonREPL
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages import AIMessage
from langchain_core.prompts import (
    SystemMessagePromptTemplate,
    PromptTemplate,
    FewShotPromptTemplate,
)
from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)

from langchain.memory import ConversationBufferMemory
from langchain.tools import Tool

In [2]:
load_dotenv()

True

In [3]:
# Setup

service_account_file = f"{os.getcwd()}/round-booking-276105-a4524c60591f.json"
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_file

In [4]:
project='round-booking-276105'
dataset='L1'
# Example Queries
sql_examples = [
    {
        "input": "Count of Customers by Source System",
        "query": f"""
            SELECT
                source_system_name,
                COUNT(*) AS customer_count
            FROM
                `{project}.{dataset}.customer`
            GROUP BY
                source_system_name
            ORDER BY
                customer_count DESC;
        """,
    },
    {
        "input": "Average Age of Customers by Gender",
        "query": f"""
            SELECT
                gender,
                AVG(EXTRACT(YEAR FROM CURRENT_DATE()) - EXTRACT(YEAR FROM dob)) AS average_age
            FROM
                `{project}.{dataset}.customer`
            GROUP BY
                gender;
        """,
    },
    {
        "input": "Count of Customers with Email and/or Phone",
        "query": f"""
            SELECT
                c.customer_key,
                c.first_name,
                c.last_name,
                SUM(CASE WHEN ct.type = 'email' THEN 1 ELSE 0 END) AS email_count,
                SUM(CASE WHEN ct.type = 'phone' THEN 1 ELSE 0 END) AS phone_count
            FROM
                `{project}.{dataset}.customer` c
            LEFT JOIN
                `{project}.{dataset}.contact` ct
            ON
                c.customer_key = ct.customer_key
            GROUP BY
                c.customer_key, c.first_name, c.last_name
            ORDER BY
                email_count DESC, phone_count DESC;
        """,
    },
    {
        "input": "List of Customers with Addresses",
        "query": f"""
            SELECT
                c.customer_key,
                c.first_name,
                c.last_name,
                a.full_address,
                a.state,
                a.country
            FROM
                `{project}.{dataset}.customer` c
            JOIN
                `{project}.{dataset}.customer_address` ca
            ON
                c.customer_key = ca.customer_key
            JOIN
                `{project}.{dataset}.address` a
            ON
                ca.address_key = a.address_key;
        """,
    },
    {
        "input": "Job States Summary",
        "query": f"""
            SELECT
                batch_id,
                status,
                record_count,
                load_timestamp,
                JSON_EXTRACT_SCALAR(job_summary, '$.SYS1') AS sys1_count,
                JSON_EXTRACT_SCALAR(job_summary, '$.SYS2') AS sys2_count,
                JSON_EXTRACT_SCALAR(job_summary, '$.SYS3') AS sys3_count,
                JSON_EXTRACT_SCALAR(job_summary, '$.SYS4') AS sys4_count,
                JSON_EXTRACT_SCALAR(job_summary, '$.SYS5') AS sys5_count
            FROM
                `{project}.{dataset}.job_states`
            ORDER BY
                load_timestamp DESC;
        """,
    },
    {
        "input": "Top 5 Most Populated States",
        "query": f"""
            SELECT
                state,
                COUNT(*) AS address_count
            FROM
                `{project}.{dataset}.address`
            GROUP BY
                state
            ORDER BY
                address_count DESC
            LIMIT 5;
        """,
    },
    {
        "input": "Total Contacts (Emails and Phones) by Source System",
        "query": f"""
            SELECT
                c.source_system_name,
                SUM(CASE WHEN ct.type = 'email' THEN 1 ELSE 0 END) AS total_emails,
                SUM(CASE WHEN ct.type = 'phone' THEN 1 ELSE 0 END) AS total_phones
            FROM
                `{project}.{dataset}.customer` c
            JOIN
                `{project}.{dataset}.contact` ct
            ON
                c.customer_key = ct.customer_key
            GROUP BY
                c.source_system_name;
        """,
    },
    {
        "input": "Distribution of Customers by Age Groups",
        "query": f"""
            SELECT
                CASE
                    WHEN age < 20 THEN 'Under 20'
                    WHEN age BETWEEN 20 AND 29 THEN '20-29'
                    WHEN age BETWEEN 30 AND 39 THEN '30-39'
                    WHEN age BETWEEN 40 AND 49 THEN '40-49'
                    WHEN age BETWEEN 50 AND 59 THEN '50-59'
                    ELSE '60 and above'
                END AS age_group,
                COUNT(*) AS customer_count
            FROM
                (SELECT
                    EXTRACT(YEAR FROM CURRENT_DATE()) - EXTRACT(YEAR FROM dob) AS age
                FROM
                    `{project}.{dataset}.customer`)
            GROUP BY
                age_group
            ORDER BY
                customer_count DESC;
        """,
    },
    {
        "input": "Customers with Multiple Source Systems",
        "query": f"""
            SELECT
                first_name,
                last_name,
                COUNT(DISTINCT source_system_name) AS source_system_count
            FROM
                `{project}.{dataset}.customer`
            GROUP BY
                first_name, last_name
            HAVING
                source_system_count > 1;
        """,
    },
    {
        "input": "Recent Job Runs with Their Status",
        "query": f"""
            SELECT
                batch_id,
                status,
                record_count,
                load_timestamp
            FROM
                `{project}.{dataset}.job_states`
            ORDER BY
                load_timestamp DESC
            LIMIT 10;
        """,
    },
]


PREFIX = """
You are a SQL expert. You have access to a BigQuery database.
Identify which tables can be used to answer the user's question and write and execute a SQL query accordingly.
Given an input question, create a syntactically correct SQL query to run against the dataset customer_profiles, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table; only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the information returned by these tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

If the user asks for a visualization of the results, use the python_agent tool to create and display the visualization.

After obtaining the results, you must use the mask_pii_data tool to mask the results before providing the final answer.
"""

SUFFIX = """Begin!

{chat_history}

Question: {input}
Thought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""


In [5]:
def mask_pii_data(text):
    dlp = dlp_v2.DlpServiceClient()

    project_id = project
    parent = f"projects/{project_id}"

    info_types = [
        {"name": "EMAIL_ADDRESS"},
        {"name": "PHONE_NUMBER"},
        {"name": "DATE_OF_BIRTH"},
        {"name": "LAST_NAME"},
        {"name": "STREET_ADDRESS"},
        {"name": "LOCATION"},
    ]

    deidentify_config = types.DeidentifyConfig(
        info_type_transformations=types.InfoTypeTransformations(
            transformations=[
                types.InfoTypeTransformations.InfoTypeTransformation(
                    primitive_transformation=types.PrimitiveTransformation(
                        character_mask_config=types.CharacterMaskConfig(
                            masking_character="*", number_to_mask=0, reverse_order=False
                        )
                    )
                )
            ]
        )
    )

    item = {"value": text}
    inspect_config = {"info_types": info_types}
    request = {
        "parent": parent,
        "inspect_config": inspect_config,
        "deidentify_config": deidentify_config,
        "item": item,
    }

    response = dlp.deidentify_content(request=request)

    return response.item.value


python_repl = PythonREPL()

In [6]:
def sql_agent_tools():
    tools = [
        Tool.from_function(
            func=mask_pii_data,
            name="mask_pii_data",
            description="Masks PII data in the input text using Google Cloud DLP.",
        ),
        Tool(
            name="python_repl",
            description=f"A Python shell. Use this to execute python commands. \
              Input should be a valid python command. \
              If you want to see the output of a value, \
              you should print it out with `print(...)`.",
            func=python_repl.run,
        ),
    ]
    return tools


In [7]:
example_selector = SemanticSimilarityExampleSelector.from_examples(
    sql_examples,
    OpenAIEmbeddings(),
    FAISS,
    k=2,
    input_keys=["input"],
)


In [8]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    prefix=PREFIX,
    suffix="",
    input_variables=["input", "top_k"],
    example_separator="\n\n",
)


In [9]:

messages = [
    SystemMessagePromptTemplate(prompt=few_shot_prompt),
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{input}"),
    AIMessage(content=SUFFIX),
    MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)

In [10]:
extra_tools = sql_agent_tools()

In [11]:
memory = ConversationBufferMemory(
    memory_key="chat_history", return_messages=True, input_key="input"
)

  memory = ConversationBufferMemory(


In [12]:
# Connect to your Google BigQuery database
db = SQLDatabase.from_uri("bigquery://round-booking-276105")

In [13]:
# Create a language model
# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
model = ChatOpenAI(model="gpt-4o", temperature=0)

In [14]:

# Create the agent executor
agent_executor = create_sql_agent(
    llm=model,
    db=db,
    verbose=True,
    top_k=10,
    prompt=prompt,
    extra_tools=extra_tools,
    input_variables=["input", "agent_scratchpad", "chat_history"],
    agent_type="openai-tools",
    agent_executor_kwargs={"handle_parsing_errors": True, "memory": memory},
)


In [15]:
agent_executor.invoke("L1.CUSTOMER_MASTER 의 고객 연령의 평균은 얼마야?")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mL1.CUSTOMER_MASTER, L1.TRPURC_OBS_CART_L, L1.TRPURC_OBS_ORD_L[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'L1.CUSTOMER_MASTER'}`


[0m[33;1m[1;3m
CREATE TABLE `L1.CUSTOMER_MASTER` (
	`PSST_CUST_ID` STRING OPTIONS(description='영구고객아이디'), 
	`CO_CODE` STRING OPTIONS(description='관계사코드'), 
	`CO_NAME` STRING OPTIONS(description='관계사이름'), 
	`EMAIL_ADDR` STRING OPTIONS(description='전자우편주소'), 
	`FNAME_NAME` STRING OPTIONS(description='FIRSTNAME이름'), 
	`MNAME_NAME` STRING OPTIONS(description='MIDDLENAME이름'), 
	`LNAME_NAME` STRING OPTIONS(description='LASTNAME이름'), 
	`MOBL_PHN_NO` STRING OPTIONS(description='모바일전화번호'), 
	`PHN_NO` STRING OPTIONS(description='전화번호'), 
	`CUST_AGE` STRING OPTIONS(description='고객연령'), 
	`BDAY_DATE` STRING OPTIONS(description='생일일자'), 
	`SEX_CODE` STRING OPTIONS(description='성별코드'), 
	`SEX_NAME` STRING OPTIONS(descr

{'input': 'L1.CUSTOMER_MASTER 의 고객 연령의 평균은 얼마야?',
 'chat_history': [HumanMessage(content='L1.CUSTOMER_MASTER 의 고객 연령의 평균은 얼마야?', additional_kwargs={}, response_metadata={}),
  AIMessage(content='The average age of customers in the L1.CUSTOMER_MASTER table is not available or cannot be calculated due to missing or invalid data.', additional_kwargs={}, response_metadata={})],
 'output': 'The average age of customers in the L1.CUSTOMER_MASTER table is not available or cannot be calculated due to missing or invalid data.'}