<a href="https://colab.research.google.com/github/manjunathgoli/Deep_Research_AI_Agent/blob/main/Deep_Research_AI_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
!pip install langchain-community
!pip install langgraph
from IPython import get_ipython
from IPython.display import display
import os
import torch
from langchain.llms import OpenAI
from langchain.tools import TavilySearchResults
from langgraph.graph import StateGraph
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer



In [38]:
# Set up Tavily API key
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not TAVILY_API_KEY:
    raise ValueError("TAVILY_API_KEY is missing. Set it in environment variables.")

In [39]:
# Set up device and model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "facebook/bart-large-cnn"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [40]:
# Research Agent
class ResearchAgent:
    def __call__(self, state):
        search_tool = TavilySearchResults(tavily_api_key=TAVILY_API_KEY)
        query = state["question"]
        results = search_tool.run(query)

        if not results:
            state["research_data"] = []
            state["source_urls"] = []
            return state

        urls = [result["url"] for result in results if isinstance(result, dict) and "url" in result]
        state["research_data"] = results
        state["source_urls"] = urls
        return state

In [41]:
# Term Expansion Agent
class TermExpansionAgent:
    def __call__(self, state):
        unknown_terms = extract_unknown_terms(state["research_data"], state["question"])
        if unknown_terms:
            search_tool = TavilySearchResults(tavily_api_key=TAVILY_API_KEY)
            definitions = {term: search_tool.run(term) for term in unknown_terms}
            state["definitions"] = definitions
        return state

In [42]:
# Summarization Agent
class SummarizationAgent:
    def __call__(self, state):
        context = state.get("research_data", [])
        if "definitions" in state:
            context += str(state["definitions"])

        if isinstance(context, list):
            context = " ".join(str(item) for item in context)

        if not context.strip():
            state["summary"] = "No relevant research data available."
            return state

        input_text = "summarize: " + context
        inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)

        if inputs.input_ids.shape[1] == 0:
            state["summary"] = "Failed to generate a summary due to empty input."
            return state

        summary_ids = model.generate(**inputs, max_length=512, min_length=150, length_penalty=2.0)
        state["summary"] = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return state

In [43]:
# Answer Agent
class AnswerAgent:
    def __call__(self, state):
        context = state.get("research_data", [])
        if "definitions" in state:
            context += str(state["definitions"])

        if isinstance(context, list):
            context = " ".join(str(item) for item in context)

        if not context.strip():
            state["answer_text"] = "No content available for answering."
            return state

        input_text = "summarize: " + context
        inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)

        if inputs.input_ids.shape[1] == 0:
            state["answer_text"] = "Failed to generate an answer due to empty input."
            return state

        summary_ids = model.generate(**inputs, max_length=512, min_length=50)
        state["answer_text"] = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        sources = state.get("source_urls", [])
        if sources:
            state["answer_text"] += "\n\nSources:\n" + "\n".join(sources)

        return state

In [44]:
def extract_unknown_terms(research_data, query):
    return []

In [45]:
# Define the state structure
from typing import TypedDict, List

class State(TypedDict):
    question: str
    research_data: List[str]
    definitions: dict
    summary: str
    answer_text: str
    source_urls: List[str]

# Construct the graph
graph = StateGraph(State)
graph.add_node("research", ResearchAgent())
graph.add_node("term_expansion", TermExpansionAgent())
graph.add_node("summarization", SummarizationAgent())
graph.add_node("answer", AnswerAgent())

# Define transitions
graph.add_conditional_edges(
    "research",
    lambda state: "term_expansion" if extract_unknown_terms(state["research_data"], state["question"]) else "summarization",
)
graph.add_edge("term_expansion", "summarization")
graph.add_edge("summarization", "answer")

graph.set_entry_point("research")
compiled_graph = graph.compile()

In [46]:
# Function to run the system
def runsystem(query):
    state = {
        "question": query,
        "research_data": [],
        "definitions": {},
        "summary": "",
        "answer_text": "",
        "source_urls": [],
    }
    final_state = compiled_graph.invoke(state)

    if not final_state.get("research_data"):
        print("No research data available for summarization.")
    else:
        print("\nGenerated Answer:\n")
        print(final_state.get("answer_text", "No answer generated."))

        sources = final_state.get("source_urls", [])


In [47]:
if __name__ == "__main__":
    query = input("Enter your research question: ")
    runsystem(query)

Enter your research question:  describe about lions

Generated Answer:

The lion (Panthera leo) is a large cat of the genus Panthera, native to Africa and India. Nearly all wild lions live in Africa, but one small population exists elsewhere. Male lions can weigh 30 stone. Up to 80 percent of lion cubs die within their first 2 years of life.

Sources:
https://en.wikipedia.org/wiki/Lion
https://www.wwf.org.uk/learn/fascinating-facts/lions
https://www.britannica.com/animal/lion
https://nationalzoo.si.edu/animals/lion
