In [1]:
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from typing import Annotated, Dict, List, Optional, Sequence, Set
from dataclasses import dataclass, field, replace
from typing_extensions import TypedDict
import re

@dataclass
class InputState:
    """Defines the input state for the agent, representing a narrower interface to the outside world.

    This class is used to define the initial state and structure of incoming data.
    """
    messages : Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list)
    tables : List[str] = field(default_factory=list)

@dataclass
class AgentState:
    remaining_datafetch: int = 1
    remaining_querygen: int = 3
    question : str = ""
    rewritten_question : str = ""
    updated_question : str = ""
    relevant_queries : List[str] = field(default_factory=list)
    relevant_tables : List[str] = field(default_factory=list)
    relevant_columns : List[str] = field(default_factory=list)
    already_seen_chunk_column: Set[str] = field(default_factory=set)
    already_seen_chunk_table: Set[str] = field(default_factory=set)
    error_message : str = ""
    explanation : str = ""
    is_sufficient_data : bool = False
    sql_query : str = ""
    query_executed_successfully : bool = False
    result : List[str] =  field(default_factory=list)

@dataclass
class OutputState(TypedDict):
    """Defines the output state for the agent, representing the expected response structure.

    This class is used to define the output format and structure of the agent's response.
    """
    sql_query : str = ""
    explanation : str = ""
    used_tables : List[str] = field(default_factory=list)
    used_columns : List[str] = field(default_factory=list)

In [2]:
##### Question rewriting and generation format class ####
from pydantic import BaseModel, Field
from typing import List

class TableColumnPlan(BaseModel):
    table_name: str = Field(
        ...,
        description="The name of the table that will be used in the SQL query."
    )
    expected_columns: List[str] = Field(
        ...,
        description=(
            "A list of descriptions of the types of data expected from this table. "
            "Do not use exact column names. Instead, describe what kind of information is needed. "
            "Example: 'user ID of the employee', 'Name of the employee', 'name of the skill', 'company identifier', 'Price of the item',etc. This type of short description"
        )
    )
    purpose: str = Field(
        ...,
        description="A short explanation of why this table is needed in the context of answering the query."
    )

class QueryRewritePlan(BaseModel):
    original_question: str = Field(
        ...,
        description="The user's original natural language question."
    )
    rewritten_question: str = Field(
        ...,
        description=(
            "A rewritten version of the original question that clearly explains the intent, "
            "mentions relevant entities, and makes relationships between tables obvious. "
            "This should not be a SQL query, but a detailed version of the natural language question."
        )
    )
    updated_question: str = Field(
        ...,
        description=(
            "An enhanced version of the original question that explicitly mentions what the user "
            "expects to see in the result — including useful fields like names, IDs, totals, etc. "
            "If you think it required filtering or sorting then also mention that data on which you want to filter or sort. "
            "Keep it in plain natural language (no SQL terms), and do not mention specific table or column names."
        )
    )


In [3]:
from typing import Literal
from pydantic import BaseModel, Field

class SQLQueryResponse(BaseModel):
    """
    Structured response for SQL query generation.
    """
    sql_query: str = Field(
        ...,
        description="""The generated SQL query string. the SQL query string wrapped in backticks like:
        ```sql
        SELECT * FROM ... WHERE department = 'Science';
        ```
        If insufficient data, write N/A instead of query."""
    )
    explanation: str = Field(
        ...,
        description="Step-by-step explanation of how the query was created, or if insufficient data, explain what is missing."
    )
    is_sufficient_data: bool = Field(
        ...,
        description="True if all required tables and columns were provided to generate a valid query, else False."
    )

In [4]:
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage

