In [72]:
questions = [
	#'What are the closest blue birds near me?',
	'What do I need to go to see an Ostrich?',
	#'What are some easy birds to start off with in my area?',
	'What are some easy birds to start off with?',
	'How many red birds are there in Australia?',
	'What equipment will I need to start bird watching?'
]

In [None]:
import os
from typing import (
	TypedDict,
	Annotated,
	Sequence,
	Optional,
	Callable,
	Dict,
	List, 
	Tuple,
	Union,
)
from langchain_core.messages import (
    BaseMessage,
	SystemMessage,
	AIMessage,
	HumanMessage
)
from langgraph.graph import StateGraph, START, END, add_messages
from langgraph.graph.state import CompiledStateGraph
from langchain_ollama import ChatOllama
from neo4j import Query, Record
from neo4j.graph import Node
from langchain_neo4j import GraphCypherQAChain, Neo4jGraph
from langchain_neo4j.chains.graph_qa.cypher import extract_cypher
from langchain_core.tools import tool, StructuredTool
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.prompts import ChatPromptTemplate

from dotenv import load_dotenv

load_dotenv()


class BirdRetrieverState(TypedDict):
	messages: Annotated[Sequence[List[BaseMessage]], add_messages]
	image_uuids: Annotated[List[str], lambda prev, new: prev + new]


def serialize_node(node: Node) -> Dict[str, str]:
	label = list(node.labels)[0]
	allowed_keys = []

	if label == 'Bird':
		allowed_keys = ['family', 'order', 'genus', 'species', 'name'] 
	elif label == 'Fact':
		allowed_keys = ['bird_name', 'title', 'text']
	elif label == ['Image']:
		allowed_keys = ['bird_name', 'title']
	elif label in ['Family', 'Order', 'Genus']:
		allowed_keys = ['name']

	return { 'id': node.element_id } | {
		k: v
		for k, v in
		dict(node).items()
		if k in allowed_keys
	}


def serialize_value(record: Record) -> Dict[str, Union[str, Dict[str, str]]]:
	#print(value.keys())
	row = {}
	for attr in record.keys():
		value = record[attr]
		if hasattr(value, 'labels') and hasattr(value, 'items'):
			row[attr] = serialize_node(value)
		else:
			row[attr] = record[attr]
	return row


