In [None]:
import pandas as _hex_pandas
import datetime as _hex_datetime
import json as _hex_json

In [None]:
hex_scheduled = _hex_json.loads("false")

In [None]:
hex_user_email = _hex_json.loads("\"example-user@example.com\"")

In [None]:
hex_run_context = _hex_json.loads("\"logic\"")

In [None]:
hex_timezone = _hex_json.loads("\"UTC\"")

In [None]:
hex_project_id = _hex_json.loads("\"9f3e2ca6-e2d9-4be5-b2b1-d761a410618b\"")

In [None]:
hex_project_name = _hex_json.loads("\"DataSight AI\"")

In [None]:
hex_status = _hex_json.loads("\"In Progress\"")

In [None]:
hex_categories = _hex_json.loads("[\"External\"]")

In [None]:
hex_color_palette = _hex_json.loads("[\"#4C78A8\",\"#F58518\",\"#E45756\",\"#72B7B2\",\"#54A24B\",\"#EECA3B\",\"#B279A2\",\"#FF9DA6\",\"#9D755D\",\"#BAB0AC\"]")

This app allows users to quickly onboard themselves to the contents of a database.  It uses standard data tools and Generative AI agents to surface information about the data. Currently it offers: 
* AI generated descriptions for each table
* Data preview and column stats 
* AI generated table relationship details 
* Natural language querying via an LLM-based agent 

(This is meant to be a proof of concept for how AI tools can help users better understand and work with their data.)

Coming Soon 🗺️:
* AI generated ERDs
* Advanced query agent with Error handling and recovery
* AI recommended visualizations

