In [1]:
from langchain import hub
from langchain.agents import AgentExecutor, create_tool_calling_agent
import pandas as pd
import os
import sqlite3
from openai import AzureOpenAI
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from dotenv import load_dotenv


In [2]:
load_dotenv()

True

In [3]:
api_key = os.environ.get('API_KEY')
azure_endpoint = os.environ.get('AZURE_ENDPOINT')
llm = AzureChatOpenAI(
    api_key = api_key,
    azure_endpoint = azure_endpoint,
    openai_api_version = "2024-03-01-preview",
    azure_deployment = "myTalentX_GPTo",
    temperature = 0)



In [4]:
df = pd.read_csv('../df 1.csv')
# set id as index
df.set_index('Id', inplace=True)

# Get the current directory
current_directory = os.getcwd()

# Define the database filename
db_filename = 'housing.db'

# Create the full path to the database file
db_path = os.path.join(current_directory, db_filename)

# Create a connection to the SQLite database
conn = sqlite3.connect(db_path)

# Write the DataFrame to a SQL table
df.to_sql('housing', conn, if_exists='replace', index=True)

# Close the connection
conn.close()

# Print the SQLite connection string (URI)
print(f"SQLite Database URI: sqlite:///{db_path}")

db_local = SQLDatabase.from_uri(f"sqlite:///{db_path}")



SQLite Database URI: sqlite:///d:\Projects\databricks-llm\notebooks\housing.db


In [5]:
# agent = create_sql_agent(llm, db=db_local, agent_type="openai-tools", verbose=True)

# thought_response = agent.invoke("What is the average price of a house in the dataset?")

# print(thought_response)

In [6]:
from typing import List, Optional, Dict, Any, Union, Type
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseToolkit, BaseTool
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
    QuerySQLDataBaseTool,
    BaseSQLDatabaseTool
)
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain.agents import create_sql_agent
from langchain_openai import ChatOpenAI


# import logging
from typing import Type, Optional, Any, List
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_community.tools.sql_database.tool import BaseSQLDatabaseTool
from langchain_core.tools import BaseTool
from langchain_core.callbacks import CallbackManagerForToolRun
from difflib import SequenceMatcher
import re
import math

# Set up logging
# logging.basicConfig(level=logging.DEBUG)
# logger = logging.getLogger(__name__)

class _ColumnValidationToolInput(BaseModel):
    query: str = Field(
        ...,
        description="A string containing the table name, column name, and value to compare, separated by commas. Example: 'table_name, column_name, value_to_compare'"
    )


def jaccard_similarity(a: str, b: str) -> float:
    """Calculate the Jaccard similarity between two strings."""
    set_a = set(a.lower())
    set_b = set(b.lower())
    intersection = len(set_a.intersection(set_b))
    union = len(set_a.union(set_b))
    return intersection / union if union != 0 else 0

def extract_english_words(input_string):
    # Convert string representation to a list
    tuple_list = eval(input_string)


    # Extract terms from each tuple
    terms = [item[0] for item in tuple_list]

    # Convert all items to strings
    terms = [str(item) for item in terms]
    
    # Regular expression to match only English words
    word_pattern = re.compile(r'^[a-zA-Z]+$')
    
    # Filter to keep only English words
    english_words = [term for term in terms if word_pattern.match(term)]
    
    return english_words