class BirdRetriever:
	def __init__(self):
		self.llm = ChatOllama(
			model='qwen3:30b',
			temperature=0.2,
			base_url='http://127.0.0.1:11434'
		)

		self.neo4j_graph = Neo4jGraph(
      		url=os.environ['NEO4J_URI'], 
        	username=os.environ['NEO4J_USER'], 
         	password=os.environ['NEO4J_PWD']
        )

		self.query_memory: Dict[str, List[Dict]] = {}

		self._generate_tools()
		self.llm_with_tools = self.llm.bind_tools(self.tools)

		self.app = self.build_graph()


	def build_graph(self) -> CompiledStateGraph:
		builder = StateGraph(BirdRetrieverState)
		builder.add_node('agent_node', self.agent_node)
		builder.add_node('tools', self.tools_node)

		builder.add_edge(START, 'agent_node')
		builder.add_conditional_edges(
			'agent_node',
			tools_condition
		)
		builder.add_edge('tools', 'agent_node')

		return builder.compile()


	def stream(self, message: str):
		'''SystemMessage(content=(
			'You are a retrieval hybrid search agent interacting with a neo4j graph database, with vector embedding fields. '	
			'The graph database you are interacting with contains information about birds, their relevant facts, images of those birds, and taxonomic information. '
			'Your goal is to return a desirable answer to the query provided, as well as any bird images that support or give examples of your answer. '
		)),'''
		initial_state = { 'messages': [
			HumanMessage(content=message)	
		]}	


		for event in self.app.stream(initial_state):
			for node_name, value in event.items():
				print(f"--- Node: {node_name} ---")

				# Check messages
				if "messages" in value:
					message = value["messages"][-1]
					if hasattr(message, "content") and message.content:
						print(f"Content: {message.content}")

					if hasattr(message, "tool_calls") and message.tool_calls:
						print(f"Tool Call: {message.tool_calls[0]['name']} -> {message.tool_calls[0]['args']}")

				print("\n")


	def agent_node(self, state: BirdRetrieverState) -> BirdRetrieverState:
		self._generate_tools()
		self.llm_with_tools = self.llm.bind_tools(self.tools)
		ai_msg = self.llm_with_tools.invoke(state['messages'])
		return { 'messages': [ai_msg] }
	

	def tools_node(self, state: BirdRetrieverState) -> BirdRetrieverState:
		self._generate_tools()
		tool_node = ToolNode(self.tools)
		return tool_node.invoke(state)


	def _generate_tools(self) -> None:
		cypher_tool = self._create_cypher_tool()

		self.tools = [ 
            cypher_tool 
        ]
	
	
	def _create_cypher_tool(self) -> StructuredTool:
		if self.query_memory:
			memory_keys = list(self.query_memory.keys())
			cypher_tool_desc = (
       			'Takes a natural language query as input and turns it into a cypher query based on the schema of the graph database.\n'
				'You can optionally start with the results of a previous query that exists in memory as a starting point.\n'
				'This tool will only ever return 100 results at a time.\n\n'
				'Args:\n'
				'\tquery: the natural language query to generate the cypher\n'
				'\tmemory_key_id (Optional): the EXACT key of the previous query stored in memory.\n\n'
				f'The memory key ids that can be used are: {memory_keys}'
			)
		else:
			cypher_tool_desc = (
       			'Takes a natural language query as input and turns it into a cypher query based on the schema of the graph database.\n'
				'This tool will only ever return 100 results at a time.\n\n'
				'Args:\n'
				'\tquery: the natural language query to generate the cypher\n'
				'\tmemory_key_id: (IGNORE) the memory key of the results of a previous query'
			)

		return StructuredTool.from_function(
			func=self._query_cypher_tool(),
			name='query_cypher',
			description=cypher_tool_desc
		)
	

	def _query_cypher_tool(self) -> Callable:
		def query_cypher(query: str, memory_key_id: Optional[List[str]] = None):

			memory_key_id = memory_key_id if memory_key_id in self.query_memory else None

			node_memory_ids = self._parse_query_memory_nodes(memory_key_id) if memory_key_id else None

			cypher, schema = self._generate_cypher(
				query, 
				memory_key_id=memory_key_id,
				is_node_memory=bool(node_memory_ids),
			)
			print('CYPHER ---- ', cypher)

			with self.neo4j_graph._driver.session() as session:
				records = session.execute_read(
					lambda tx: tx.run(
						cypher,
						{ 'startIds': node_memory_ids } if node_memory_ids else {}
					).data()
				)
				serialized_results = [ serialize_value(r) for r in records ]

				summary = self.summarize_query(cypher, schema)

				self.query_memory[summary] = serialized_results

				return { f'{summary}': serialized_results }

		return query_cypher
	

	def	summarize_query(self, cypher, schema) -> str:
		# TODO: Summarize query using fine tuned CodeT5+ model
		summary = self.llm.invoke((
			'Come up with a short title for the cypher query below, with a given schema.\n'
   			'The title should completely describe what the query does, expect it to be referenced with no knowledge of the graph schema or content.\n\n'
			'Schema:\n\n'
			f'{schema}\n\n'
			'Cypher:\n\n'
			f'{cypher}\n\n'
			'Do not explain your answer, ONLY return the short title that describes what the query does.'
		)).content
		if '</think>' in summary:
			summary = summary.split('</think>')[-1].strip()
		return summary



	def _generate_cypher(self, query: str, memory_key_id: Union[str, None], is_node_memory: bool) -> Tuple[str, str]:
		cypher_prompt = ChatPromptTemplate.from_template(
			'Task: Generate a **read-only** Cypher statement to query a Neo4j graph.\n\n' +
			((
				f'You have the following results from a previous query that can be summarized as "{memory_key_id}":\n'
				f'{self.query_memory[memory_key_id]}'
				'\n\n'
			) if not is_node_memory else '') + (
				'Instructions:\n'
				'- Use ONLY the provided schema.\n'
			) + ((
				f'- You are starting with the results of query that can be summarized as "{memory_key_id}"\n'
				'- Match these starting nodes by elementId() using a parameter $startIds.\n'
			) if is_node_memory else '') + (
				'- There should be a LIMIT of 100 nodes returned.\n'
				'- Name ALL variables in the RETURN statement, and variable names should be as descriptive as possible, e.g. "bird" is acceptable, "b" is not.\n'
				'- The query MUST be read-only: no CREATE, MERGE, DELETE, SET, REMOVE, etc.\n'
				'- Return ONLY the Cypher query, no explanation.\n\n'
				'Schema: {schema}\n\n'
				'Question: {question}'
			)
		) if memory_key_id else None

		cypher_chain = GraphCypherQAChain.from_llm(
			llm=self.llm, 
			graph=self.neo4j_graph, 
			cypher_prompt=cypher_prompt,
			allow_dangerous_requests=True
		)

		cypher = cypher_chain.cypher_generation_chain.invoke({
			'question': query,
			'schema': cypher_chain.graph_schema,
		})

		cypher = extract_cypher(cypher)
		print('CYPHER ------ ', cypher)

		if '</think>' in cypher:
			cypher = cypher.split('</think>')[-1].strip()

		if cypher_chain.cypher_query_corrector:
			cypher = cypher_chain.cypher_query_corrector(cypher)

		return cypher, cypher_chain.graph_schema


	def _parse_query_memory_nodes(self, memory_key_id: str) -> Union[List[str], None]:
		output_sample = self.query_memory[memory_key_id][0]
		node_fields = [ 
			k for k in output_sample.keys() 
			if 'id' in output_sample[k]
		]
		if not node_fields:
			return None
		
		out = set([])
		for row in self.query_memory[memory_key_id]:
			for node_field in node_fields:
				out.add(row[node_field])

		return list(out)