def rewrite_question(state : InputState)-> AgentState:
    """Rewrites the user's question to make it more explicit and detailed for SQL planning."""
    question = state.messages[-1].content if state.messages else ""
    tables = state.tables if state.tables else []
    
    prompt = PromptTemplate.from_template("""
    You are a SQL planning assistant.

    Given a user's natural language question:
    1. Rewrite it in a clear, explicit and detailed form that makes all the necessary table and column relationships obvious.
    2. Make sure the rewritten version **asks to display all useful columns**
    3. If the original question is vague (e.g., “top students”), clarify the logic (e.g., “students with the highest average grade”).
    4. Use terminology similar to column or table names when possible (e.g., "student name" instead of "who" or "person").
    5. Identify which tables are required.
    6. For each table, list the descriptions of columns you expect to use and why.
    7. A form that can help a retriever or another AI system better identify relevant tables or columns.
                
    Return your result in the required structured format.
                                        
    ## This are the tables : {table}


    Original Question: {question}
    """)

    
    question_rewrite_chain = prompt | llm.with_structured_output(QueryRewritePlan)
    result = question_rewrite_chain.invoke({
        "question": question,
        "table": tables
    })

    return AgentState(
            remaining_datafetch=1,  # default value
            remaining_querygen=3,   # default value
            question=question,
            rewritten_question=result.rewritten_question,
            updated_question=result.updated_question,
            relevant_queries=[],
            relevant_tables=[],
            relevant_columns=[],
            already_seen_chunk_column=set(),
            already_seen_chunk_table=set(),
            error_message="",
            explanation="",
            is_sufficient_data=False,
            sql_query="",
            query_executed_successfully=False,
            result=[]
        )


In [5]:
tables_descrptions = {
"Departments": "The Departments table holds academic department information. Each department may be associated with multiple students, professors, and courses. Use it when filtering by department or analyzing department-specific activity.",
"Students": "The Students table contains personal and academic enrollment details for students. Each student is associated with a department and can be enrolled in multiple courses. Use this table to filter students by department or analyze student enrollment.",
"Professors": "The Professors table stores data about academic faculty. Each professor belongs to a department and may teach one or more courses. Use this when identifying course instructors or analyzing departmental faculty.",
"Courses": "The Courses table contains the list of academic courses offered by departments. Each course is taught by a professor and belongs to a department. Use it for curriculum planning, instructor assignments, or course scheduling.",
"Enrollments": "The Enrollments table tracks which students are enrolled in which courses, including the date of enrollment. Use it to analyze student course load, popularity of courses, or to join student and course data.",
"Classrooms": "The Classrooms table defines physical classrooms, including building and room number, and capacity. It supports the scheduling system for courses. Use it to analyze classroom utilization or room assignments.",
"Schedules": "The Schedules table links courses to classrooms and timeslots, including day of the week and start/end times. Use it to retrieve timetable information or detect scheduling conflicts.",
"Grades": "The Grades table records student performance in courses by linking enrollment to a grade and grade point. Use it for GPA calculation, performance tracking, and academic reports.",
"Assignments": "The Assignments table contains homework or project records tied to specific courses. Each assignment has a title, due date, and total marks. Use this table for academic workload or submission tracking.",
"Submissions": "The Submissions table tracks which students submitted which assignments and the marks they received. Use this to evaluate assignment performance and submission timing.",
"Clubs": "The Clubs table holds information about student clubs, including their faculty advisor and founding year. Use this when analyzing extracurricular involvement or managing student organizations.",
"ClubMembers": "The ClubMembers table links students to clubs and records their role (e.g., President, Member). Use this table to retrieve club rosters, roles, or member count."
} 