class ColumnValidationTool(BaseSQLDatabaseTool, BaseTool):
    """Tool for validating column values using Jaccard similarity."""

    name: str = "sql_db_column_validation"
    description: str = "Validate a value against a column in a table using Jaccard similarity. Input should be 'table_name, column_name, value_to_compare'."
    args_schema: Type[BaseModel] = _ColumnValidationToolInput

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Validate the given value against the column values using Jaccard similarity."""
        # logger.debug(f"Received input: {query}")
        
        try:
            parts = [part.strip() for part in query.split(',')]
            if len(parts) != 3:
                raise ValueError("Input should be 'table_name, column_name, value_to_compare'")
            
            table_name, column_name, value_to_compare = parts
            

            # logger.info(f"Parsed input - Table: {table_name}, Column: {column_name}, Value to compare: {value_to_compare}")

            if table_name not in self.db.get_usable_table_names():
                return f"Error: The table '{table_name}' does not exist in the database."

            sql_query = f"SELECT DISTINCT {column_name} FROM {table_name} LIMIT 1000"
            
            # logger.debug(f"Executing query: {sql_query}")

            result = self.db.run_no_throw(sql_query)
            if isinstance(result, str):
                # split by ,
                # print(result)
                distinct_values = extract_english_words(result)

                # print(distinct_values)
                highest_similarity: float = 0
                most_similar_term: Optional[str] = None

                for value in distinct_values:
                    similarity = jaccard_similarity(value_to_compare, value)
                    if similarity > highest_similarity:
                        highest_similarity = similarity
                        most_similar_term = value


                if highest_similarity > 0.5:
                    return f"Highest similarity match found. '{value_to_compare}' is most similar to existing value '{most_similar_term}' with Jaccard similarity of {highest_similarity:.2f}"
                else:
                    return f"No similar values found above the threshold of 0.5 for '{value_to_compare}' in '{column_name}' of '{table_name}'."
            else:
                # logger.error(f"Unexpected result type: {type(result)}")
                return f"Unexpected result type: {type(result)}. Content: {str(result)}"

        except ValidationError as e:
            # logger.error(f"Validation error: {e}")
            return f"Input validation error: {e}"
        except Exception as e:
            # logger.exception("An error occurred while validating the column value")
            return f"An error occurred while validating the column value: {str(e)}"

class _DistanceCalculationToolInput(BaseModel):
    query: str = Field(
        ...,
        description="A string containing table name, latitude column, longitude column, reference latitude, reference longitude, and radius, all separated by commas. Example: 'properties,Latitude,Longitude,40.7128,-74.0060,5'"
    )

class DistanceCalculationTool(BaseSQLDatabaseTool, BaseTool):
    """Tool for finding locations within a specified radius."""

    name: str = "sql_db_distance_calculation"
    description: str = (
        "Find locations within a specified radius from reference coordinates. "
        "Input should be 'table_name,latitude_column,longitude_column,ref_latitude,ref_longitude,radius_km"
        "Example: 'properties,Latitude,Longitude,40.7128,-74.0060,5, Category =None, SubType=Landed' will find locations within 5km of the reference point."
    )
    args_schema: Type[BaseModel] = _DistanceCalculationToolInput

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Find locations within the specified radius."""
        try:
            # Parse input
            parts = [part.strip() for part in query.split(',')]
            if len(parts) != 6:
                raise ValueError(
                    "Input should be 'table_name,latitude_column,longitude_column,ref_latitude,ref_longitude,radius_km'"
                )
            
            table_name, lat_col, lon_col, ref_lat, ref_lon, radius = parts
            
            # Validate table exists
            if table_name not in self.db.get_usable_table_names():
                return f"Error: The table '{table_name}' does not exist in the database."
            
            # Convert coordinates and radius to float
            try:
                ref_lat = float(ref_lat)
                ref_lon = float(ref_lon)
                radius = float(radius)
            except ValueError:
                return "Error: Coordinates and radius must be valid numbers."

            # SQLite-compatible Haversine formula query
            sql_query = f"""
            WITH DistanceCalculation AS (
                SELECT 
                    *,
                    (6371.0 * 2 * asin(
                        sqrt(
                            pow(sin(({ref_lat} - {lat_col}) * 0.0174533 / 2), 2) +
                            cos({lat_col} * 0.0174533) * 
                            cos({ref_lat} * 0.0174533) * 
                            pow(sin(({ref_lon} - {lon_col}) * 0.0174533 / 2), 2)
                        )
                    )) as distance_km
                FROM {table_name}
                WHERE {lat_col} IS NOT NULL 
                AND {lon_col} IS NOT NULL
            )
            SELECT 
                *
            FROM DistanceCalculation
            WHERE distance_km <= {radius}
            ORDER BY distance_km;
            """

            result = self.db.run_no_throw(sql_query)
            
            if not result or result == "[]":
                return (
                    f"No locations found within {radius}km of the reference point "
                    f"({ref_lat}, {ref_lon}) in table '{table_name}'"
                )
            
            return (
                f"Found locations within {radius}km radius of ({ref_lat}, {ref_lon}):\n"
                f"{result}"
            )

        except ValueError as e:
            return f"Input validation error: {e}"
        except Exception as e:
            return f"An error occurred while calculating distances: {str(e)}"