In [128]:
bird_retriever_agent = BirdRetriever()

bird_retriever_agent.stream(questions[2])


--- Node: agent_node ---
Content: Okay, the user is asking, "How many red birds are there in Australia?" Let me think about how to approach this.

First, I need to figure out what data is available. The tools provided include a function called query_cypher, which converts natural language queries into Cypher queries for a graph database. But the user's question is about counting red birds in Australia. 

Wait, the problem is that the graph database schema isn't specified here. The function's description mentions it's based on the schema, but since I don't have access to the schema, I have to assume what entities and relationships might exist. For example, maybe there are nodes for birds, countries, colors, etc.

But the user is asking for a count. So the Cypher query would likely involve filtering birds that are red and located in Australia. However, without knowing the exact schema, it's a bit tricky. Let's assume the graph has a 'Bird' node with properties like 'color' and 'country',

In [None]:
def generate_cypher(q: str) -> None:
	llm = ChatOllama(
		model='gpt-oss:20b',
		temperature=0.5,
		base_url='http://127.0.0.1:11434'
	)

	neo4j_graph = Neo4jGraph(
		url=os.environ['NEO4J_URI'], 
		username=os.environ['NEO4J_USER'], 
		password=os.environ['NEO4J_PWD']
	)

	cypher_prompt = ChatPromptTemplate.from_template((
		'Task: Generate a **read-only** Cypher statement to query a Neo4j graph.\n\n' +
		'Instructions:\n'
		'- Use ONLY the provided schema.\n'
		'- The query MUST be read-only: no CREATE, MERGE, DELETE, SET, REMOVE, etc.\n'
		'- Return ONLY the Cypher query, DO NOT EXPLAIN.\n\n'
		'- Every at'
		'Schema: {schema}\n\n'
		'Question: {question}'
	))

	#'- There should be a LIMIT of 100 nodes returned.\n'
	#'- Any field on any node that has the title "embedding" can be used in semantic similarity search for text or images. Just pretend the embedding for a given query is inserted in the paraemter $embedding to compare.\n'
	cypher_chain = GraphCypherQAChain.from_llm(
		llm=llm, 
		graph=neo4j_graph, 
		cypher_prompt=cypher_prompt,
		allow_dangerous_requests=True
	)

	#print(cypher_chain.graph_schema)

	cypher = cypher_chain.cypher_generation_chain.invoke({
		'question': q,
		'schema': cypher_chain.graph_schema,
	})

	print('CYPHER ----- ', cypher)

#print(llm.invoke('How many red birds are there in Austrialia?').content.split('</think>')[-1].strip())
generate_cypher('How many red birds are there in Australia?')


CYPHER -----  MATCH (b:Bird)-[:HAS_FACT]->(d:Distribution)
WHERE toLower(b.name) CONTAINS 'red'
  AND (toLower(d.text) CONTAINS 'australia' OR toLower(d.title) CONTAINS 'australia')
RETURN count(DISTINCT b) AS redBirdsInAustralia
LIMIT 100


In [111]:
llm = ChatOllama(
	model='gpt-oss:20b',
	temperature=0.2,
	base_url='http://127.0.0.1:11434'
)

'''neo4j_graph = Neo4jGraph(
	url=os.environ['NEO4J_URI'], 
	username=os.environ['NEO4J_USER'], 
	password=os.environ['NEO4J_PWD']
)

cypher_chain = GraphCypherQAChain.from_llm(
	llm=llm, 
	graph=neo4j_graph, 
	allow_dangerous_requests=True
)

print(cypher_chain.graph_schema)

cypher = cypher_chain.cypher_generation_chain.invoke({
	'question': 'How many red birds are there in Australia?',
	'schema': cypher_chain.graph_schema,
})

print(cypher)'''