columns_descrptions = {
  "Departments": [
    {
      "column_name": "department_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each academic department."
    },
    {
      "column_name": "department_name",
      "data_type": "TEXT",
      "description": "The name of the department."
    },
    {
      "column_name": "head_of_department_id",
      "data_type": "INTEGER",
      "description": "The ID of the professor who is the head of the department."
    },
    {
      "column_name": "founding_year",
      "data_type": "INTEGER",
      "description": "The year the department was founded."
    }
  ],
  "Students": [
    {
      "column_name": "student_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each student."
    },
    {
      "column_name": "first_name",
      "data_type": "TEXT",
      "description": "The student's first name."
    },
    {
      "column_name": "last_name",
      "data_type": "TEXT",
      "description": "The student's last name."
    },
    {
      "column_name": "email",
      "data_type": "TEXT",
      "description": "The student's email address."
    },
    {
      "column_name": "department_id",
      "data_type": "INTEGER",
      "description": "The ID of the department the student is enrolled in."
    },
    {
      "column_name": "enrollment_date",
      "data_type": "DATE",
      "description": "The date the student was enrolled."
    }
  ],
  "Professors": [
    {
      "column_name": "professor_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each professor."
    },
    {
      "column_name": "first_name",
      "data_type": "TEXT",
      "description": "The professor's first name."
    },
    {
      "column_name": "last_name",
      "data_type": "TEXT",
      "description": "The professor's last name."
    },
    {
      "column_name": "email",
      "data_type": "TEXT",
      "description": "The professor's email address."
    },
    {
      "column_name": "department_id",
      "data_type": "INTEGER",
      "description": "The ID of the department the professor belongs to."
    },
    {
      "column_name": "hire_date",
      "data_type": "DATE",
      "description": "The date the professor was hired."
    },
    {
      "column_name": "salary",
      "data_type": "REAL",
      "description": "The professor's annual salary."
    }
  ],
  "Courses": [
    {
      "column_name": "course_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each course."
    },
    {
      "column_name": "course_name",
      "data_type": "TEXT",
      "description": "The full name of the course."
    },
    {
      "column_name": "course_code",
      "data_type": "TEXT",
      "description": "A unique code for the course."
    },
    {
      "column_name": "department_id",
      "data_type": "INTEGER",
      "description": "The ID of the department that offers the course."
    },
    {
      "column_name": "professor_id",
      "data_type": "INTEGER",
      "description": "The ID of the professor teaching the course."
    },
    {
      "column_name": "credits",
      "data_type": "INTEGER",
      "description": "The number of credits for the course."
    }
  ],
  "Enrollments": [
    {
      "column_name": "enrollment_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each enrollment."
    },
    {
      "column_name": "student_id",
      "data_type": "INTEGER",
      "description": "The ID of the student who is enrolled."
    },
    {
      "column_name": "course_id",
      "data_type": "INTEGER",
      "description": "The ID of the course the student is enrolled in."
    },
    {
      "column_name": "enrollment_date",
      "data_type": "DATE",
      "description": "The date of enrollment."
    }
  ],
  "Classrooms": [
    {
      "column_name": "classroom_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each classroom."
    },
    {
      "column_name": "building",
      "data_type": "TEXT",
      "description": "The building name of the classroom."
    },
    {
      "column_name": "room_number",
      "data_type": "TEXT",
      "description": "The room number of the classroom."
    },
    {
      "column_name": "capacity",
      "data_type": "INTEGER",
      "description": "The maximum capacity of the classroom."
    }
  ],
  "Schedules": [
    {
      "column_name": "schedule_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each schedule entry."
    },
    {
      "column_name": "course_id",
      "data_type": "INTEGER",
      "description": "The ID of the course being scheduled."
    },
    {
      "column_name": "classroom_id",
      "data_type": "INTEGER",
      "description": "The ID of the classroom assigned for the course."
    },
    {
      "column_name": "day_of_week",
      "data_type": "TEXT",
      "description": "The day of the week for the class."
    },
    {
      "column_name": "start_time",
      "data_type": "TIME",
      "description": "The start time of the class."
    },
    {
      "column_name": "end_time",
      "data_type": "TIME",
      "description": "The end time of the class."
    }
  ],
  "Grades": [
    {
      "column_name": "grade_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each grade record."
    },
    {
      "column_name": "enrollment_id",
      "data_type": "INTEGER",
      "description": "The ID of the enrollment record this grade belongs to."
    },
    {
      "column_name": "grade_letter",
      "data_type": "TEXT",
      "description": "The final grade letter (e.g., 'A', 'B+')."
    },
    {
      "column_name": "grade_point",
      "data_type": "REAL",
      "description": "The GPA value for the grade."
    },
    {
      "column_name": "grade_date",
      "data_type": "DATE",
      "description": "The date the grade was issued."
    }
  ],
  "Assignments": [
    {
      "column_name": "assignment_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each assignment."
    },
    {
      "column_name": "course_id",
      "data_type": "INTEGER",
      "description": "The ID of the course the assignment belongs to."
    },
    {
      "column_name": "assignment_title",
      "data_type": "TEXT",
      "description": "The title of the assignment."
    },
    {
      "column_name": "due_date",
      "data_type": "DATE",
      "description": "The due date for the assignment."
    },
    {
      "column_name": "total_marks",
      "data_type": "INTEGER",
      "description": "The maximum marks for the assignment."
    }
  ],
  "Submissions": [
    {
      "column_name": "submission_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each submission."
    },
    {
      "column_name": "assignment_id",
      "data_type": "INTEGER",
      "description": "The ID of the assignment being submitted."
    },
    {
      "column_name": "student_id",
      "data_type": "INTEGER",
      "description": "The ID of the student who submitted the assignment."
    },
    {
      "column_name": "submission_date",
      "data_type": "DATE",
      "description": "The date of the submission."
    },
    {
      "column_name": "marks_obtained",
      "data_type": "INTEGER",
      "description": "The marks the student received on the assignment."
    }
  ],
  "Clubs": [
    {
      "column_name": "club_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each club."
    },
    {
      "column_name": "club_name",
      "data_type": "TEXT",
      "description": "The name of the student club."
    },
    {
      "column_name": "faculty_advisor_id",
      "data_type": "INTEGER",
      "description": "The ID of the professor who advises the club."
    },
    {
      "column_name": "founding_year",
      "data_type": "INTEGER",
      "description": "The year the club was founded."
    }
  ],
  "ClubMembers": [
    {
      "column_name": "member_id",
      "data_type": "INTEGER",
      "description": "Unique identifier for each club member record."
    },
    {
      "column_name": "student_id",
      "data_type": "INTEGER",
      "description": "The ID of the student who is a member of the club."
    },
    {
      "column_name": "club_id",
      "data_type": "INTEGER",
      "description": "The ID of the club the student is a member of."
    },
    {
      "column_name": "role",
      "data_type": "TEXT",
      "description": "The role of the student in the club (e.g., 'President')."
    },
    {
      "column_name": "join_date",
      "data_type": "DATE",
      "description": "The date the student joined the club."
    }
  ]
}