# class _DistanceCalculationToolInput(BaseModel):
#     query: str = Field(
#         ...,
#         description="A string containing table name, latitude column, longitude column, reference latitude, reference longitude, and radius, all separated by commas. Example: 'properties,Latitude,Longitude,40.7128,-74.0060,5'"
#     )

# class DistanceCalculationTool(BaseSQLDatabaseTool, BaseTool):
#     """Tool for constructing location-based distance queries."""

#     name: str = "sql_db_distance_calculation"
#     description: str = (
#         "Creates a SQL query to find locations within a specified radius from reference coordinates. "
#         "Input should be 'table_name,latitude_column,longitude_column,ref_latitude,ref_longitude,radius_km'. "
#         "Example: 'properties,Latitude,Longitude,40.7128,-74.0060,5' will create a query to find locations within 5km of the reference point. "
#         "Use this tool first to construct the query, then use query_sql_database_tool to execute it."
#     )
#     args_schema: Type[BaseModel] = _DistanceCalculationToolInput

#     def _run(
#         self,
#         query: str,
#         run_manager: Optional[CallbackManagerForToolRun] = None,
#     ) -> str:
#         """Construct a SQL query for finding locations within the specified radius."""
#         try:
#             # Parse input
#             parts = [part.strip() for part in query.split(',')]
#             if len(parts) != 6:
#                 raise ValueError(
#                     "Input should be 'table_name,latitude_column,longitude_column,ref_latitude,ref_longitude,radius_km'"
#                 )
            
#             table_name, lat_col, lon_col, ref_lat, ref_lon, radius = parts
            
#             # Validate table exists
#             if table_name not in self.db.get_usable_table_names():
#                 return f"Error: The table '{table_name}' does not exist in the database."
            
#             # Convert coordinates and radius to float
#             try:
#                 ref_lat = float(ref_lat)
#                 ref_lon = float(ref_lon)
#                 radius = float(radius)
#             except ValueError:
#                 return "Error: Coordinates and radius must be valid numbers."

#             # Construct SQLite-compatible Haversine formula query
#             sql_query = (
#                 f"WITH DistanceCalculation AS ("
#                 f"    SELECT "
#                 f"        *,"
#                 f"        (6371.0 * 2 * asin("
#                 f"            sqrt("
#                 f"                pow(sin(({ref_lat} - {lat_col}) * 0.0174533 / 2), 2) +"
#                 f"                cos({lat_col} * 0.0174533) * "
#                 f"                cos({ref_lat} * 0.0174533) * "
#                 f"                pow(sin(({ref_lon} - {lon_col}) * 0.0174533 / 2), 2)"
#                 f"            )"
#                 f"        )) as distance_km"
#                 f"    FROM {table_name}"
#                 f"    WHERE {lat_col} IS NOT NULL "
#                 f"    AND {lon_col} IS NOT NULL"
#                 f")"
#                 f"SELECT "
#                 f"    *"
#                 f"FROM DistanceCalculation"
#                 f"WHERE distance_km <= {radius}"
#                 f"ORDER BY distance_km"
#                 f"LIMIT 10;"
#             )
            