print(llm.invoke('How many red birds are there in Austrialia?').content)

**Short answer:**  
There is no single, definitive number for “how many red birds” live in Australia (or “Austrialia” – I’m assuming you meant Australia). The country is home to dozens of bird species that display red plumage, and each species has its own population size that can fluctuate year‑to‑year. Because of this, scientists can only give *estimates* for individual species, and even those estimates are based on limited surveys, breeding‑pair counts, and modeling.

---

## Why an exact number is impossible

| Reason | What it means for the count |
|--------|-----------------------------|
| **Many species** | Australia hosts ~1,500 bird species, of which roughly 30–40 have red or predominantly red plumage (e.g., Crimson Rosella, Red Wattlebird, Red Honeyeater, Red‑breasted Myna, etc.). |
| **Population estimates vary** | For most species, population estimates come from *point counts*, *breeding‑pair surveys*, or *remote‑sensing* data. These methods sample only a fraction of the hab

In [None]:
llm = ChatOllama(
	model='gpt-oss:20b',
	temperature=0.7,
	base_url='http://127.0.0.1:11434'
)

graph_description = '''
- (:Bird) nodes have the following attributes: {order: STRING, genus: STRING, family: STRING, species: STRING, name: STRING}

- (:Image) nodes:
-- are connected to birds by (:Bird)-[:HAS_IMAGE]->(:Image)
-- have an {embedding: LIST} attribute, which allows them to be matched to similar text/descriptions

- (:Genus), (:Order), and (:Family) nodes are connected to (:Bird) nodes by:
-- (:Bird)-[:IN_ORDER]->(:Order)
-- (:Bird)-[:IN_FAMILY]->(:Family)
-- (:Bird)-[:IN_GENUS]->(:Genus)

- The following nodes have label (:Fact) and the attributes: {bird_name: STRING, embedding: LIST, text: STRING}, where the embedding attribute can be used for semantic similarity search. They also have the relationship (:Bird)-[:HAS_FACT]->(:Fact). Each (:Fact) also has one of the following labels:
-- (:Introduction)
-- (:Identification)
-- (:Similar_Species)
-- (:Systematics_History)
-- (:Geographic_Variation)
-- (:Subspecies) 
-- (:Distribution) 
-- (:General_Habitat)
-- (:Movements_and_Migration) 
-- (:Diet_and_Foraging) 
-- (:Sounds_and_Vocal_Behavior) 
-- (:Breeding) 
-- (:Conservation_Status) 
-- (:Vernacular_Names) 
-- (:Hybridization) 
-- (:Eggs) 
-- (:Parental_Care) 
-- (:Plumages) 
-- (:Migration_Overview) 
-- (:Other) 
-- (:Nest) 
-- (:Predation)
-- (:Measures_of_Breeding_Activity) 
-- (:Causes_of_Mortality) 
-- (:Population_Regulation)
-- (:Historical_Changes_to_the_Distribution) 
-- (:Sexual_Behavior) 
-- (:Population_Spatial_Metrics) 
-- (:Social_and_Interspecific_Behavior) 
-- (:Diet) 
-- (:Agonistic_Behavior) 
-- (:Phenology) 
-- (:Nest_Site) 
-- (:Incubation) 
-- (:Hatching) 
-- (:Cooperative_Breeding) 
-- (:Population_Status) 
-- (:Effects_of_Human_Activity) 
-- (:Young_Birds) 
-- (:Life_Span_and_Survivorship) 
-- (:Bare_Parts) 
-- (:Feeding)
-- (:Vocalizations) 
-- (:Fledgling_Stage) 
-- (:Management) 
-- (:Related_Species) 
-- (:Fossils) 
-- (:Locomotion) 
-- (:Molts) 
-- (:Measurements) 
-- (:Nonvocal_Sounds) 
-- (:Behavior) 
-- (:Demography_and_Populations) 
-- (:Priorities_for_Future_Research) 
-- (:Food_Selection_and_Storage) 
-- (:Nutrition_and_Energetics) 
-- (:Metabolism_and_Temperature_Regulation)
-- (:Drinking_Pellet_Casting_and_Defecation) 
-- (:Brood_Parasitism) 
-- (:Similar_Species_Summary) 
-- (:Dispersal_and_Site_Fidelity) 
-- (:Pathogens_and_Parasites) 
-- (:Self_Maintenance) 
-- (:Habitat)
-- (:Nonmigratory_Movements) 
-- (:Habitat_in_Breeding_Range) 
-- (:Habitat_in_Nonbreeding_Range)
-- (:Timing_and_Routes_of_Migration)
-- (:Migratory_Behavior) 
-- (:Control_and_Physiology_of_Migration)
-- (:Field_Identification)

- All (:Fact) nodes have the VECTOR INDEX called "fact_embedding_idx" on the "embedding" field, which has cosine similarity
- All (:Image) nodes have the VECTOR INDEX called "image_embedding_idx" on the "embedding" field, which as cosine similarity
'''
plans = []

