In [1]:
import nest_asyncio
nest_asyncio.apply()
import uvicorn
import websockets

In [2]:

import asyncio
import json
from typing import AsyncGenerator, List, Dict, Annotated, TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_openai.chat_models.base import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse

# LLM 초기화
llm = ChatOpenAI(model="gpt-4o-mini", temperature=1)

class RealEstateState(TypedDict):
    real_estate_type: Annotated[str, "부동산 유형"]
    keywordlist: Annotated[List[Dict], "키워드 리스트"]
    messages: Annotated[list, add_messages]
    query_sql: Annotated[str, "SQL 쿼리"]
    results: Annotated[List[Dict], "쿼리 결과"]
    answers: Annotated[List[str], "최종 답변"]
    query_answer: Annotated[str, "답변"]

# 비동기 스트리밍
async def stream_llm_response(messages) -> AsyncGenerator[str, None]:
    async for chunk in llm.stream(messages):
        yield chunk["content"]

# SQLite 데이터베이스 초기화
def get_db_engine(db_path: str):
    return create_engine(
        f"sqlite:///{db_path}",
        poolclass=StaticPool,
        connect_args={"check_same_thread": False}
    )

db_path = "./data/real_estate_(1).db"
db_engine = get_db_engine(db_path)
db = SQLDatabase(engine=db_engine, sample_rows_in_table_info=False)

# 그래프 정의
workflow = StateGraph(RealEstateState)

async def filter_node(state: RealEstateState) -> RealEstateState:
    system_prompt = "Classify if a question is related to real estate."
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=state["messages"][-1].content)
    ]
    response_content = ""
    async for chunk in stream_llm_response(messages):
        response_content += chunk
    real_estate_type = response_content.strip()
    return RealEstateState(real_estate_type=real_estate_type)

async def extract_keywords(state: RealEstateState) -> RealEstateState:
    system_prompt = "Extract relevant keywords based on database schema."
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=state["messages"][-1].content)
    ]
    response_content = ""
    async for chunk in stream_llm_response(messages):
        response_content += chunk
    keywordlist = json.loads(response_content)
    return RealEstateState(keywordlist=keywordlist)

async def generate_sql(state: RealEstateState) -> RealEstateState:
    table_info = db.get_table_info(table_names=["rentals", "property_info", "property_locations"])
    system_prompt = f"Available tables: {table_info}"
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=state["messages"][-1].content)
    ]
    response_content = ""
    async for chunk in stream_llm_response(messages):
        response_content += chunk
    query_sql = response_content.strip()
    return RealEstateState(query_sql=query_sql)

async def run_query(state: RealEstateState) -> RealEstateState:
    tool = QuerySQLDataBaseTool(db=db)
    results = tool._run(state["query_sql"])
    return RealEstateState(results=results)

async def generate_answers(state: RealEstateState) -> RealEstateState:
    results = state["results"]
    if not results:
        return RealEstateState(answers="검색 결과가 없습니다.")

    system_prompt = "Convert query results into human-readable answers."
    messages = [SystemMessage(content=system_prompt), HumanMessage(content=json.dumps(results))]
    response_content = ""
    async for chunk in stream_llm_response(messages):
        response_content += chunk
    answers = json.loads(response_content)
    return RealEstateState(answers=answers)

# 노드 추가
workflow.add_node("Filter Question", filter_node)
workflow.add_node("Extract Keywords", extract_keywords)
workflow.add_node("Generate SQL", generate_sql)
workflow.add_node("Run Query", run_query)
workflow.add_node("Generate Answers", generate_answers)

# 엣지 정의
workflow.add_edge(START, "Filter Question")
workflow.add_edge("Filter Question", "Extract Keywords")
workflow.add_edge("Extract Keywords", "Generate SQL")
workflow.add_edge("Generate SQL", "Run Query")
workflow.add_edge("Run Query", "Generate Answers")
workflow.add_edge("Generate Answers", END)

# 그래프 컴파일
app = workflow.compile()

# FastAPI 앱 초기화
fastapi_app = FastAPI()