In [None]:

class QueryRewritePlan(BaseModel):
    original_question: str = Field(
        ...,
        description="The user's original natural language question."
    )
    rewritten_question: str = Field(
        ...,
        description=(
            "A rewritten version of the original question that clearly explains the intent, "
            "mentions relevant entities, and makes relationships between tables obvious. "
            "This should not be a SQL query, but a detailed version of the natural language question."
        )
    )
    updated_question: str = Field(
        ...,
        description=(
            "An enhanced version of the original question that explicitly mentions what the user "
            "expects to see in the result — including useful fields like names, IDs, totals, etc. "
            "If you think it required filtering or sorting then also mention that data on which you want to filter or sort. "
            "Keep it in plain natural language (no SQL terms), and do not mention specific table or column names."
        )
    )


In [6]:
columns_descrptions.keys()
reformat_columns_descriptions = []

def get_required_columns(table_name: str) -> List[Dict[str, str]]:
    for table, column_description in columns_descrptions.items():
        if table != table_name:
            continue
        tmp = {}
        tmp['table_name'] = table
        tmp['columns'] = column_description
        return tmp
    return None

In [None]:
table_system_prompt = """You are an expert database assistant. Your task is to identify and select the most relevant tables from a given database schema to answer the user's question.
You will be provided with a list of tables, each with a `tableName` and a `tableDescription`. Analyze the user's natural language query and return a JSON list of the tableName's that are essential for answering it. 
Additionally, for each selected table, provide a reason explaining why it is relevant to the user's question. Focus solely on matching the user's intent with the table descriptions and justify your selection."""
table_user_prompt = """**Database Schema:**
{table_descriptions}
**Query:**
{question}
"""