for ind in range(5):
	print(f'\nGenerating plan {ind+1}...\n')
	plan = llm.invoke([
		SystemMessage(content=(
			'You are a planning assistant for an autonomous agent that has access to a Neo4j graph database containing information pertaining to different bird species.\n'
			'Your database is hybrid search, meaning that you can search using cypher to get information from structured fields, but can also search based on semantic similarity based on embedding fields for text and images.\n'
			'With a user input, return a step by step plan that an agent with access to this database would take to gather the information required to respond to the input.\n\n'
			'This is a description of the graph SCHEMA:\n\n'
			f'{graph_description}\n\n'
			'Your agent can generate cypher to search this graph, and search by semantic similarity on embedding fields.\n'
			'Generate a plan, with bullet pointed steps, that an agent could take to get all the information necessary to answer this question as accurately as possible.\n'
			'Plans should only output a collection of nodes, not just a count. That way, the agent can explain its reasoning and/or give examples.\n'
			'You do not need to output any cypher, only high-level steps in natural language.'
			'DO NOT EXPLAIN YOUR ANSWER. Just return your bullet pointed plan. DO NOT INCLUDE SUB-STEPS, ONLY TOP-LEVEL BULLET POINTS.'
			#'Your agent has the following tools as its disposal:\n'
			#'- Cypher Query: Create a cypher query given the SCHEMA above.\n'
			#'- Semantic Search: Search for bird facts using the embedding field on the Fact nodes. This will return bird facts similar to a given string.\n'
			#'- Image Search: Search the embedding field on the Image nodes, finding images similar to a natural language query.\n\n'
			#'Come up with 5 possible plans, which should represent ways you could address the user\'s query. Separate each plan by a new line.\n'
			#'Return each plan in non-numbered bullet points. DO NOT EXPLAIN ANY ANSWER. JUST RETURN 5 possible plans.\n'
		)),
		HumanMessage(content='How many red birds are there in Australia?')
	]).content.split('</think')[-1].strip()
	print(plan)
	plans.append(plan)


Generating plan 1...

- Retrieve all Bird nodes connected to Fact nodes labeled (Plumages, Description, Field_Identification) whose embeddings are semantically similar to “red”.  
- Retrieve all Bird nodes connected to Fact nodes labeled Distribution whose text or embeddings indicate occurrence in Australia.  
- Intersect the two result sets to identify Bird nodes that are both red‑colored and present in Australia.  
- Count the distinct species (or subspecies) within this intersection to determine the number of red bird species found in Australia.  
- Optionally, gather example species names and a short excerpt from their distribution or plumage facts for illustration.

Generating plan 2...