# HTML 페이지 제공
@fastapi_app.get("/")
async def get():
    return HTMLResponse("""
    <!DOCTYPE html>
    <html>
    <head>
        <title>WebSocket Test</title>
    </head>
    <body>
        <h1>WebSocket Test</h1>
        <textarea id="messageInput" placeholder="Enter your message"></textarea>
        <button onclick="sendMessage()">Send</button>
        <pre id="messages"></pre>
        <script>
            const ws = new WebSocket("ws://127.0.0.1:8000/ws");
            ws.onmessage = event => {
                const messages = document.getElementById("messages");
                messages.textContent += event.data + "\\n";
            };
            function sendMessage() {
                const input = document.getElementById("messageInput");
                ws.send(input.value);
                input.value = "";
            }
        </script>
    </body>
    </html>
    """)

# WebSocket 엔드포인트
@fastapi_app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Received: {data}")
    except Exception as e:
        await websocket.send_text(f"Error: {str(e)}")
    finally:
        await websocket.close()

# WebSocket 클라이언트 테스트
async def test_websocket():
    uri = "ws://127.0.0.1:8000/ws"
    async with websockets.connect(uri) as websocket:
        await websocket.send("서울 아파트 매매 가격 알려줘")
        response = await websocket.recv()
        print(f"Server response: {response}")
        await websocket.send("exit")  # 종료 메시지 전송

# 서버 실행 및 테스트 클라이언트 실행
async def main():
    # FastAPI 서버 실행
    config = uvicorn.Config(fastapi_app, host="127.0.0.1", port=8000, log_level="info")
    server = uvicorn.Server(config)
    server_task = asyncio.create_task(server.serve())

    # 서버 준비 대기 (약간의 지연을 추가)
    await asyncio.sleep(2)

    # WebSocket 클라이언트 테스트 실행
    await test_websocket()

# 전체 실행
if __name__ == "__main__":
    asyncio.run(main())


INFO:     Started server process [5232]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     ('127.0.0.1', 55779) - "WebSocket /ws" [accepted]
INFO:     connection open


Server response: Received: 서울 아파트 매매 가격 알려줘


INFO:     127.0.0.1:55783 - "GET / HTTP/1.1" 200 OK


INFO:     ('127.0.0.1', 55785) - "WebSocket /ws" [accepted]
INFO:     connection open


In [2]:
import asyncio
import websockets

async def test_websocket():
    uri = "ws://127.0.0.1:8000/ws"
    async with websockets.connect(uri) as websocket:
        await websocket.send("서울 아파트 매매 가격 알려줘")
        while True:
            response = await websocket.recv()
            print(response)

asyncio.run(test_websocket())




ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "c:\Users\USER\anaconda3\envs\nlp\lib\site-packages\uvicorn\protocols\websockets\websockets_impl.py", line 330, in asgi_send
    await self.send(data)  # type: ignore[arg-type]
  File "c:\Users\USER\anaconda3\envs\nlp\lib\site-packages\websockets\legacy\protocol.py", line 628, in send
    await self.ensure_open()
  File "c:\Users\USER\anaconda3\envs\nlp\lib\site-packages\websockets\legacy\protocol.py", line 938, in ensure_open
    raise self.connection_closed_exc()
websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "c:\Users\USER\anaconda3\envs\nlp\lib\site-packages\starlette\websockets.py", line 85, in send
    await self._send(message)
  File "c:\Users\USER\anaconda3\envs\nlp\lib\site-packages\starlette\_exception_handler.py", line 39, in sender
    await send(m

In [None]:
from fastapi.responses import HTMLResponse

@fastapi_app.get("/")
async def get():
    return HTMLResponse("""
    <!DOCTYPE html>
    <html>
    <head>
        <title>WebSocket Test</title>
    </head>
    <body>
        <h1>WebSocket Test</h1>
        <textarea id="messageInput" placeholder="Enter your message"></textarea>
        <button onclick="sendMessage()">Send</button>
        <pre id="messages"></pre>
        <script>
            const ws = new WebSocket("ws://127.0.0.1:8000/ws");
            ws.onmessage = event => {
                const messages = document.getElementById("messages");
                messages.textContent += event.data + "\\n";
            };
            function sendMessage() {
                const input = document.getElementById("messageInput");
                ws.send(input.value);
                input.value = "";
            }
        </script>
    </body>
    </html>
    """)
