# Test Retrieval Graph

This notebook tests the retrieval graph by connecting to the LangGraph server.

**Prerequisites:**
1. Start the LangGraph dev server: `cd backend && npm run dev`
2. Make sure the server is running at `http://localhost:2024`

In [None]:
# Install required packages (run once)
# %pip install langgraph-sdk
# Ollama is running locally for this test.

In [None]:
from langgraph_sdk import get_client

# Connect to the local LangGraph server
# Make sure you've started it with: cd backend && npm run dev
client = get_client(url="http://localhost:2024")

print("‚úÖ Connected to LangGraph server")


In [None]:
# Create a new thread (conversation session)
thread = await client.threads.create()
print(f"‚úÖ Created thread: {thread['thread_id']}")


In [None]:
## Test 1: Simple Query (Non-streaming)

# Test with a simple query
test_query = "What is the capital of France?"

print(f"Query: {test_query}\n")

result = await client.runs.create(
    thread_id=thread['thread_id'],
    assistant_id="retrieval_graph",
    input={"query": test_query},
    
)

print("‚úÖ Graph execution completed!")
print(f"\nResult: {result}")
print(f"\nStatus: {result.get('status', 'unknown')}")


In [None]:
# Get the threads
threads = await client.threads.search()
print(threads)

In [None]:
# Get the thread state which contains messages
thread_state = await client.threads.get(thread_id=thread['thread_id'])

# Access messages from the thread's state
messages = thread_state.get('values', {}).get('messages', [])

print("Messages in thread:")
for msg in messages:
    msg_type = msg.get('type') if isinstance(msg, dict) else getattr(msg, 'type', 'unknown')
    msg_content = msg.get('content') if isinstance(msg, dict) else getattr(msg, 'content', '')
    print(f"\n- {msg_type}: {msg_content}")

In [None]:
import asyncio

# Test with a simple query
test_query = "Add 600 + 3000"

print(f"Query: {test_query}\n")

result = await client.runs.create(
    thread_id=thread['thread_id'],
    assistant_id="retrieval_graph",
    input={"query": test_query},
    # config={
    #     "configurable": {
    #         "queryModel": "ollama/llama3.2:1b"
    #     }
    # }
)

# cancel the run after 100ms
run_id = result.get('run_id') if isinstance(result, dict) else getattr(result, 'run_id', None)
thread_id = thread['thread_id']
if run_id:
    await asyncio.sleep(0.1)
    
    target_run_id = run_id or globals().get('current_run_id')
    if not target_run_id:
        print("‚ùå No run ID found.")
    else:
        # Cancel the run
        await client.runs.cancel(thread_id=thread_id, run_id=target_run_id)
        print(f"‚úÖ Canceled")


# Get the thread state which contains messages
thread_state = await client.threads.get(thread_id=thread['thread_id'])

# Access messages from the thread's state
messages = thread_state.get('values', {}).get('messages', [])

if messages:
    # Get the latest message (last in the list)
    latest = messages[-1]
    latest_content = latest.get('content') if isinstance(latest, dict) else getattr(latest, 'content', None)
    if latest_content:
        print(f"Response: {latest_content}")
print(thread_state.get('values', {}).get('messages', []))

In [None]:
# Get the thread state which contains messages
thread_state = await client.threads.get(thread_id=thread['thread_id'])

# Access messages from the thread's state
messages = thread_state.get('values', {}).get('messages', [])

if messages:
    # Get the latest message (last in the list)
    latest = messages[-1]
    latest_content = latest.get('content') if isinstance(latest, dict) else getattr(latest, 'content', None)
    if latest_content:
        print(f"Response: {latest_content}")
print(thread_state.get('values', {}).get('messages', []))

In [None]:
## Test 1: Simple Query (Non-streaming)

# Test with a simple query
test_query = "What is the capital of France?"

print(f"Query: {test_query}\n")

result = await client.runs.create(
    thread_id=thread['thread_id'],
    assistant_id="retrieval_graph",
    input={"query": test_query},
    
)

print("‚úÖ Graph execution completed!")
print(f"\nResult: {result}")
print(f"\nStatus: {result.get('status', 'unknown')}")


In [None]:
# Get the thread state which contains messages
thread_state = await client.threads.get(thread_id=thread['thread_id'])

# Access messages from the thread's state
messages = thread_state.get('values', {}).get('messages', [])

print("Messages in thread:")
for msg in messages:
    msg_type = msg.get('type') if isinstance(msg, dict) else getattr(msg, 'type', 'unknown')
    msg_content = msg.get('content') if isinstance(msg, dict) else getattr(msg, 'content', '')
    print(f"\n- {msg_type}: {msg_content}")

In [None]:
## Test 2: Streaming Response

# Test with streaming
test_query_2 = "Tell me a joke"

print(f"Query: {test_query_2}\n")
print("Streaming response:")

# Remove 'await' here - stream() returns an async generator directly
stream = client.runs.stream(
    thread_id=thread['thread_id'],
    assistant_id="retrieval_graph",
    input={"query": test_query_2},
    stream_mode=["messages", "updates"],
    config={
        "configurable": {
            "queryModel": "ollama/llama3.2:1b"
        }
    }
)

async for chunk in stream:
    chunk_event = chunk.get('event') if isinstance(chunk, dict) else getattr(chunk, 'event', None)
    if chunk_event == 'messages':
        chunk_data = chunk.get('data') if isinstance(chunk, dict) else getattr(chunk, 'data', None)
        if chunk_data:
            for msg in chunk_data:
                msg_content = msg.get('content') if isinstance(msg, dict) else getattr(msg, 'content', None)
                if msg_content:
                    print(f"üì® {msg_content}")
    elif chunk_event == 'updates':
        print(f"üîÑ Update: {chunk}")
    else:
        print(f"üì¶ Chunk: {chunk}")

print("\n‚úÖ Streaming completed!")

In [None]:
## Test 3: Multiple Queries in Same Thread

thread = await client.threads.create()
# Test multiple queries in the same thread (conversation history)
queries = [
    "What is 2 + 2?",
    "What about 3 + 3?",
    "Can you add those two answers together?"
]

import asyncio

for i, query in enumerate(queries, 1):
    print(f"\n{'='*50}")
    print(f"Query {i}: {query}")
    print('='*50)
    
    result = await client.runs.create(
        thread_id=thread['thread_id'],
        assistant_id="retrieval_graph",
        input={"query": query},
        config={
            "configurable": {
                # "queryModel": "ollama/qwen3:4b"
                # "queryModel": "ollama/llama3.2:1b"
            }
        }
    )
    
    # Wait for the run to complete by polling its status
    run_id = result.get('run_id') if isinstance(result, dict) else getattr(result, 'run_id', None)
    if run_id:
        while True:
            run_status = await client.runs.get(
                thread_id=thread['thread_id'],
                run_id=run_id
            )
            status = run_status.get('status') if isinstance(run_status, dict) else getattr(run_status, 'status', None)
            if status in ['success', 'error', 'cancelled']:
                break
            await asyncio.sleep(0.5)  # Poll every 500ms
    
    # Get the thread state which contains messages
    thread_state = await client.threads.get(thread_id=thread['thread_id'])
    
    # Access messages from the thread's state
    messages = thread_state.get('values', {}).get('messages', [])
    
    if messages:
        # Get the latest message (last in the list)
        latest = messages[-1]
        latest_content = latest.get('content') if isinstance(latest, dict) else getattr(latest, 'content', None)
        if latest_content:
            print(f"Response: {latest_content}")

print("\n‚úÖ All queries completed!")