- Retrieve all `(:Bird)` nodes that have a `(:Distribution)` fact mentioning “Australia”.  
- For each of those birds, fetch all `(:Fact)` nodes (e.g., `(:Color)`, `(:Description)`, `(:Introduction)`) linked via `(:Bird)-[:HAS_FACT]->(:Fact)`.  
- Perform a semantic similarity search on the `embe

In [154]:
import random
from collections import defaultdict
from itertools import combinations

pairs = list(combinations(range(len(plans)), 2))
print(f'Evaluating {len(pairs)} Pairs...\n')
plan_tallies = defaultdict(int)

grading_llm = ChatOllama(
	model='gpt-oss:20b',
	temperature=0.2,
	base_url='http://127.0.0.1:11434'
)

grading_prompt = ChatPromptTemplate.from_messages([
	('system',
		'You are a plan comparing agent. Your goal is to assess which plan (out of two plans) is better for answering/completing a particular query. '
		'You will accept the QUERY, PLAN 1, and PLAN 2 as input. You will return the integer 1 if PLAN 1 is better than PLAN 2, and return 2 if PLAN 2 is better. '
		'Evaluate the plans based on the following criteria: '
		'-- A plan should be thorough, reasonably accounting for any edge cases that affect the quality of the end result. '
		'-- A plan should avoid unneccessary steps and redundancies. '
		'-- A plan should be able to be reasonably executed. '
		'In your answer, ONLY return 1 or 2. DO NOT EXPLAIN YOUR ANSWER.'
	),
	('human', 
		'QUERY: {query}\n'
		'PLAN 1:\n\n{plan1}\n\n'
		'PLAN 2:\n\n{plan2}\n\n'
  	)
])

for pair in pairs:
	first_option = random.randint(0,1)
	second_option = 0 if first_option else 1
	print(f'GRADING PLAN {pair[first_option]} AGAINST {pair[second_option]}')
	plan = (grading_prompt | grading_llm).invoke({
		'query': questions[3],
		'plan1': plans[pair[first_option]],
		'plan2': plans[pair[second_option]]
	})
	if plan.content == '1':
		print(f'PLAN {pair[first_option]} WON')
		plan_tallies[pair[first_option]] += 1
	elif plan.content == '2':
		print(f'PLAN {pair[second_option]} WON')
		plan_tallies[pair[second_option]] += 1
	else:
		print('INCORRECT FORMAT')

print(plan_tallies)

Evaluating 10 Pairs...

GRADING PLAN 0 AGAINST 1
PLAN 1 WON
GRADING PLAN 0 AGAINST 2
PLAN 0 WON
GRADING PLAN 0 AGAINST 3
PLAN 3 WON
GRADING PLAN 0 AGAINST 4
PLAN 4 WON
GRADING PLAN 1 AGAINST 2
PLAN 1 WON
GRADING PLAN 3 AGAINST 1
PLAN 1 WON
GRADING PLAN 1 AGAINST 4
PLAN 1 WON
GRADING PLAN 2 AGAINST 3
PLAN 2 WON
GRADING PLAN 4 AGAINST 2
PLAN 2 WON
GRADING PLAN 3 AGAINST 4
PLAN 3 WON
defaultdict(<class 'int'>, {1: 4, 0: 1, 3: 2, 4: 1, 2: 2})


In [161]:
for step in plans[1].split('\n'):
    print(step)

- Retrieve all `(:Bird)` nodes that have a `(:Distribution)` fact mentioning “Australia”.  
- For each of those birds, fetch all `(:Fact)` nodes (e.g., `(:Color)`, `(:Description)`, `(:Introduction)`) linked via `(:Bird)-[:HAS_FACT]->(:Fact)`.  
- Perform a semantic similarity search on the `embedding` of each fact to find those that are most similar to the query “red” or contain the keyword “red”.  
- Filter the birds whose most similar fact(s) indicate a red coloration.  
- Compile a list of the distinct species (using the `species` or `name` property) that satisfy the red coloration criterion.  
- Count the number of species in that list to obtain the total number of red birds in Australia.


In [182]:
cypher_llm = ChatOllama(
	model='gpt-oss:20b',
	temperature=0.2,
	base_url='http://127.0.0.1:11434'
)

initial_message = SystemMessage(content=(
	'You are a cypher query generator. Your goal is to create a query that interacts with a Neo4j database, modifying it step-by-step based on user inputs.\n'
	'This is a description of the graph that you are interacting with:\n\n'	
	f'{graph_description}\n\n'
	'If you need to perform a vector similarity search on the embedding for a phrase or phrases, add a comment to the first line of the query. Call the function embed() on the word or phrase you want to embed, and assign it to the parameter for use in the query\n'
	'For instance:\n'
	'// $colorEmbedding = embed(Green)\n'
	'CALL db.index.vector.queryNodes(\'fact_embedding_idx\', 1000, $colorEmbedding) // (example)\n'
	'...rest of query...\n\n'
	'For each new embedding, add a new commented line.\n'
	'At any time, ONLY return the cypher with any comments, and modify the existing query if it exists. DO NOT explain. DO NOT format the cypher, return only as plain text.'
))

messages = [initial_message]

for step in plans[1].split('\n'):
	f_step = ('-'.join(step.split('-')[1:]) if step.startswith('-') else step).strip()
	print(f'HUMAN -- {f_step}\n\n')
	messages.append(HumanMessage(content=f_step))
	ai_msg = cypher_llm.invoke(messages)
	print(f'AI -- {ai_msg.content}\n\n')
	messages.append(ai_msg)

HUMAN -- Retrieve all `(:Bird)` nodes that have a `(:Distribution)` fact mentioning “Australia”.


AI -- MATCH (b:Bird)-[:HAS_FACT]->(f:Fact:Distribution)
WHERE toLower(f.text) CONTAINS "australia"
RETURN b;


HUMAN -- For each of those birds, fetch all `(:Fact)` nodes (e.g., `(:Color)`, `(:Description)`, `(:Introduction)`) linked via `(:Bird)-[:HAS_FACT]->(:Fact)`.


AI -- MATCH (b:Bird)-[:HAS_FACT]->(dist:Fact:Distribution)
WHERE toLower(dist.text) CONTAINS "australia"
MATCH (b)-[:HAS_FACT]->(f:Fact)
RETURN b, f;


HUMAN -- Perform a semantic similarity search on the `embedding` of each fact to find those that are most similar to the query “red” or contain the keyword “red”.


AI -- ```
// $redEmbedding = embed("red")
CALL db.index.vector.queryNodes('fact_embedding_idx', 1000, $redEmbedding) YIELD node AS simFact, score
WITH simFact, score
MATCH (b:Bird)-[:HAS_FACT]->(dist:Fact:Distribution)
WHERE toLower(dist.text) CONTAINS "australia"
MATCH (b)-[:HAS_FACT]->(f:Fact)
WHERE f = simFa

