In [13]:
from dotenv import load_dotenv
load_dotenv()

# Access environment variables
import os
openai_key = os.getenv("API_KEY")
db_url = os.getenv("DB_URL")

In [2]:
#Using opneai-tools Agent
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import SQLDatabaseToolkit

from dotenv import load_dotenv
load_dotenv()

import psycopg2

# Setup PostgreSQL database using SQLDatabaseToolkit
db = SQLDatabase.from_uri(
    db_url
)
print(db.dialect)
print(db.get_usable_table_names())


postgresql
['ADM2022', 'C2022DEP', 'C2022_A', 'C2022_B', 'C2022_C', 'EFFY2022', 'EFFY2022_DIST', 'EFIA2022', 'FLAGS2022', 'GR200_22', 'GR2022', 'GR2022_L2', 'GR2022_PELL_SSL', 'HD2022', 'IC2022', 'IC2022_AY', 'IC2022_CAMPUSES', 'IC2022_PY', 'OM2022', 'SFA2122', 'SFAV2122']


In [3]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0, max_tokens= 4000, )


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)
table_chain.invoke({"input": "The total number of branch campuses offering graduate programs."})

[Table(name='C2022DEP'), Table(name='IC2022_CAMPUSES')]

In [4]:
system = f"""Return the names of the SQL tables that are relevant to the user question. \
The tables are:
Admissions
Degrees Awarded
Programs Offered
Enrollment
Instructional Activity
Institutaional Characteristics
Student Financial Aid
Graduation Rates"""
category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)
category_chain.invoke({"input":  "List of all the Institutes in Massachusetts."})

[Table(name='Institutaional Characteristics')]

In [5]:
from typing import List
def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Admissions":
            tables.extend(
                [
                    "ADM2022",
                    "IC2022_CAMPUSES",
                ]
            )
        elif category.name == "Degrees Awarded":
            tables.extend(
                [
                    "C2022_A",
                    "C2022_B",
                    "C2022_C",
                ]
            )
        elif category.name == "Enrollment":
            tables.extend(
                [
                    "EFFY2022",
                    "EFFY2022_DIST",
                ]
            )
        elif category.name == "Programs Offered":
            tables.extend([
                "C2022DEP",
                "IC2022",
            ])
        elif category.name == "Instructional Activity":
            tables.extend(
                [
                    
                    "EFIA2022",
                    "IC2022",
                ]
            )
        elif  category.name == "Institutional Characteristics":
            tables.extend(
                [
                    "FLAGS2022",
                    "HD2022",
                    "IC2022",
                    "IC2022_CAMPUSES",
                    "IC2022_PY",      
                ]
            )
        elif category.name == "Student Financial Aid":
            tables.extend(
                [
                    "SFA2122",
                    "SFAV2122",
                ]
            )
        elif category.name == "Graduation Rates":
            tables.extend(
                [
                    "GR2022",
                    "GR200_22",
                    "GR2022_PELL_SSL",
                ]
            )
    return tables

table_chain = category_chain | get_tables  # noqa
table_chain.invoke({"input": "The total number of branch campuses offering graduate programs."})

['C2022DEP', 'IC2022']

In [6]:
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [7]:
query = full_chain.invoke(
    {"question": "The total number of branch campuses offering graduate programs."}
)
print(query)

SELECT COUNT(*) AS total_branch_campuses
FROM "C2022DEP"
WHERE pmastr > 0;


In [8]:
#unique values for each entity we want, for which we define a function that parses the result into a list of elements
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


proper_nouns = query_as_list(db, "SELECT INSTNM, ADDR, CITY, STABBR, ZIP FROM \"HD2022\" WHERE STABBR = 'MA'")
proper_nouns += query_as_list(db, "SELECT INSTNM, WEBADDR FROM \"HD2022\" WHERE GROFFER = 1;")
proper_nouns += query_as_list(db, "SELECT PCCITY, PCSTABBR FROM \"IC2022_CAMPUSES\";")
len(proper_nouns)
proper_nouns[:5]

['Empire Beauty School-Boston', 'West Street', 'Boston', 'MA', '']

In [9]:
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})

In [10]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

system = """You are a PostgreSQL expert. Given an input question, create a syntactically \
correct PostgreSQL query to run. Unless otherwise specificed, do not return more than \
{top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nHere is a non-exhaustive \
list of possible feature values. \n\n DO NOT MAKE ANY DML COMMANDS(INSERT, UPDATE, DELETE) \n\n If filtering on a feature value make sure to check its spelling \
against this list first:\n\n{proper_nouns}"""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
    itemgetter("question")
    | retriever
    | (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain

In [11]:
# Without retrieval
query = query_chain.invoke(
    {"question": "The total number of branch campuses offering graduate programs.", "proper_nouns": ""}
)
print(query)
db.run(query)

BadRequestError: Error code: 400 - {'error': {'message': "This model's maximum context length is 16385 tokens. However, your messages resulted in 43127 tokens. Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}

In [12]:
# With retrieval
query = chain.invoke({"question": "The total number of branch campuses offering graduate programs."})
print(query)
db.run(query)

BadRequestError: Error code: 400 - {'error': {'message': "This model's maximum context length is 16385 tokens. However, your messages resulted in 43239 tokens. Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}