column_system_prompt = """You are an intelligent SQL query builder. You have already identified the relevant tables to answer a user's query. Now, your task is to select the specific columns required from those tables.
You will be provided with the user's original query and a list of the relevant tables, each containing its columns with `columnName`, `data_type`  and `columnDescription`. Analyze the user's query and the column descriptions, then return a JSON list of the fully qualified column names (in `tableName.columnName` format) that are needed to fulfill the user's request.
Additionally, for each selected column, provide a reason explaining why it is required to answer the user's question."""

columns_user_prompt = """**Original Query:**
{question}
**Relevant Tables:**
{relevant_tables}
"""

from langchain.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages([
    ("system", table_system_prompt),
    ("user", table_user_prompt)
])


prompt.invoke({
    "table_descriptions": tables_descrptions,
    "question": "What are the names of all students in the Science department?"
})

print(p)




def get_table_and_columns(state : AgentState)->AgentState:

    prompt = ChatPromptTemplate.from_messages([
    ("system", table_system_prompt),
    ("user", table_user_prompt)
    ])


prompt.invoke({
    "table_descriptions": tables_descrptions,
    "question": "What are the names of all students in the Science department?"
})



class TableselectionPlan(BaseModel):
    relevant_tables: List[str] = Field(
        ...,
        description=(
            "A list of table names that are relevant to the user's question. "
            "These tables should be selected based on their descriptions and relevance to the question. "
            "The table names should be in a simple list format, without any additional information."
    ))
    reason: str = Field(
        ...,
        description=(
            "A brief explanation of why the selected tables are relevant to the user's question. "
            "This should justify the inclusion of each table based on its description and the context of the question."
        )
    )

class ColumnsSelectionPlan(BaseModel):  
    relevant_columns: List[str] = Field(
        ...,
        description=(
            "A list of fully qualified column names (in `tableName.columnName` format) that are required to answer the user's question. "
            "These columns should be selected based on their relevance to the question and the previously selected tables."
        )
    )
    reason: str = Field(
        ...,
        description=(
            "A brief explanation of why each selected column is necessary to answer the user's question. "
            "This should justify the inclusion of each column based on its description and the context of the question."
        )
    )



messages=[SystemMessage(content="You are an expert database assistant. Your task is to identify and select the most relevant tables from a given database schema to answer the user's question.\nYou will be provided with a list of tables, each with a `tableName` and a `tableDescription`. Analyze the user's natural language query and return a JSON list of the tableName's that are essential for answering it. Focus solely on matching the user's intent with the table descriptions.", additional_kwargs={}, response_metadata={}), HumanMessage(content="**Database Schema:**\n{'Departments': 'The Departments table holds academic department information. Each department may be associated with multiple students, professors, and courses. Use it when filtering by department or analyzing department-specific activity.', 'Students': 'The Students table contains personal and academic enrollment details for students. Each student is associated with a department and can be enrolled in multiple courses. Use t

In [None]:
def get_table_and_columns(state : AgentState)->AgentState:

    table_fetch = index.search(
        namespace="table-details-university", 
        query={
            "inputs": {"text":  state.rewritten_question}, 
            "top_k": 5
        },
    )
    # print(results.result.hits)

    column_fetch = index.search(
        namespace="column-details-university", 
        query={
            "inputs": {"text":  state.rewritten_question}, 
            "top_k": 10
        },
    )

    relevant_tables = []
    relevant_columns = []


    for docs in table_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_table:
            relevant_tables.append(docs.fields)
            state.already_seen_chunk_table.add(chunkid)

    for docs in column_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_column:
            relevant_columns.append(docs.fields.get("chunk_text", ""))
            state.already_seen_chunk_column.add(chunkid)

    # print(relevant_tables)
    # print(relevant_columns)
    state.relevant_tables = relevant_tables
    state.relevant_columns = relevant_columns
    
    return state