In [165]:
import torch
from transformers.utils.import_utils import is_flash_attn_2_available

from colpali_engine.models import BiQwen2_5, BiQwen2_5_Processor

embed_model_name = "nomic-ai/nomic-embed-multimodal-3b"

embed_model = BiQwen2_5.from_pretrained(
    embed_model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # or "mps" if on Apple Silicon
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()

embed_processor = BiQwen2_5_Processor.from_pretrained(embed_model_name)

`torch_dtype` is deprecated! Use `dtype` instead!


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


In [166]:
from typing import Union
import torch
from PIL import Image


def generate_txt_embeddings(queries: Union[str, list[str]]) -> torch.Tensor:
	with torch.no_grad():
		batch_queries = embed_processor.process_queries(queries).to(embed_model.device)
		return embed_model(**batch_queries)


def generate_img_embeddings(imgs: Union[Image.Image, list[Image.Image]]) -> torch.Tensor:
	with torch.no_grad():
		batch_imgs = embed_processor.process_images(imgs).to(embed_model.device)
		return embed_model(**batch_imgs)


def score_embeddings(target_embeds: torch.Tensor, reference_embeds: torch.Tensor) -> torch.Tensor:
	return embed_processor.score(
    	list(torch.unbind(target_embeds)), 
     	list(torch.unbind(reference_embeds))
    )


images = [
    Image.new("RGB", (128, 128), color="white"),
    Image.new("RGB", (64, 32), color="black"),
]
queries = [
    "What is the organizational structure for our R&D department?",
    "Can you provide a breakdown of last year’s financial performance?",
]


image_embeddings = generate_img_embeddings(images)
query_embeddings = generate_txt_embeddings(queries)
display(image_embeddings.shape, query_embeddings.shape)

score_embeddings(image_embeddings, query_embeddings)

torch.Size([2, 2048])

torch.Size([2, 2048])

tensor([[-0.0084,  0.0361],
        [-0.0138,  0.0432]], device='cuda:0')

In [183]:
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv
load_dotenv()

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD'])
)
database = os.environ['NEO4J_DATABASE']


red_embed = generate_txt_embeddings(['red']).tolist()[0]
with graph.session(database=database) as session:
	result = session.run('''
CALL db.index.vector.queryNodes('fact_embedding_idx', 1000, $redEmbedding) YIELD node AS f, score
MATCH (b:Bird)-[:HAS_FACT]->(dist:Fact:Distribution)
WHERE toLower(dist.text) CONTAINS "australia"
WITH b, f
WHERE f:Color OR toLower(f.text) CONTAINS "red"
RETURN COUNT(DISTINCT coalesce(b.species, b.name)) AS totalRedBirdsInAustralia;
	''', redEmbedding=red_embed)
	print(len([ r for r in result ]))

1


In [115]:
print(llm.invoke([
	SystemMessage(content=(
		'You are a helpful assistant. Your goal is to determine if you need external data to answer the user\'s question.\n'
		'If you can answer it confidently using your internal knowledge, output \'INTERNAL: [Your Answer]\'\n'
		'If you need to look up specific statistics, counts, or recent data from the bird database, output \'RETRIEVE\''
	)),
	HumanMessage(content='How many red birds are there in Australia?')
]).content)

RETRIEVE


In [63]:
import os
from typing import TypedDict, Dict, Optional, Any, Iterable, Union
from neo4j import GraphDatabase, Record
from neo4j.graph import Node
from dotenv import load_dotenv