You can get technical details and access to the code on [GitHub](https://github.com/brayden-s-haws/data_sight_ai/tree/main).

In [None]:
"""
Import non-standard packages
"""

!pip install langchain langchain-experimental
!pip install sqlalchemy-bigquery

Collecting langchain-experimental
  Downloading langchain_experimental-0.0.47-py3-none-any.whl.metadata (1.9 kB)
Collecting langchain
  Downloading langchain-0.0.353-py3-none-any.whl.metadata (13 kB)
INFO: pip is looking at multiple versions of langchain-experimental to determine which version is compatible with other requirements. This could take a while.
Collecting langchain-experimental
  Downloading langchain_experimental-0.0.46-py3-none-any.whl.metadata (1.9 kB)
Collecting langchain-core<0.1,>=0.0.11 (from langchain)
  Downloading langchain_core-0.0.13-py3-none-any.whl.metadata (978 bytes)
Collecting langchain-community<0.1,>=0.0.2 (from langchain)
  Downloading langchain_community-0.0.7-py3-none-any.whl.metadata (7.3 kB)
INFO: pip is looking at multiple versions of langchain to determine which version is compatible with other requirements. This could take a while.
Collecting langchain
  Downloading langchain-0.0.352-py3-none-any.whl.metadata (13 kB)
  Downloading langchain-0.0.35

In [None]:
"""
Setup the BigQuery engine for data previews and AI agent use
"""

from google.cloud import bigquery
from google.oauth2 import service_account
import json
from sqlalchemy import create_engine  # Table, MetaData, Integer, String, Column, ForeignKey, Float, Date

# Path to your service account key file
service_account_file = 'bq_config.json'

# Open and load the JSON file to get the credentials
with open(service_account_file, 'r') as file:
    bq_creds_dict = json.load(file)

# Load credentials from the dictionary
credentials = service_account.Credentials.from_service_account_info(bq_creds_dict)

# Create a BigQuery client with the credentials
client = bigquery.Client(credentials=credentials)


datasets = list(client.list_datasets(project='bigquery-public-data'))
if datasets:
    print ("Found datasets")
else:
    print("No datasets found.")

# Create SQLAlchemy engine
engine = create_engine("bigquery://", credentials_info=bq_creds_dict)

Found datasets


In [None]:
select_dataset = _hex_json.loads("\"StackOverflow\"")

In [None]:
"""
Convert dataset names from friendly name to actual name
"""

if select_dataset == 'Austin Bikeshare':
    prompt_dataset_name = 'bigquery-public-data.austin_bikeshare'
elif select_dataset == 'US Census':
    prompt_dataset_name = 'bigquery-public-data.census_bureau_usa'
elif select_dataset == 'FHIR':
    prompt_dataset_name = 'bigquery-public-data.fhir_synthea'
elif select_dataset == 'Google Analytics 4':
    prompt_dataset_name = 'bigquery-public-data.ga4_obfuscated_sample_ecommerce'
elif select_dataset == 'FAA':
    prompt_dataset_name = 'bigquery-public-data.faa'
elif select_dataset == 'Google Cloud Release Notes':
    prompt_dataset_name = 'bigquery-public-data.google_cloud_release_notes'
elif select_dataset == 'Iowa Liquor Sales Forecast':
    prompt_dataset_name = 'bigquery-public-data.iowa_liquor_sales_forecasting'
elif select_dataset == 'Medicare':
    prompt_dataset_name = 'bigquery-public-data.medicare'
elif select_dataset == 'NCAA Basketball':
    prompt_dataset_name = 'bigquery-public-data.ncaa_basketball'
elif select_dataset == 'San Francisco Film Locations':
    prompt_dataset_name = 'bigquery-public-data.san_francisco_film_locations'
elif select_dataset == 'StackOverflow':
    prompt_dataset_name = 'bigquery-public-data.stackoverflow'
elif select_dataset == 'USA Popular Names':
    prompt_dataset_name = 'bigquery-public-data.usa_names'
elif select_dataset == 'Wikipedia':
    prompt_dataset_name = 'bigquery-public-data.wikipedia'
else:
    prompt_dataset_name = ''

In [None]:
"""
Get the table names for the specified dataset
"""

if prompt_dataset_name.strip() == '':
    print("Dataset name is empty. Please provide a valid dataset name.")
else:
    try:
        # List all tables in the dataset
        dataset_tables = client.list_tables(prompt_dataset_name)

        dataset_table_names = []
        if dataset_tables:
            print(f"Found tables in dataset {prompt_dataset_name}:")
            for table in dataset_tables:
                dataset_table_names.append(table.table_id)
                print(f"Table found: {table.table_id}")
        else:
            print(f"No tables found in dataset {prompt_dataset_name}.")

        print("All table names:", dataset_table_names)

    except Exception as e:
        print(f"An error occurred: {e}")

Found tables in dataset bigquery-public-data.stackoverflow:
Table found: badges
Table found: comments
Table found: post_history
Table found: post_links
Table found: posts_answers
Table found: posts_moderator_nomination
Table found: posts_orphaned_tag_wiki
Table found: posts_privilege_wiki
Table found: posts_questions
Table found: posts_tag_wiki
Table found: posts_tag_wiki_excerpt
Table found: posts_wiki_placeholder
Table found: stackoverflow_posts
Table found: tags
Table found: users
Table found: votes
All table names: ['badges', 'comments', 'post_history', 'post_links', 'posts_answers', 'posts_moderator_nomination', 'posts_orphaned_tag_wiki', 'posts_privilege_wiki', 'posts_questions', 'posts_tag_wiki', 'posts_tag_wiki_excerpt', 'posts_wiki_placeholder', 'stackoverflow_posts', 'tags', 'users', 'votes']


In [None]:
import json as _hex_json
prompt_table = _hex_pks.kernel_execution.input_cell.run_dropdown_dynamic(args=_hex_types.DropdownDynamicArgs.from_dict({**_hex_json.loads("{\"dataframe_column\":null,\"ui_selected_value\":\"posts_answers\"}"), **{_hex_json.loads("\"options_variable\""):_hex_kernel.variable_or_none("dataset_table_names", scope_getter=lambda: globals())}}), app_session_token=_hex_APP_SESSION_TOKEN, python_kernel_init_status=_hex_python_kernel_init_status, hex_timezone=_hex_kernel.variable_or_none("hex_timezone", scope_getter=lambda: globals()), interrupt_event=locals().get("_hex_interrupt_event"))

import json as _hex_json
_hex_pks.kernel_execution.input_cell.filled_dynamic_value(args=_hex_types.FilledDynamicValueArgs.from_dict({**_hex_json.loads("{\"variable_name\":\"dataset_table_names\",\"dataframe_column\":null,\"max_size\":10000,\"max_size_in_bytes\":5242880}"), **{_hex_json.loads("\"variable\""):_hex_kernel.variable_or_none("dataset_table_names", scope_getter=lambda: globals())}}), app_session_token=_hex_APP_SESSION_TOKEN, python_kernel_init_status=_hex_python_kernel_init_status, hex_timezone=_hex_kernel.variable_or_none("hex_timezone", scope_getter=lambda: globals()), interrupt_event=locals().get("_hex_interrupt_event"))

In [None]:
rows_to_preview = _hex_json.loads("15")

In [None]:
"""
Generated a preview of the selected table and sample size
"""

table_to_query = prompt_dataset_name + "." + prompt_table
rows_to_preview = rows_to_preview


preview_query = f"""
SELECT * 
FROM `{table_to_query}`
LIMIT {rows_to_preview}
"""

query_job = client.query(preview_query)

In [None]:
"""
Use Langchain to generate and store a table description. First look up if a description exists, if so display it. If one does not exist then generate a description.
"""

import csv
import os
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

csv_filename = "table_descriptions.csv"
describe_prompt = f' describe {table_to_query}. Give a general description of the table and what it is for. Highlight the columns that are in the table, what each column contains.'

db = SQLDatabase(engine) #, include_tables=prompt_tables
llm = OpenAI(temperature=0, verbose=True)

db_chain = SQLDatabaseChain.from_llm(llm, verbose=False,db=db, use_query_checker=True, top_k=10)

# Check if the CSV file exists and read it into a dictionary
if os.path.exists(csv_filename):
    with open(csv_filename, mode="r", newline="") as csvfile:
        reader = csv.DictReader(csvfile)
        description_dict = {row["key"]: row["description"] for row in reader}
else:
    description_dict = {}

# Check if the description for table_to_query exists
if table_to_query in description_dict:
    print(description_dict[table_to_query])
else:
    # Run the db_chain and store the output
    description_output = db_chain.run(describe_prompt)
    print(description_output)

    # Add the new record to the CSV file
    with open(csv_filename, mode="a", newline="") as csvfile:
        fieldnames = ["key", "description"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not description_dict:  # If the dictionary is empty, write the header
            writer.writeheader()
        writer.writerow({"key": table_to_query, "description": description_output})

The table posts_answers contains the answers to questions posted on Stack Overflow. The columns in the table are id, title, body, accepted_answer_id, answer_count, comment_count, community_owned_date, creation_date, favorite_count, last_activity_date, last_edit_date, last_editor_display_name, last_editor_user_id, owner_display_name, owner_user_id, parent_id, post_type_id, score, tags, view_count.


In [None]:
"""
Convert the table to a dataframe for working with Pandas
"""

import pandas as pd

dataframe_table = query_job.to_dataframe()
dataframe_table.head(rows_to_preview)

Unnamed: 0,id,title,body,accepted_answer_id,answer_count,comment_count,community_owned_date,creation_date,favorite_count,last_activity_date,last_edit_date,last_editor_display_name,last_editor_user_id,owner_display_name,owner_user_id,parent_id,post_type_id,score,tags,view_count
0,666919,,<p>I think that DBDesigner don't draw links pr...,,,0,NaT,2009-03-20 16:28:36.123000+00:00,,2009-03-20 16:28:36.123000+00:00,NaT,,,noman,,637935,2,0,,
1,666956,,<p>There are some real issues with argument al...,,,1,NaT,2009-03-20 16:39:48.810000+00:00,,2009-03-20 16:39:48.810000+00:00,NaT,,,Robert,,242894,2,0,,
2,666966,,<p>Windows Server 2008 supports VPN capabiliti...,,,0,NaT,2009-03-20 16:42:52.970000+00:00,,2009-03-20 16:42:52.970000+00:00,NaT,,,jtdrummerboy,,19721,2,0,,
3,667041,,<p>Looks like Silverlight 3 supports direct PC...,,,0,NaT,2009-03-20 16:59:03.210000+00:00,,2009-03-20 16:59:03.210000+00:00,NaT,,,,,585868,2,0,,
4,667297,,<p>Login to the server that runs the IIS using...,,,0,NaT,2009-03-20 18:05:43.113000+00:00,,2009-03-20 18:05:43.113000+00:00,NaT,,,Nate Vasquez,,168946,2,0,,
5,667670,,<p>If you get two copies of splitter.py runnin...,,,2,NaT,2009-03-20 19:40:18.253000+00:00,,2009-03-20 19:40:18.253000+00:00,NaT,,,Andy V,,667500,2,0,,
6,668127,,<p>The easy way to do it is to have the user t...,,,1,NaT,2009-03-20 21:44:14.710000+00:00,,2009-03-20 21:44:14.710000+00:00,NaT,,,Lucius Kwok,,79445,2,0,,
7,668487,,<p>You can create some form of persistence usi...,,,0,NaT,2009-03-21 00:40:35.927000+00:00,,2009-03-21 00:40:35.927000+00:00,NaT,,,Daniel Luyo,,667891,2,0,,
8,668609,,<p>-- What about EXCEPT? (if this is SQL Serve...,,,0,NaT,2009-03-21 02:15:55.930000+00:00,,2009-03-21 02:15:55.930000+00:00,NaT,,,,,666595,2,0,,
9,668671,,<p>alternatively you could use lastfm web serv...,,,0,NaT,2009-03-21 03:16:12.587000+00:00,,2009-03-21 03:16:12.587000+00:00,NaT,,,maxim,,664771,2,0,,


In [None]:
"""
Display stats for all columns in the table
"""

dataframe_table.describe(include='all', datetime_is_numeric=True)

# Note: This is currently only showing the stats for the preview set of data, needs to be expanded to show all data but it was crashing the notebook due to lack of memory

Unnamed: 0,id,title,body,accepted_answer_id,answer_count,comment_count,community_owned_date,creation_date,favorite_count,last_activity_date,last_edit_date,last_editor_display_name,last_editor_user_id,owner_display_name,owner_user_id,parent_id,post_type_id,score,tags,view_count
count,15.0,0.0,15,0.0,0.0,15.0,0,15,0.0,15,0,0.0,0.0,13,0.0,15.0,15.0,15.0,0.0,0.0
unique,,0.0,15,0.0,0.0,,,,0.0,,,0.0,,13,,,,,0.0,0.0
top,,,<p>I think that DBDesigner don't draw links pr...,,,,,,,,,,,noman,,,,,,
freq,,,1,,,,,,,,,,,1,,,,,,
mean,668104.866667,,,,,0.266667,NaT,2009-03-21 00:25:08.485333248+00:00,,2009-03-21 00:25:08.485333248+00:00,NaT,,,,,507846.133333,2.0,0.0,,
min,666919.0,,,,,0.0,NaT,2009-03-20 16:28:36.123000+00:00,,2009-03-20 16:28:36.123000+00:00,NaT,,,,,19721.0,2.0,0.0,,
25%,667169.0,,,,,0.0,NaT,2009-03-20 17:32:23.161499904+00:00,,2009-03-20 17:32:23.161499904+00:00,NaT,,,,,395801.5,2.0,0.0,,
50%,668487.0,,,,,0.0,NaT,2009-03-21 00:40:35.927000064+00:00,,2009-03-21 00:40:35.927000064+00:00,NaT,,,,,660463.0,2.0,0.0,,
75%,668756.5,,,,,0.0,NaT,2009-03-21 04:46:55.232000+00:00,,2009-03-21 04:46:55.232000+00:00,NaT,,,,,667695.5,2.0,0.0,,
max,669132.0,,,,,2.0,NaT,2009-03-21 11:09:59.330000+00:00,,2009-03-21 11:09:59.330000+00:00,NaT,,,,,669105.0,2.0,0.0,,


In [None]:
dataframe_table

In [None]:
user_prompt = _hex_json.loads("\"Which post had the most answers and what was the question, and what was the most popular answer and which user answered it \"")

In [None]:
run_prompt = _hex_json.loads("false")

In [None]:
"""
Use Langchain to query the database for the user. Converts their natural language prompt to SQL.
"""

# TODO: Switch to agentic model that can handle errors

system_prompt = f' in {table_to_query}.You are a BigQuery expert. You are able quickly review the tables in a dataset and understand the contents of each table along with their relation. You will be asked a question for which you need to generate and execute a query. The table in the question is the main focus of the question, but you may also need to join to other tables, so keep them in mind as your create your plan. The other tables are {dataset_table_names}. The column names may not match 1:1 in the prompt, use your best reasoning to select a column (for instance a user may ask for an account but in the table the column is account_name).Ensure that the columns you use in the query exist in the table. As you answer the users question, consider what other columns may be additive to their question and include those in your response'
full_prompt = user_prompt + system_prompt


if run_prompt:
   
    
    from langchain.utilities import SQLDatabase
    from langchain.llms import OpenAI
    from langchain_experimental.sql import SQLDatabaseChain
    
    db = SQLDatabase(engine) # include_tables=dataset_tables_to_query
    llm = OpenAI(temperature=0, verbose=True)

    db_chain = SQLDatabaseChain.from_llm(llm, verbose=True,db=db, use_query_checker=True, top_k=10)

    db_chain.run(full_prompt)

else: 
    display("Waiting on you to run the query")



[1m> Entering new SQLDatabaseChain chain...[0m
Which post had the most answers and what was the question, and what was the most popular answer and which user answered it  in bigquery-public-data.stackoverflow.posts_answers.You are a BigQuery expert. You are able quickly review the tables in a dataset and understand the contents of each table along with their relation. You will be asked a question for which you need to generate and execute a query. The table in the question is the main focus of the question, but you may also need to join to other tables, so keep them in mind as your create your plan. The other tables are ['badges', 'comments', 'post_history', 'post_links', 'posts_answers', 'posts_moderator_nomination', 'posts_orphaned_tag_wiki', 'posts_privilege_wiki', 'posts_questions', 'posts_tag_wiki', 'posts_tag_wiki_excerpt', 'posts_wiki_placeholder', 'stackoverflow_posts', 'tags', 'users', 'votes']. The column names may not match 1:1 in the prompt, use your best reasoning to s

In [None]:
"""
Use Langchain to generate an explanation of relationships between tables
"""

# TODO: Store these like we do the descriptions
# TODO: Add error handling
# TODO: Replace with AI generated ERD

relationship_prompt = f' Describe the relationship between {table_to_query} and the other tables in the dataset.'

db = SQLDatabase(engine) #, include_tables=prompt_tables
llm = OpenAI(temperature=0, verbose=True)

db_chain = SQLDatabaseChain.from_llm(llm, verbose=False,db=db, use_query_checker=True, top_k=1)

relationship_output = db_chain.run(relationship_prompt)
print(relationship_output)

The relationship between bigquery-public-data.stackoverflow.posts_answers and the other tables in the dataset is that the posts_answers table contains answers to questions in the posts_questions table, with each answer having a score and body associated with it.
