In [2]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import  ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import AIMessage
import os
from dotenv import load_dotenv

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from langchain.schema import BaseOutputParser
import re

class SQLQueryParser(BaseOutputParser):
    """Parses and extracts raw SQL queries from LLM responses."""
    
    def parse(self, text: AIMessage) -> str:
        sql_match = re.search(r"```sql\n(.*?)\n```", text, re.DOTALL)
        if sql_match:
            return sql_match.group(1).strip().replace('\n',' ') 
        return text.strip().replace('\n',' ') 

parser = SQLQueryParser()

In [4]:
load_dotenv()
key = os.getenv('GEMINI_API_KEY')

In [5]:
template = """Based on the table schema provided below, write only the PostgreSQL query that would answer the user's question:
{schema}

Question : {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)

In [6]:
def get_schema(db):
    schema = db.get_table_info()
    return schema

In [7]:
mysql_uri = 'mysql+pymysql://root:password@localhost:3306/HealthInsuraceEnquirySystem'
db = SQLDatabase.from_uri(mysql_uri)

In [8]:
get_schema(db)

'\nCREATE TABLE `Claim` (\n\t`claimID` INTEGER NOT NULL, \n\t`memberID` INTEGER, \n\t`serviceID` INTEGER, \n\tservice_date DATE, \n\tprovider_name VARCHAR(100), \n\tprovider_id INTEGER, \n\tclaim_amount DECIMAL(10, 2), \n\tPRIMARY KEY (`claimID`)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from Claim table:\nclaimID\tmemberID\tserviceID\tservice_date\tprovider_name\tprovider_id\tclaim_amount\n301\t1\t201\t2021-02-15\tCity Hospital\t501\t90.00\n302\t2\t202\t2020-06-20\tDowntown Clinic\t502\t130.00\n303\t3\t203\t2019-04-10\tER Center\t503\t450.00\n*/\n\n\nCREATE TABLE `Coverage` (\n\t`planID` INTEGER, \n\t`serviceID` INTEGER NOT NULL, \n\tservice_name VARCHAR(100), \n\tallowed_amount DECIMAL(10, 2), \n\tcopay DECIMAL(5, 2), \n\tcoinsurance DECIMAL(5, 2), \n\tPRIMARY KEY (`serviceID`)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from Coverage table:\nplanID\tserviceID\tservice_name\tallowed_amount\tcopay\tcoinsurance

In [9]:
sql_llm = ChatGoogleGenerativeAI(model="tunedModels/texttopostgres-weounsumvwtl",
                             google_api_key = key,
                            temperature=0,
                            max_tokens=None,
                            timeout=None,
                            max_retries=2,
                            generation_config={"response_mime_type": "text/sql"})

In [10]:
sql_llm = ChatOpenAI(openai_api_base= "http://192.168.10.73:1234/v1",
    openai_api_key=key,
    model_name="qwen-2.5-3b-text_to_sql"
)

In [11]:
chain = RunnablePassthrough.assign(schema= lambda x : get_schema(db)) | prompt | sql_llm.bind(stop=["\nSQLResult:"]) | parser

In [40]:
user_query = 'How many members are there in the member table'
query = chain.invoke({'question':user_query})

In [41]:
def convert_mysql_to_postgres(sql_query: str) -> str:
    """
    Converts a MySQL-style SQL query with backticks to PostgreSQL format with double quotes.
    """
    # Replace backticks with double quotes
    converted_query = re.sub(r'`([^`]*)`', r'"\1"', sql_query)
    return converted_query

In [42]:
query = query.split('SQL Query')
query = query[1].strip(':')
print(convert_mysql_to_postgres(query))

 SELECT COUNT(*) as total_members FROM Member;


In [168]:
full_template = """ You are a HealthCare Claims Inquiry Agent. You are supposed to answer queries realted to the claims 
raised by members. Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""
full_prompt = ChatPromptTemplate.from_template(full_template)

In [169]:
def run_query(query):
    print("Query : ",query)
    return db.run(query)

In [170]:
full_chain = (
    RunnablePassthrough.assign(query=chain).assign(
        schema=lambda x : get_schema(db),
        response=lambda vars: run_query(vars["query"]),
    )
    | full_prompt
    | reason_llm
)


In [171]:
ques = "How many claims are present in the databse ?"
full_chain.invoke({"question":ques}).content

Query :  SELECT   COUNT(*) FROM Claim;


'There are 10 claims in the database.'

In [43]:
def get_schema():
    f = open('schema.txt','r')
    return f.read()