<a href="https://colab.research.google.com/github/lokeshparab/GenAI-Full-Course/blob/main/RAG/Graph_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import and Install Library

In [1]:
!pip install langchain-graph-retriever graph_rag_example_helpers transformers grandalf
!pip install langchain-openai langchain-groq langchain-anthropic langchain-google-genai langchain-huggingface

Collecting langchain-graph-retriever
  Downloading langchain_graph_retriever-0.8.0-py3-none-any.whl.metadata (4.1 kB)
Collecting graph_rag_example_helpers
  Downloading graph_rag_example_helpers-0.8.0-py3-none-any.whl.metadata (1.6 kB)
Collecting grandalf
  Downloading grandalf-0.8-py3-none-any.whl.metadata (1.7 kB)
Collecting backoff>=2.2.1 (from langchain-graph-retriever)
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Collecting graph-retriever (from langchain-graph-retriever)
  Downloading graph_retriever-0.8.0-py3-none-any.whl.metadata (1.6 kB)
Collecting astrapy>=1.5.2 (from graph_rag_example_helpers)
  Downloading astrapy-2.0.1-py3-none-any.whl.metadata (23 kB)
Collecting griffe>=1.5.7 (from graph_rag_example_helpers)
  Downloading griffe-1.7.3-py3-none-any.whl.metadata (5.0 kB)