#             return (
#                 f"I've constructed a SQL query to find locations within {radius}km radius of coordinates "
#                 f"({ref_lat}, {ref_lon}). You can now use query_sql_database_tool to execute this query:\n\n"
#                 f"{sql_query}"
#             )

#         except ValueError as e:
#             return f"Input validation error: {e}"
#         except Exception as e:
#             return f"An error occurred while constructing the query: {str(e)}"


class _TableContextToolInput(BaseModel):
    query: str = Field(
        ...,
        description="Name of the table to analyze. Example: 'housing'"
    )


class InfoSQLDatabaseTool_2(BaseSQLDatabaseTool, BaseTool):
    """Tool for understanding table context by analyzing text columns and their distinct values. Always and must use this tool after {info_sql_database_tool.name} """

    name: str = "sql_db_schema_2"
    description: str = (
        "Analyzes a table and returns text-based columns along with their distinct values. "
        "Input should be just the table name. Output format is 'column_name: value1, value2, value3'"
        "Always identify relevant column and column value for user query"

    )
    args_schema: Type[BaseModel] = _TableContextToolInput

    def _get_column_info(
        self, 
        table_name: str,
        run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> List[Dict]:
        """Get information about all columns in the table."""
        sql_query = f"""
        SELECT 
            name, type
        FROM pragma_table_info('{table_name}')
        """
        result = self.db.run_no_throw(sql_query)
        if isinstance(result, str):
            try:
                columns = eval(result)
                return [{"name": col[0], "type": col[1]} for col in columns]
            except:
                return []
        return []

    def _get_distinct_values(
        self, 
        table_name: str, 
        column_name: str,
        run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> List:
        """Get distinct values for a text column."""
        sql_query = f"""
        SELECT DISTINCT {column_name}
        FROM {table_name}
        WHERE {column_name} IS NOT NULL
        AND {column_name} != ''
        AND LENGTH(TRIM({column_name})) > 0
        ORDER BY {column_name}
        LIMIT 15
        """

        result = self.db.run_no_throw(sql_query)
        if isinstance(result, str):
            try:
                return eval(result)
            except:
                return []
        return []

    def _is_text_column(self, column_type: str) -> bool:
        """Check if the column is a text-based column."""
        text_types = {'char', 'text', 'varchar', 'nvarchar', 'string', 'enum'}
        return any(text_type in column_type.lower() for text_type in text_types)

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Analyze table structure and content."""
        try:
            table_name = query.strip()

            # Verify table exists
            if table_name not in self.db.get_usable_table_names():
                return f"Error: The table '{table_name}' does not exist in the database."

            # Get column information
            columns = self._get_column_info(table_name, run_manager)
            if not columns:
                return f"Error: Could not retrieve column information for table '{table_name}'"

            # Filter for text columns only
            text_columns = [col for col in columns if self._is_text_column(col["type"])]

            if not text_columns:
                return f"No text columns found in table '{table_name}'"

            # Build simple output with just column names and values
            output_lines = []

            for col in text_columns:
                column_name = col["name"]
                distinct_values = self._get_distinct_values(table_name, column_name, run_manager)

                if distinct_values:
                    values = [str(val[0]) for val in distinct_values]
                    output_lines.append(f"{column_name}: {', '.join(values)}")

            return "\n".join(output_lines) if output_lines else "No distinct values found in text columns"

        except Exception as e:
            return f"An error occurred while analyzing the table: {str(e)}"


class ExtendedSQLDatabaseToolkit(BaseToolkit):
    """Toolkit for interacting with SQL databases, including custom tools."""

    db: SQLDatabase = Field(exclude=True)
    llm: BaseLanguageModel = Field(exclude=True)

    @property
    def dialect(self) -> str:
        """Return string representation of SQL dialect to use."""
        return self.db.dialect

    class Config:
        """Configuration for this pydantic object."""
        arbitrary_types_allowed = True

    def get_tools(self) -> List[BaseTool]:

        ################################################################
        """Get the tools in the toolkit, including custom tools."""
        list_sql_database_tool = ListSQLDatabaseTool(db=self.db)

        ################################################################

        info_sql_database_tool_description = (
            "Input to this tool is a comma-separated list of tables, output is the "
            "schema and sample rows for those tables. "
            "Be sure that the tables actually exist by calling "
            f"{list_sql_database_tool.name} first! "
            "Example Input: table1, table2, table3"
        )
        info_sql_database_tool = InfoSQLDatabaseTool(
            db=self.db, description=info_sql_database_tool_description
        )

        ##########################################################
        understand_table_context_tool = InfoSQLDatabaseTool_2(
        db=self.db,
        description=(
            "Provides the schema and all column name and each unique column values for specified tables."
            "Always use this tools when you need more context on database, such as checking whether which column needed to be used to answer questions"
            "Be sure that the tables actually exist by calling "
            f"{list_sql_database_tool.name} first!"
            )
        )

        ########################################################

        query_sql_database_tool_description = (
            "Input to this tool is a detailed and correct SQL query, output is a "
            "result from the database. If the query is not correct, an error message "
            "will be returned. If an error is returned, rewrite the query, check the "
            "query, and try again. If you encounter an issue with Unknown column "
            f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
            "to query the correct table fields."

        )
        query_sql_database_tool = QuerySQLDataBaseTool(
            db=self.db, description=query_sql_database_tool_description
        )
        
        ###################################################################################


        query_sql_checker_tool_description = (
            "Use this tool to double check if your query is correct before executing "
            "it. Always use this tool before executing a query with "
            f"{query_sql_database_tool.name}!"
        )
        query_sql_checker_tool = QuerySQLCheckerTool(
            db=self.db, llm=self.llm, description=query_sql_checker_tool_description
        )


        ##################################################################################
        column_validation_tool_description = (
            "Use this tool when you have created a sql query with WHERE clause, but after executing SQL, the result is None or no result or error, and you want to validate if the column values are correct for specific column name"
            "Input should be a table name, column name and column values to check. Always use this tool before executing a query with "
            f"{query_sql_database_tool.name}"
        )
        column_validation_tool = ColumnValidationTool(
            db=self.db, description=column_validation_tool_description
        )


        #################################################################################
        
        # Add the new distance calculation tool
        distance_calculation_tool = DistanceCalculationTool(
            db=self.db,
            description=(
                "Use this tool to find properties within a specified radius from reference coordinates. "
                "Input should be 'table_name,latitude_column,longitude_column,ref_latitude,ref_longitude,radius_km"
            )
        )




        return [
            query_sql_database_tool,
            info_sql_database_tool,
            list_sql_database_tool,
            query_sql_checker_tool,
            column_validation_tool,
            distance_calculation_tool,
            understand_table_context_tool,
        ]

    def get_context(self) -> dict:
        """Return db context that you may want in agent prompt."""
        return self.db.get_context()

def create_sql_agent_with_extra_tools(
    llm: BaseLanguageModel,
    db: SQLDatabase,
    **kwargs: Any
) -> Any:
    toolkit = ExtendedSQLDatabaseToolkit(db=db, llm=llm)

    return create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True,
        **kwargs
    )



In [7]:
# Create the agent
agent_executor = create_sql_agent_with_extra_tools(
    llm=llm,
    db=db_local,
    handle_parsing_errors=True
)

# thought_response = agent_executor.invoke("what is the average house price in Urban12")
# thought_response = agent_executor.invoke("Find Facility within 50km of coordinates 2.8,101.5")
# thought_response = agent_executor.invoke("List down all facilty related to healthcare")
thought_response = agent_executor.invoke("How many tools do you have")
print(thought_response)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThe question "How many tools do you have" does not seem related to the database. It appears to be asking about the capabilities of the agent rather than querying the database itself.

Final Answer: I don't know[0m

[1m> Finished chain.[0m
{'input': 'How many tools do you have', 'output': "I don't know"}