load_dotenv()

class NodeRef(TypedDict):
	id: str
	name: str
	label: Optional[str]
	properties: Dict[str, Any]
	score: Optional[float]
	source_tool: str

def parse_node(node: Record) -> NodeRef:
	print(node)
	label = node.labels[0]
	return NodeRef(
		id=node.element_id,
		label=node.labels[0],
		properties=''
	)

def serialize_node(node: Node) -> Dict[str, str]:
	label = list(node.labels)[0]
	allowed_keys = []

	if label == 'Bird':
		allowed_keys = ['family', 'order', 'genus', 'species', 'name'] 
	elif label == 'Fact':
		allowed_keys = ['bird_name', 'title', 'text']
	elif label == ['Image']:
		allowed_keys = ['bird_name', 'title']
	elif label in ['Family', 'Order', 'Genus']:
		allowed_keys = ['name']

	return { 'id': node.element_id } | {
		k: v
		for k, v in
		dict(node).items()
		if k in allowed_keys
	}

def serialize_value(record: Record) -> Dict[str, Union[str, Dict[str, str]]]:
	#print(value.keys())
	row = {}
	for attr in record.keys():
		value = record[attr]
		if hasattr(value, 'labels') and hasattr(value, 'items'):
			row[attr] = serialize_node(value)
		else:
			row[attr] = record[attr]
	return row

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD'])
)
database = os.environ['NEO4J_DATABASE']

with graph.session(database=database) as session:
	results = session.run('MATCH (b:Bird) RETURN b.name AS bird_name, b.species AS species LIMIT 3')
	print([ serialize_value(r) for r in results ])
	results = session.run('MATCH (bird:Bird) RETURN bird LIMIT 2')
	print([ serialize_value(r) for r in results ])
	

[{'bird_name': 'Common Ostrich', 'species': 'Camelus'}, {'bird_name': 'Somali Ostrich', 'species': 'Molybdophanes'}, {'bird_name': 'Southern Cassowary', 'species': 'Casuarius'}]
[{'bird': {'id': '4:a554891f-4e4e-45d4-beca-f7796c51940f:3', 'genus': 'Struthio', 'species': 'Camelus', 'name': 'Common Ostrich', 'family': 'Struthionidae', 'order': 'Struthioniformes'}}, {'bird': {'id': '4:a554891f-4e4e-45d4-beca-f7796c51940f:118', 'genus': 'Struthio', 'species': 'Molybdophanes', 'name': 'Somali Ostrich', 'family': 'Struthionidae', 'order': 'Struthioniformes'}}]


In [None]:
import time
from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_core.prompts import ChatPromptTemplate

bird_retriever_llm = ChatOllama(
	model='gpt-oss:20b',
	temperature=0.2,
	base_url='http://127.0.0.1:11434'
)
yourmom = time.perf_counter()
neo4j_graph = Neo4jGraph(
	url=os.environ['NEO4J_URI'], 
	username=os.environ['NEO4J_USER'], 
	password=os.environ['NEO4J_PWD']
)
print(f'INIT GRAPH: {time.perf_counter() - yourmom}')
yourmom = time.perf_counter()

cypher_template = ChatPromptTemplate.from_template((
	'Task: Generate a **read-only** Cypher statement to query a Neo4j graph.\n\n'
	'Instructions:\n'
	'- Use ONLY the provided schema.\n'
	'- You are starting with the results of query that can be summarized as "birds in Australia"\n'
	'- Match these starting nodes by elementId() using a parameter $startIds\n'
	'- The query MUST be read-only: no CREATE, MERGE, DELETE, SET, REMOVE, etc.\n'
	'- Return ONLY the Cypher query, no explanation.\n\n'
	'Schema: {schema}\n\n'
	'Question: {question}'
))

cypher_chain = GraphCypherQAChain.from_llm(
	llm=bird_retriever_llm, 
	graph=neo4j_graph, 
	cypher_prompt=cypher_template,
	allow_dangerous_requests=True
)
print(f'INIT CHAIN: {time.perf_counter() - yourmom}')
cypher_chain.cypher_generation_chain.invoke({
	'question': 'List all red birds',
	'schema': cypher_chain.graph_schema
})

INIT GRAPH: 0.1260071249998873
INIT CHAIN: 0.0012473470014811028


"MATCH (b:Bird)\nWHERE id(b) IN $startIds\nMATCH (b)-[:HAS_FACT]->(f:Fact)\nWHERE f.text CONTAINS 'red'\nRETURN DISTINCT b;"