In [None]:
import torch
from langchain_core.embeddings import Embeddings
from colpali_engine.models import BiQwen2_5, BiQwen2_5_Processor
from transformers.utils.import_utils import is_flash_attn_2_available

class NomicEmbeddings(Embeddings):
	def __init__(
		self, 
		device: str = 'cuda:0',
		use_flash_attn: bool = True,
		model_name: str = 'nomic-ai/nomic-embed-multimodal-3b'
	):
		self.model = BiQwen2_5.from_pretrained(
			model_name,
			torch_dtype=torch.bfloat16,
			device_map=device,
			attn_implementation='flash_attention_2' if use_flash_attn and is_flash_attn_2_available() else None
		)
		self.processor = BiQwen2_5_Processor.from_pretrained(model_name)

	def embed_documents(self, texts: list[str]) -> list[list[float]]:
		with torch.no_grad():
			batch_queries = self.processor.process_queries(texts).to(self.model.device)
			result_tensors = self.model(**batch_queries)
			return result_tensors.tolist()
	
	def embed_query(self, text: str) -> list[float]:
		with torch.no_grad():
			batch_queries = self.processor.process_queries([text]).to(self.model.device)
			result_tensors = self.model(**batch_queries)
			return result_tensors.tolist()[0]

embedding_model = NomicEmbeddings()

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
import os
from langchain_community.vectorstores import Neo4jVector
from dotenv import load_dotenv

load_dotenv()

vectorstore = Neo4jVector.from_existing_index(
    embedding=embedding_model,
    url=os.environ['NEO4J_URI'],
    username=os.environ['NEO4J_USER'],
    password=os.environ['NEO4J_PWD'],
    index_name='fact_embedding_idx',
    node_label='Fact',
    text_node_property='text',
    embedding_node_property='embedding',
)
vector_retriever = vectorstore.as_retriever(
	search_type='mmr',
	search_kwargs={
		'k': 5,
		'fetch_k': 50,
		'lambda_mult': 0.3
  	}
)
docs = vector_retriever.get_relevant_documents('What is the largest bird species?')

print(f"Got {len(docs)} docs")
for i, d in enumerate(docs, 1):
    print(f"\n--- Doc {i} ---")
    print("Content:", d.page_content)
    print("Metadata:", d.metadata)

Got 5 docs

--- Doc 1 ---
Content: Oldest recorded bird 12 years old.
Metadata: {'title': 'Life Span and Survivorship', 'created_at': neo4j.time.DateTime(2025, 11, 11, 16, 33, 51, 586000000, tzinfo=<UTC>), 'bird_name': 'Red-footed Falcon'}

--- Doc 2 ---
Content: 135–150 cm (1); 4310–4468 g (2, 1); wingspan 210–230 cm. Sexes alike, although female is generally smaller and has shorter bill (1). Largest heron, with large and very deep bill  , mainly rufous-chestnut head, neck and underparts  , with chin to upper breast white streaked black over foreneck and breast  , and slate-grey upperparts  and lanceolate plumes on scapulars and mantle; very long tibia 2 exaggerates impression of large bird, while flight is slow and ponderous, with both wings and legs sagging below horizontal. Only likely to be confused with A. purpurea, but present species is almost twice as large and has rufous, not black, crown and much larger and heavier bill. Bill can be all blackish, but usually has horn-coloure

In [None]:
from operator import itemgetter
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel 
from langsmith import traceable


class RAGChain():
	def __init__(
		self,
		model: str,
		temperature: float,
		ollama_uri: str = 'http://127.0.0.1:11434'
	):
		prompt = ChatPromptTemplate.from_messages([
			("system", 
				"You are a helpful assistant for bird-related questions. "
				"Answer using ONLY the provided context when possible. "
				"If the answer is not in the context, say you don't know."
			),
			("human",
				"Context:\n{context}\n\n"
				"Question: {question}\n\n"
				"Answer:"
			),
		])

		llm = ChatOllama(
			model=model,
			temperature=temperature,
			base_url=ollama_uri
		)

		self.chain = (
			itemgetter('question')
			| RunnableParallel(
				context = vector_retriever, question = RunnablePassthrough()
			)
			| RunnablePassthrough.assign(
				answer=(prompt | llm | StrOutputParser())
			)
		)

	@traceable(name='rag_chain_invoke')
	def invoke(self, question: str):
		return self.chain.invoke({ 'question': question })

	def _format_docs(self, docs):
		return "\n\n".join(d.page_content for d in docs)


chain = RAGChain(model='gpt-oss:20b', temperature=0.3)
chain.invoke('What is the wingspan of an Ostrich?')

{'context': [Document(metadata={'title': 'Identification', 'created_at': neo4j.time.DateTime(2025, 11, 10, 17, 51, 28, 756000000, tzinfo=<UTC>), 'bird_name': 'European Shag'}, page_content='65–80 cm; male 1760–2154 g, female 1407–1788 g; wingspan 90–105 cm.'),
  Document(metadata={'title': 'Systematics History', 'created_at': neo4j.time.DateTime(2025, 11, 13, 6, 28, 7, 87000000, tzinfo=<UTC>), 'bird_name': 'Brown Oriole'}, page_content='See O. sagittatus. Monotypic.'),
  Document(metadata={'title': 'Identification', 'created_at': neo4j.time.DateTime(2025, 11, 11, 4, 18, 22, 700000000, tzinfo=<UTC>), 'bird_name': 'Alor Boobook'}, page_content='A fairly typical Ninox owl with vivid yellow eyes set in a brownish head, mottled brown-and-white underparts, and brown wings with extensive white spots.'),
  Document(metadata={'title': 'Life Span and Survivorship', 'created_at': neo4j.time.DateTime(2025, 11, 11, 16, 33, 51, 586000000, tzinfo=<UTC>), 'bird_name': 'Red-footed Falcon'}, page_conten

In [None]:
from langsmith import Client
ls_client = Client()