In [0]:


# Initialize Spark session
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("YourAppName").getOrCreate()

In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
from databricks_langchain import ChatDatabricks
from databricks.sdk import WorkspaceClient
import os

w = WorkspaceClient()

os.environ["DATABRICKS_HOST"] = w.config.host
os.environ["DATABRICKS_TOKEN"] = w.tokens.create(comment="for model serving", lifetime_seconds=1200).token_value

llm = ChatDatabricks(endpoint="databricks-llama-4-maverick")

In [0]:
from langchain.agents import initialize_agent, Tool
from langchain.tools import tool
from langchain.agents.agent_types import AgentType
# from langchain_anthropic import ChatAnthropic
import sqlite3
import os

@tool
def query_flight_schedule(sql: str) -> str:
    """Run a SQL query on the flight_status.db file."""
    try:
        conn = sqlite3.connect("flight_status.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        col_names = [description[0] for description in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(col_names, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error: {str(e)}"


@tool
def query_geotracking(sql: str) -> str:
    """Run a SQL query on the geotracking.db file."""
    try:
        conn = sqlite3.connect("geo_tracking.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        col_names = [description[0] for description in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(col_names, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error: {str(e)}"
    


@tool
def query_weather(sql: str) -> str:
    """Run a SQL query on the weather.db file."""
    try:
        conn = sqlite3.connect("weather.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        columns = [desc[0] for desc in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(columns, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error in Weather Tool: {e}"

In [0]:
tools = [
    Tool(
        name="ScheduleTrackerTool",
        func=query_flight_schedule,
        description="Use this tool to query scheduled flights and detect conflicts, delays, or tight arrival overlaps. Accepts SQL input."
    ),
    Tool(
        name="GeoTrackerTool",
        func=query_geotracking,
        description="Use this tool to query geospatial data about flight phases and deviations from expected routes."
    ),
    Tool(
        name="WeatherTrackerTool",
        func=query_weather,
        description="Use this tool to query weather_by_flight table to get wind, visibility, storm/fog info, and help determine flight risk."
    )
]

from langchain.memory import ConversationBufferMemory
from langchain.schema import SystemMessage, HumanMessage

system_message = SystemMessage(
    content=(
        "You are an intelligent ATC assistant.\n"
        "When a flight sends a landing request like 'This is flight UN002 requesting to land', "
        "you must:\n"
        "1. Identify the flight ID.\n"
        "2. Use tools to check:\n"
        "   - If the trajectory is clear and aligned.\n"
        "   - Whether weather risk is safe (risk score < 7).\n"
        "   - If any deviation or path correction is causing time conflict with other scheduled flights.\n"
        "3. Respond clearly with:\n"
        "   - Clearance to land or not\n"
        "   - Any issues (e.g., rerouting, weather, gate conflict)\n"
        "   - Suggestions if needed.\n"
    )
)

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

atc_agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    memory=memory,
    handle_parsing_errors=True,
    agent_kwargs={
        "system_message": system_message
    }
)


print("👨‍✈️ ATC Agent Ready. Type 'exit' to quit.\n")

while True:
    user_input = input("🧑‍💻 You: ")
    if user_input.lower() in {"exit", "quit"}:
        break
    response = atc_agent.run(user_input)
    print("🤖 ATC Agent:", response)



# flight_id = "UN002"  # replace with one from your DB

# prompt = f"""
# Give me a summary of ATC report for flight {flight_id}.

# - Check for any delays or schedule conflicts.
# - Check if it's deviated or off-course in geospatial data.
# - Fetch current weather conditions at the flight's current position.
# - Check the weather based on the risk score and summarize all issues found.
# """

# response = atc_agent.run(prompt)
# print("📡 ATC Summary for", flight_id)
# print(response)

In [0]:
%pip install databricks-genai

In [0]:
%restart_python

In [0]:
%%writefile streamlit_app.py
import streamlit as st
from atc_agent import atc_agent  # Make sure this is importable

st.title("🛫 ATC Agent Chat")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

user_input = st.chat_input("Ask the ATC Agent...")

if user_input:
    st.session_state.chat_history.append({"role": "user", "text": user_input})
    with st.spinner("Thinking..."):
        response = atc_agent.run(user_input)
    st.session_state.chat_history.append({"role": "agent", "text": response})

for msg in st.session_state.chat_history:
    if msg["role"] == "user":
        st.chat_message("user").markdown(msg["text"])
    else:
        st.chat_message("assistant").markdown(msg["text"])


In [0]:
!streamlit run streamlit_app.py


In [0]:
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage

from langchain_core.tools import tool

from langgraph.graph import StateGraph, MessagesState, END, START
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from typing_extensions import Literal
from IPython.display import Image, display


tools = [
    Tool(
        name="ScheduleTrackerTool",
        func=query_flight_schedule,
        description="Use this tool to query scheduled flights and detect conflicts, delays, or tight arrival overlaps. Accepts SQL input."
    ),
    Tool(
        name="GeoTrackerTool",
        func=query_geotracking,
        description="Use this tool to query geospatial data about flight phases and deviations from expected routes."
    ),
    Tool(
        name="WeatherTrackerTool",
        func=query_weather,
        description="Use this tool to query weather_by_flight table to get wind, visibility, storm/fog info, and help determine flight risk."
    )
]


tools_by_name = {tool.name: tool for tool in tools}

# from databricks_genai.chat_models import ChatDatabricks

# llm = ChatDatabricks(model="databricks-meta-llama-3-70b-instruct")
# llm_with_tools = create_databricks_agent(llm=llm, tools=tools)

llm_with_tools = llm.bind_tools(tools) 

# llm_with_tools = initialize_agent(
#     tools=tools,
#     llm=llm,
#     agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
#     verbose=True,
#     handle_parsing_errors=True
# )


# Nodes
def llm_call(state: MessagesState):
    """LLM decides whether to call a tool or not"""

    return {
        "messages": [
            llm_with_tools.invoke(
                [
                    SystemMessage(
                        content="You are an ATC Agent. Analyze schedule, geo, and weather risk using tools."
                    )
                ]
                + state["messages"]
            )
        ]
    }


def tool_node(state: dict):
    """Performs the tool call"""

    result = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    return {"messages": result}


def should_continue(state: MessagesState) -> Literal["Action", END]:
    if state["messages"][-1].tool_calls:
        return "Action"
    return END


# Build workflow
agent_builder = StateGraph(MessagesState)

# Add nodes
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("environment", tool_node)

# Add edges to connect nodes
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges(
    "llm_call",
    should_continue,
    {
        # Name returned by should_continue : Name of next node to visit
        "Action": "environment",
        END: END,
    },
)
agent_builder.add_edge("environment", "llm_call")

# Compile the agent
agent = agent_builder.compile()

# Show the agent
display(Image(agent.get_graph(xray=True).draw_mermaid_png()))

# Invoke

messages = [HumanMessage(content="What is the status of flight UA002?")]
final_state = agent.invoke({"messages": messages})

for m in final_state["messages"]:
    m.pretty_print()