Collecting python-dotenv>=1.0.1 (from graph_rag_example_helpers)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting deprecation<2.2.0,>=2.1.0 (from astrapy>=1.5.

In [1]:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFacePipeline,HuggingFaceEndpoint

from graph_rag_example_helpers.datasets.animals import fetch_documents
from graph_retriever.strategies import Eager
from langchain_graph_retriever import GraphRetriever

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.messages import (
    HumanMessage,
    SystemMessage,
)
# from lanchain_core.output_parse

import os
from google.colab import userdata

os.environ["GROQ_API_KEY"]=userdata.get('GROQ_API_KEY')
os.environ["GOOGLE_API_KEY"]=userdata.get('GOOGLE_API_KEY')
os.environ["OPENAI_API_KEY"]=userdata.get('OPENAI_API_KEY')
os.environ["ANTHROPIC_API_KEY"]=userdata.get('ANTHROPIC_API_KEY')
# os.environ["HUGGINGFACEHUB_API_TOKEN"]=userdata.get('HUGGINGFACE_API_KEY')
os.environ["HF_TOKEN"]=userdata.get('HUGGINGFACE_API_KEY')
# from huggingface_hub import login
# login(token=os.environ["HF_TOKEN"])


# Model loading

## Chat Model

In [4]:
#@markdown ### Google GenAI

model_name = "gemini-1.5-flash" # @param ["gemini-1.5-flash","gemini-1.5-flash-8b","gemini-1.5-pro","gemini-2.0-flash-lite","gemini-2.0-flash"] {"allow-input":true}

gemini_model = ChatGoogleGenerativeAI(model=model_name)
gemini_model.invoke("Hi How are you?")

AIMessage(content="I'm doing well, thank you for asking!  How are you today?", additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-1.5-flash', 'safety_ratings': []}, id='run--eb3e338e-00e0-4184-924d-7e2c8e7aacf3-0', usage_metadata={'input_tokens': 5, 'output_tokens': 18, 'total_tokens': 23, 'input_token_details': {'cache_read': 0}})

In [5]:
#@markdown ### Groq Model

model_name = "llama3-70b-8192" # @param ["deepseek-r1-distill-llama-70b","allam-2-7b","meta-llama/llama-4-maverick-17b-128e-instruct","meta-llama/llama-4-scout-17b-16e-instruct","meta-llama/Llama-Guard-4-12B","mistral-saba-24b","playai-tts","playai-tts-arabic","qwen-qwq-32b","llama3-8b-8192","llama3-70b-8192"] {"allow-input":true}
groq_model =ChatGroq(model=model_name)
groq_model.invoke("Hi How are you?")

AIMessage(content="I'm just a language model, I don't have emotions or personal experiences, but I'm functioning properly and ready to assist you with any questions or topics you'd like to discuss. How can I help you today?", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 45, 'prompt_tokens': 15, 'total_tokens': 60, 'completion_time': 0.128571429, 'prompt_time': 0.000189919, 'queue_time': 0.22068280199999998, 'total_time': 0.128761348}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_dd4ae1c591', 'finish_reason': 'stop', 'logprobs': None}, id='run--95b3cc47-e2e5-421a-a608-ac71dd14ba2d-0', usage_metadata={'input_tokens': 15, 'output_tokens': 45, 'total_tokens': 60})

In [6]:
#@markdown ### OpenAI model

model_name = "o3" # @param ["o4-mini","o3","o3-mini","o1","o1-pro","gpt-4o","gpt-4.1","gpt-4o-mini","chatgpt-4o-latest","gpt-4.1-mini","gpt-4.1-nano"] {"allow-input":true}
openai_model = ChatOpenAI(model=model_name)
openai_model.invoke("Hi How are you?")

AIMessage(content='Hello! I’m doing well, thank you for asking. How can I help you today?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 101, 'prompt_tokens': 11, 'total_tokens': 112, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 64, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'o3-2025-04-16', 'system_fingerprint': None, 'id': 'chatcmpl-Bc6Lehr6rKZWMqsKhkSzD2CNQ03sh', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--f8675d44-5b95-44c3-964f-d521edbaf8bc-0', usage_metadata={'input_tokens': 11, 'output_tokens': 101, 'total_tokens': 112, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 64}})

In [7]:
#@markdown ### Anthropic AI

model_name = "claude-3-opus-latest" # @param ["claude-2","claude-3-5-sonnet-latest","claude-3-opus-latest","claude-3-5-haiku-latest","claude-3-7-sonnet-latest"] {"allow-input":true}
claude_model = ChatAnthropic(model=model_name)
claude_model.invoke("Hi How are you?")

AIMessage(content="Hello! As an AI language model, I don't have feelings, but I'm functioning properly and ready to assist you. How can I help you today?", additional_kwargs={}, response_metadata={'id': 'msg_01A3aefD6XCokJ5k6ZqBf915', 'model': 'claude-3-opus-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, 'input_tokens': 12, 'output_tokens': 35, 'server_tool_use': None, 'service_tier': 'standard'}, 'model_name': 'claude-3-opus-20240229'}, id='run--16f0c957-a290-4a60-87c6-505e7ee5bb1f-0', usage_metadata={'input_tokens': 12, 'output_tokens': 35, 'total_tokens': 47, 'input_token_details': {'cache_read': 0, 'cache_creation': 0}})

In [None]:
# del llm

In [None]:
# # llm = HuggingFacePipeline.from_model_id(
# #     model_id='TinyLlama/TinyLlama-1.1B-Chat-v1.0',
# #     task='text-generation',
# #     pipeline_kwargs={
# #         "do_sample": False,
# #         "max_new_tokens":100
# #         }
# # )

# llm = HuggingFaceEndpoint(
#     repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
#     task="text-generation",
#     max_new_tokens=512,
#     do_sample=False,
#     repetition_penalty=1.03,
# )
# messages = [
#     SystemMessage(content="You're a helpful assistant"),
#     HumanMessage(
#         content="Hi I am Lokesh"
#     ),
# ]

# model = ChatHuggingFace(llm=llm)
# model.invoke(messages)


##Embeding Model

In [43]:
# @markdown # OpenAI

embedding_model = "text-embedding-3-large" # @param ["text-embedding-3-large","text-embedding-3-small","text-embedding-ada-002"]
dimensions = 128 #@param {type:"integer"}
set_dimension = True # @param {type:"boolean"}
query = "India is a growing country" # @param {"type":"string","placeholder":"India is a growing country"}

if set_dimension:
  openai_embedding = OpenAIEmbeddings(
      model=embedding_model,
      dimensions=dimensions,
  )
else:
  openai_embedding = OpenAIEmbeddings(
      model=embedding_model,
  )

result = openai_embedding.embed_query(query)
print(len(result),result)

128 [-0.15703527629375458, 0.1968393623828888, -0.013897349126636982, 0.03334273397922516, 0.005019812379032373, 0.001335038454271853, -0.11188763380050659, 0.06423179060220718, 0.022900978103280067, -0.08653298020362854, 0.0598151721060276, -0.03255210444331169, 0.08069868385791779, -0.07344670593738556, -0.0737193375825882, 0.010571255348622799, 0.06297768652439117, -0.14602099359035492, -0.02182408608496189, -0.07704543322324753, -0.14056839048862457, -0.08200731128454208, 0.032006844878196716, -0.12639158964157104, -0.15245507657527924, 0.06401368975639343, -0.02585902065038681, 0.0024656038731336594, -0.10992469638586044, 0.09329422563314438, 0.06706714630126953, 0.009944204241037369, 0.04356638342142105, -0.03857724368572235, -0.022451138123869896, 0.19618503749370575, -0.0173393115401268, -0.038822609931230545, -0.053217511624097824, -0.059324439615011215, -0.005176575388759375, 0.01108925323933363, -0.07988078892230988, 0.28724369406700134, 0.13358904421329498, 0.00750415958464

In [9]:
# @markdown # Google GenAi

embedding_model = "models/gemini-embedding-exp-03-07" # @param ["models/gemini-embedding-exp-03-07","models/text-embedding-004","models/embedding-001"]
task_type = "retrieval_query" # @param ["None","task_type_unspecified","retrieval_query","retrieval_document","semantic_similarity","classification","clustering"]
transport = "None" # @param ["None","rest","grpc","grpc_asyncio"]
query = "India is a growing country" # @param {"type":"string","placeholder":"India is a growing country"}

func = lambda x : None if x=="None" else x
task_type = func(task_type)
transport = func(transport)

google_embedding = GoogleGenerativeAIEmbeddings(
    model=embedding_model,
    task_type=task_type,
    transport=transport
)

result = google_embedding.embed_query(query)
print(len(result),result)


3072 [-0.01328858733177185, -0.008076989091932774, -0.02803654596209526, -0.0385795421898365, -0.0012177267344668508, -0.009753282181918621, 0.018366066738963127, 0.0390302836894989, 0.005043548997491598, -0.023734668269753456, -0.011821064166724682, 0.006893169600516558, 0.019229836761951447, 0.03314683958888054, 0.11351391673088074, 0.020717905834317207, 0.0073205651715397835, -0.019075468182563782, 0.006459957454353571, -0.0205469261854887, -0.003892261302098632, 0.01249883696436882, 0.01834138110280037, 0.004717701114714146, -0.005373682361096144, -0.0032776377629488707, 0.006898732855916023, 0.02193414606153965, 0.021265285089612007, 0.014482101425528526, 0.005145769566297531, -0.027150483801960945, 0.023643454536795616, 0.012725071050226688, 0.017535502091050148, 0.015545469708740711, 0.022201204672455788, 0.0026000298094004393, 0.02127012610435486, 0.005257639102637768, -0.02616986259818077, 0.009095565415918827, -0.022166412323713303, -0.007171641103923321, -0.02820508368313312

In [None]:
# @markdown # Hugging Face
model_name = "jinaai/jina-embeddings-v2-base-en" # @param ["BAAI/bge-en-icl","all-MiniLM-L6-v2","jinaai/jina-embeddings-v3","jinaai/jina-embeddings-v2-base-en"]
query = "India is a growing country" # @param {"type":"string","placeholder":"India is a growing country"}
huggingface_embeddings=HuggingFaceEmbeddings(model_name=model_name)

result = huggingface_embeddings.embed_query(query)
print(len(result),result)

# Ready Example

In [12]:
animals=fetch_documents()
print("Length of documents:",len(animals))
animals[:3]

Length of documents: 99


[Document(id='aardvark', metadata={'type': 'mammal', 'number_of_legs': 4, 'keywords': ['burrowing', 'nocturnal', 'ants', 'savanna'], 'habitat': 'savanna', 'tags': [{'a': 5, 'b': 7}, {'a': 8, 'b': 10}]}, page_content='the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.'),
 Document(id='albatross', metadata={'type': 'bird', 'number_of_legs': 2, 'keywords': ['seabird', 'wingspan', 'ocean'], 'habitat': 'marine', 'tags': [{'a': 5, 'b': 8}, {'a': 8, 'b': 10}]}, page_content='the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.'),
 Document(id='alligator', metadata={'type': 'reptile', 'number_of_legs': 4, 'keywords': ['reptile', 'jaws', 'wetlands'], 'diet': 'carnivorous', 'nested': {'a': 5}}, page_content='alligators are large reptiles with powerful jaws and are commonly found in freshwater wetlands.')]

In [44]:
vector_store = InMemoryVectorStore.from_documents(
    documents=animals,
    embedding=openai_embedding,
)

traversal_retriever = GraphRetriever(
    store=vector_store,
    edges = [
        ("habitat", "habitat"), ("origin", "origin")
    ],
    strategy=Eager(select_k=5,start_k=1,max_depth=10)
)

traversal_retriever

GraphRetriever(store=<langchain_core.vectorstores.in_memory.InMemoryVectorStore object at 0x7c468466d550>, edges=[('habitat', 'habitat'), ('origin', 'origin')], strategy=Eager(select_k=5, start_k=1, adjacent_k=10, max_traverse=None, max_depth=10, _query_embedding=[]), adapter=<langchain_graph_retriever.adapters.in_memory.InMemoryAdapter object at 0x7c4684698050>)

In [45]:
print(traversal_retriever.get_graph().draw_ascii())

+---------------------+  
| GraphRetrieverInput |  
+---------------------+  
            *            
            *            
            *            
   +----------------+    
   | GraphRetriever |    
   +----------------+    
            *            
            *            
            *            
+----------------------+ 
| GraphRetrieverOutput | 
+----------------------+ 


In [46]:
results=traversal_retriever.invoke("what animal could be found near a anaconda?")

print(len(results))
results

1


[Document(id='anteater', metadata={'_depth': 0, '_similarity_score': np.float64(0.45615814038534386), 'type': 'mammal', 'number_of_legs': 4, 'keywords': ['ants', 'tongue', 'termites'], 'diet': 'insectivore', 'nested': {'b': 5}}, page_content='anteaters use their long tongues to eat thousands of ants and termites each day.')]

In [27]:
prompt = ChatPromptTemplate.from_template(
"""Answer the question based only on the context provided.

Context: {context}

Question: {question}"""
)

def format_docs(docs):
  return "\n\n".join(
      f"text :{doc.page_content} metadata: {doc.metadata}"
      for doc in docs
  )

chain = (
    {"context": traversal_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | openai_model
    | StrOutputParser()
)

chain.get_graph().print_ascii()

        +---------------------------------+        
        | Parallel<context,question>Input |        
        +---------------------------------+        
                 **              ***               
              ***                   **             
            **                        ***          
+----------------+                       **        
| GraphRetriever |                        *        
+----------------+                        *        
         *                                *        
         *                                *        
         *                                *        
  +-------------+                 +-------------+  
  | format_docs |                 | Passthrough |  
  +-------------+                 +-------------+  
                 **              **                
                   ***        ***                  
                      **    **                     
       +----------------------------------+        
       | Par

In [28]:
chain.invoke("what animal could be found near a anaconda?")

'An anteater'

In [29]:
chain.invoke("what animal could be found near a tiger?")

'Based on the context, a deer could be found near a tiger.'

In [33]:
chain.invoke("what animal can be found in north America?")

'The animal is the bison.'

In [32]:
chain.invoke("what all animal can be found in north America give me all the names?")

'From the information given, the only animal identified as living in North America is the bison.'