In [None]:
from src.papers.io import db
from dotenv import load_dotenv
import os

load_dotenv()
# TODO variables de entorno y no usar caminos relativos
EXTENDED_CRAWLER_DATA_PATH = os.getenv("EXTENDED_CRAWLER_DATA_PATH")
MILVUS_DB = os.getenv("MILVUS_DB")
MILVUS_COLLECTION = os.getenv("MILVUS_COLLECTION")
MILVUS_ALIAS = os.getenv("MILVUS_ALIAS")
MILVUS_HOST = os.getenv("MILVUS_HOST")
MILVUS_PORT = os.getenv("MILVUS_PORT")

milvus_client = db.Milvus(
    # db=MILVUS_DB, 
    collection=MILVUS_COLLECTION, 
    alias=MILVUS_ALIAS,
    host=MILVUS_HOST,
    port=MILVUS_PORT, 
    new_collection=False
)   



In [None]:
query="Spot instances"

In [None]:
nn_papers = milvus_client.search(
    text=query,
    output_fields=[
        "Title",
        "TLDR",
        "Abstract",
        "KeyConcepts",
        "Year",
        "Conference",
        "Summary",
        "AuthorsAndInstitutions"
    ],
    limit=100,
    hybrid=True,
    hybrid_fields=[
        "AbstractVector", 
        "TitleVector", 
        "TLDRVector",
        "KeyConceptsVector"
    ],
    expr="Year in ['2023']"
)

In [None]:
papers = [paper["entity"] for paper in nn_papers]
for paper in papers:
    print(paper["Year"])

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama


In [None]:
template = """Question: {question}

Answer: Let's think step by step."""

prompt = ChatPromptTemplate.from_template(template)

model = ChatOllama(model="llama3.2")

chain = prompt | model

chain.invoke({"question": "What is LangChain?"})

In [None]:
from langchain_core.messages import HumanMessage

response = model.invoke([HumanMessage(content="hi!")])
response

In [None]:
tools = []
model_with_tools = model.bind_tools(tools)

response = model_with_tools.invoke([HumanMessage(content="Hi!")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

response = model_with_tools.invoke([HumanMessage(content="What's the weather in SF?")])

print(f"ContentString: {response.content}")
print(f"ToolCalls: {response.tool_calls}")

### Agente básico
Las funciones a utilizar son proporcionada por medio de tools. Estas toman un input a partir de la query y son devueltos

Se debe tipifizar la salida con pydantic, que es proporcionada al agente

In [None]:

# Base agent
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
from langchain_ollama import ChatOllama

model = ChatOllama(model="qwen3")

def get_person(person: str) -> str:  
    """Get person information."""
    return f"Jaimito is my cousin, lives in Madrid, he is 7 years old"

class WhoIsResponse(BaseModel):
    person: str
    lives_in: str
    age: int
    
tools = [get_person]
agent = create_react_agent(
    model=model,
    tools=tools,
    response_format=WhoIsResponse,
    prompt="You are a helpful assistant that uses only the feedback you are provided",  
)
response = agent.invoke(
    {"messages": [{"role": "user", "content": "Who is Jaimito?"}]}
)

response["structured_response"]



### Minimal workflow

In [None]:
from langgraph.prebuilt import create_react_agent
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from pydantic import BaseModel
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_ollama import ChatOllama

model = ChatOllama(model="llama3.1:8b")

def get_weather(city: str) -> str:
    """Get weather for a given city."""
    print("weather")
    return f"It's always sunny in {city}!"

def who_is(text: str) -> str:  
    """who is person information."""
    print("person")
    return "Jaimito is my cousin, lives in Madrid, he is 7 years old"

# class DescriptionResponse(BaseModel):
#     person: str
#     lives_in: str
#     age: int

# agent = create_react_agent(
#     model=model,
#     tools=[get_weather],
#     prompt="You are a helpful assistant"
# )

tools = []

class State(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    messages: Annotated[list, add_messages]

# Tell the LLM which tools it can call
llm_with_tools = model.bind_tools(tools)

def chatbot(state: State):
    print(state)
    return {"messages": [model.invoke(state["messages"])]}

graph_builder = StateGraph(State)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)

# Add tool node
# tool_node = ToolNode(tools)
# graph_builder.add_node("tools", tool_node)

# Add condition to call a node or another
# graph_builder.add_conditional_edges(
#     "chatbot",
#     tools_condition,
# )

# The first argument is the unique node name
# The second argument is the function or object that will be called whenever
# the node is used.
# graph_builder.add_edge("tools", "chatbot")
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
        for value in event.values():
            print("Assistant:", value["messages"][-1].content)


In [None]:
stream_graph_updates("Spot instances papers")

### Prototipo de agente

In [None]:

# Base agent
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
from pydantic import BaseModel, Field
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage

query = "Which papers were published in 2022 in nsdi related to serverless cloud computing? Select only those from Germany. Omit also citations from 2023"
model = ChatOllama(model="qwen3", temperature=0, top_k=20, top_p=0.8)

class FilterOptions(BaseModel):
    """
    Class for filter option. 
    - value corresponds for to the value to be used in filtering
    - equal indicates if the condition indicates equality
    - citation indicates if this condition is referred only to cited papers
    """
    value: str
    equal: bool
    citation: bool

class FilterParameters(BaseModel):
    main_concept: str = None
    authors: list[FilterOptions] = None
    institutions: list[FilterOptions] = None
    countries: list[FilterOptions] = None
    years: list[FilterOptions] = None
    conferences: list[FilterOptions] = None
    
def get_query_filter_conditions(query: str, model: ChatOllama) -> FilterParameters:
    
    messages = [
        SystemMessage("""
        You are a helpful assistant that uses only the input the user provides. Your role is to help research paper data related to the topics the user asks about.
        Fields with list of FilterOptions type must be filled as follows:
        - value: The value of the field to be filled
        - equal: False if it is asked to be different from the value
        - citation: If the field value refers to paper citations
        Return JSON with (if not mentioned, leave it as None):
        - main_concept (string). Topic of research
        - years (list[FilterOptions]). Publish years
        - authors (list[FilterOptions]). List of paper authors
        - institutions (list[FilterOptions]). All must be valid institutions laike universities or corporations.
        - countries (list[FilterOptions]). All must be valid countries. Translate them to usual abbreviaions like US for United Staes of America or DE for Germany
        - conferences (list[FilterOptions]). List of valid conferences: IEEECloud, Middleware, SIGCOMM, eurosys
        Example response for 2024 papers about spot instances with 100+ citations published in middleware. Authors must be from Harvard University. Cited papers must not be from Canada:
        {{
            "main_concept": "spot instances",
            "authors": None,
            "institutions": [{{ "value": "Harvard University", "equal": True, "citation": False}}],
            "countries": [{{"value": "CA", "equal": False, "citation": True}}],
            "years": [{{"value": "2025", "equal": True, "citation": False}}],
            "conferences": [{{"value": "Middleware", "equal": True, "citation": False}}],
        }}
        /no_think
        """),
        HumanMessage(query),
    ]
    # response = model.with_structured_output(FilterParameters).invoke(messages)
    response = model.invoke(messages)
    print(response)
    # return response
    messages_for_structure = [
        messages[0],
        HumanMessage(f""" 
        Return JSON with (if not mentioned, leave it as None):
        - value: The value of the field to be filled
        - equal: False if it is asked to be different from the value
        - citation: If the field value refers to paper citations
        Return JSON with (if not mentioned, leave it as None):
        - main_concept (string). Topic of research
        - years (list[FilterOptions]). Publish years
        - authors (list[FilterOptions]). List of paper authors
        - institutions (list[FilterOptions]). All must be valid institutions laike universities or corporations.
        - countries (list[FilterOptions]). All must be valid countries. Translate them to usual abbreviaions like US for United Staes of America or DE for Germany
        - conferences (list[FilterOptions]). List of valid conferences: IEEECloud, Middleware, SIGCOMM, eurosys
        Using these previous result:
        {response}
        /no_think
        """
        )
    ]
    structured_response = model.with_structured_output(FilterParameters).invoke(messages_for_structure)
    print(structured_response)
    
    return structured_response

filter_conditions = get_query_filter_conditions(query, model)

In [None]:
def build_expr_for_filter_field(filter_options: list[FilterOptions], field_name: str, json_field: bool = False) -> str:
    expr = ""
    in_expr = []
    not_in_expr = []

    if not filter_options:
        return ""
    
    for filter_option in filter_options:
        if not filter_option.citation:
            if filter_option.equal:
                in_expr.append(filter_option.value)
            else:
                not_in_expr.append(filter_option.value)
            
    if len(in_expr) > 0:
        if len(expr) == 0:
            if json_field:
                expr = f"json_contains_any({field_name}, {in_expr})"
            else:
                expr = f"{field_name} in {in_expr}"
        else:
            if len(expr) == 0:
                if json_field:            
                    expr += f"and json_contains_any({field_name}, {in_expr})"
                else:
                    expr += f"and {field_name} in {in_expr}"

    if len(not_in_expr) > 0:
        if len(expr) == 0:
            if json_field:
                expr = f"not json_contains_any({field_name}, {not_in_expr})"
            else:
                expr = f"not {field_name} in {not_in_expr}"
        else:
            if len(expr) == 0:
                if json_field:            
                    expr += f"and not json_contains_any({field_name}, {not_in_expr})"
                else:
                    expr += f"and not {field_name} in {not_in_expr}"
    return expr

expr_years = build_expr_for_filter_field(filter_options=filter_conditions.years, field_name="Year")
expr_authors = build_expr_for_filter_field(filter_options=filter_conditions.authors, field_name="Authors", json_field=True)
expr_countries = build_expr_for_filter_field(filter_options=filter_conditions.countries, field_name="Countries", json_field=True)
expr_institution = build_expr_for_filter_field(filter_options=filter_conditions.institutions, field_name="Institutions", json_field=True)
expr_conferences = build_expr_for_filter_field(filter_options=filter_conditions.conferences, field_name="Conferences")

expr = ""
if len(expr_years) > 0:
    if len(expr) == 0:
        expr = expr_years
    else:
        expr += " and " + expr_years
        
if len(expr_authors) > 0:
    if len(expr) == 0:
        expr = expr_authors
    else:
        expr += " and " + expr_authors
        
if len(expr_countries) > 0:
    if len(expr) == 0:
        expr = expr_countries
    else:
        expr += " and " + expr_countries
        
if len(expr_institution) > 0:
    if len(expr) == 0:
        expr = expr_institution
    else:
        expr += " and " + expr_institution
        
# if len(expr_conferences) > 0:
#     if len(expr) == 0:
#         expr = expr_conferences
#     else:
#         expr += " and " + expr_conferences                                

print(expr)

In [None]:
nn_papers = milvus_client.search(
    text=filter_conditions.main_concept,
    output_fields=[
        "Title",
        "TLDR",
        "Abstract",
        "KeyConcepts",
        "Year",
        "Conference",
        "Summary",
        "AuthorsAndInstitutions"
    ],
    limit=100,
    hybrid=True,
    hybrid_fields=[
        "AbstractVector", 
        "TitleVector", 
        "TLDRVector",
        "KeyConceptsVector"
    ],
    # expr=expr
    expr=""
)
if len(nn_papers) > 0:
    print(len(nn_papers))
    print(nn_papers[0]["entity"]["AuthorsAndInstitutions"])

In [None]:
class TopicCheck(BaseModel):
    is_valid: bool
    # reason: str
    

def check_paper_topic(topic: str, paper: dict, model: ChatOllama) -> dict:
    model_with_structure = model.with_structured_output(TopicCheck)
    title = paper["Title"]
    abstract = paper["Abstract"] if paper["Abstract"] else ""
    tldr = paper["TLDR"] if paper["TLDR"] else ""
    messages = [
        SystemMessage("""You are a strict assistant. Your role is to help determine if papers discuss the topic the user asks about. 
        Only accept those that strictly mention the topic since a mistake will kill all your family. If in doubt is it better to determine a paper as not valid.
        Return JSON with (if not mentioned, leave it as None):
        - is_valid (bool). True if the paper data is about the topic requested
        Example response for paper with abstract 'Kappa proposes a framework for simplified serverless development using checkpointing to handle timeouts and providing concurrency mechanisms for parallel' and topic 'serverless development':
        {{
            "is_valid": true,
        }}  
        """),
        HumanMessage(f"""Determine if the following paper content is about the topic: <topic>{topic}</topic>\n
        Title: {title}\n
        TLDR: {tldr}\n
        Abstract: {abstract}\n
        """),
    ]
    response = model_with_structure.invoke(messages)
    return response

def filter_papers(topic: str, papers: list[dict], model: ChatOllama) -> list[str]:
    valid_papers = []
    for paper in papers:
        entity = paper["entity"]
        
        check_result = check_paper_topic(topic, entity, model)
        print(check_result)
        if check_result.is_valid:
            valid_papers.append(entity)
            
    return valid_papers
            
valid_papers = filter_papers(topic=filter_conditions.main_concept, papers=nn_papers, model=model)
print(len(valid_papers))

In [None]:
model = ChatOllama(model="qwen3:0.6b", temperature=0)

In [None]:
# Paralallell
from langchain_core.runnables import RunnableParallel
from langchain_core.prompts import ChatPromptTemplate

valid_papers = []
index = 0
map_chain_elements = {}
for paper in nn_papers:
    entity = paper["entity"]
    
    topic = filter_conditions.main_concept
    
    model_with_structure = model.with_structured_output(TopicCheck)
    title = entity["Title"]
    abstract = entity["Abstract"] if entity["Abstract"] else ""
    tldr = entity["TLDR"] if entity["TLDR"] else ""
    messages = [
        SystemMessage("""You are a strict assistant. Your role is to help determine if papers discuss the topic the user asks about. 
        Only accept those that strictly mention the topic since a mistake will kill all your family. If in doubt is it better to determine a paper as not valid.
        Return JSON with (if not mentioned, leave it as None):
        - is_valid (bool). True if the paper data is about the topic requested
        Example response for paper with abstract 'Kappa proposes a framework for simplified serverless development using checkpointing to handle timeouts and providing concurrency mechanisms for parallel' and topic 'serverless development':
        {{
            "is_valid": true,
        }}  
        """),
        HumanMessage(f"""Determine if the following paper content is about the topic: <topic>{topic}</topic>\n
        Title: {title}\n
        TLDR: {tldr}\n
        Abstract: {abstract}\n
        """)
    ]
    key = f"chain_{index}"
    index += 1
    chain = (
        ChatPromptTemplate.from_messages(messages)
        | model_with_structure
    )
    # print(chain)
    map_chain_elements[key] = chain
    # print(map_chain_elements)

# print(map_chain_elements)
map_chain = RunnableParallel(map_chain_elements)
response = map_chain.invoke({})
print(response)
valid_papers = []
for index, paper in enumerate(nn_papers):
    chain_i = f"chain_{index}"
    if response[chain_i].is_valid:
        valid_papers.append(paper)
        
print(len(valid_papers))
    
    


In [None]:
entity = nn_papers[0]["entity"]

print(entity["Title"])
print(entity["Abstract"])

In [None]:
model = ChatOllama(model="qwen3", temperature=0)

class AggregationInQueryChecker(BaseModel):
    requests_aggregation: bool
    aggregations_requested: list[str]
    

def check_aggregations_requested(query: str, model: ChatOllama) -> dict:
    model_with_structure = model.with_structured_output(AggregationInQueryChecker)

    messages = [
        SystemMessage("""You are a helpful AI assistant that uses only the feedback you are provided. Your role is to help in custom searches of scientific papers. Determine if in the provided query data aggregations are requested.
        An example would be "Provide the most cited papers of every country relating spot instances", which is requesting an aggregation by country
        Return JSON with the following properties:
        - requests_aggregation (bool). True if the query requests data aggregations
        - aggregations_requested (list[string]). List containing the expected aggregations. The possible values are the following ones: citations, papers, country, conference, institution, author
        Example response for the query 'Provide the most cited papers of every country relating spot instances in conference Middleware':
        {{
            "requests_aggregation": true,
            "aggregations_requested": ["citations", "country"]
        }}  
        """),
        HumanMessage(f"""Determine the aggregations requested in the following query: <query>{query}</query>:\n"""),
    ]
    response = model_with_structure.invoke(messages)
    return response

In [None]:

check_aggregations_requested(query=query, model=model)

In [None]:
from src.papers.domain.citations_analyzer import CitationAnalyzer
import polars as pl

### Dynamic query
neo4j_client = db.Neo4j("bolt://localhost:7687", "neo4j", "password", "middleware")
citations_analyzer = CitationAnalyzer(graph_client=neo4j_client)

papers = [paper["entity"] for paper in nn_papers]
df_citations = citations_analyzer.process_papers(papers)


In [None]:
df_citations.head

In [None]:
filter_conditions = get_query_filter_conditions(query, model)

In [None]:

    
def dynamic_filter(filters: list[dict], df: pl.DataFrame) -> pl.DataFrame:
    df_filtered = df
    for filter in filters:
        print(filter)
        if filter["equal"]:
            df_filtered = df_filtered.filter(
                pl.col(filter["field"]).is_in(filter["values"])
            )
        else:
            df_filtered = df_filtered.filter(
                ~pl.col(filter["field"]).is_in(filter["values"])
            )
    return df_filtered

def dynamic_aggregation(columns: list[str], df: pl.DataFrame) -> pl.DataFrame:
    return (
        df.group_by(columns)
        .agg(pl.count().alias("total_citations"))
    )
    
    # authors: list[FilterOptions] = None
    # institutions: list[FilterOptions] = None
    # countries: list[FilterOptions] = None
    # years: list[FilterOptions] = None
    # conferences: list[FilterOptions] = None
def process_paper_data(filter_parameters: FilterParameters, df:pl.DataFrame) -> pl.DataFrame:
    df_filters = []
    if filter_parameters.authors and len(filter_parameters.authors) > 0:
        df_filter_authors = map_filter_options(filter_options=filter_parameters.authors, source_field="source_author", cited_field="cited_author")
        for filter in df_filter_authors:
            df_filters.append(filter)
            
    if filter_parameters.institutions and len(filter_parameters.institutions) > 0:
        df_filter_institutions = map_filter_options(filter_options=filter_parameters.institutions, source_field="source_institution", cited_field="cited_institution")
        for filter in df_filter_institutions:
            df_filters.append(filter)
            model = ChatOllama(model="qwen3", temperature=0)

class AggregationInQueryChecker(BaseModel):
    requests_aggregation: bool
    aggregations_requested: list[str]
    

def check_aggregations_requested(query: str, model: ChatOllama) -> dict:
    model_with_structure = model.with_structured_output(AggregationInQueryChecker)

    messages = [
        SystemMessage("""You are a helpful AI assistant that uses only the feedback you are provided. Your role is to help in custom searches of scientific papers. Determine if in the provided query data aggregations are requested.
        An example would be "Provide the most cited papers of every country relating spot instances", which is requesting an aggregation by country
        Return JSON with the following properties:
        - requests_aggregation (bool). True if the query requests data aggregations
        - aggregations_requested (list[string]). List containing the expected aggregations. The possible values are the following ones: citations, papers, country, conference, institution, author
        Example response for the query 'Provide the most cited papers of every country relating spot instances in conference Middleware':
        {{
            "requests_aggregation": true,
            "aggregations_requested": ["citations", "country"]
        }}  
        """),
        HumanMessage(f"""Determine the aggregations requested in the following query: <query>{query}</query>:\n"""),
    ]
    response = model_with_structure.invoke(messages)
    return response
    if filter_parameters.countries and len(filter_parameters.countries) > 0:
        df_filter_countries = map_filter_options(filter_options=filter_parameters.countries, source_field="source_country", cited_field="cited_country")
        for filter in df_filter_countries:
            df_filters.append(filter)
            
    if filter_parameters.years and len(filter_parameters.years) > 0:
        df_filter_years = map_filter_options(filter_options=filter_parameters.years, source_field="source_year", cited_field="cited_year")
        for filter in df_filter_years:
            df_filters.append(filter)                                    
            
    if filter_parameters.conferences and len(filter_parameters.conferences) > 0:
        df_filter_conferences = map_filter_options(filter_options=filter_parameters.conferences, source_field="source_conference", cited_field="cited_conference")
        for filter in df_filter_conferences:
            df_filters.append(filter)      
            
    df_filtered = dynamic_filter(filters=df_filters, df=df)
    df_aggregated = dynamic_aggregation(columns=["cited_country"], df=df_filtered)
    return df_aggregated           

def map_filter_options(filter_options: list[FilterOptions], source_field, cited_field: str) -> list[dict]:
    mapping = {
        "cited_equal": [],
        "cited_not_equal": [],
        "source_equal": [],
        "source_not_equal": [],
    }
    df_filters = []
    for filter_option in filter_options:
        if filter_option.citation:
            mapping["cited_equal"].append(filter_option.value) if filter_option.equal else mapping["cited_not_equal"].append(filter_option.value)
        else:            
            mapping["source_equal"].append(filter_option.value) if filter_option.equal else mapping["source_not_equal"].append(filter_option.value)
            
    if len(mapping["source_equal"]) > 0:
        df_filters.append({"field": source_field, "values": mapping["source_equal"], "equal": True})
        
    if len(mapping["source_not_equal"]) > 0:
        df_filters.append({"field": source_field, "values": mapping["source_not_equal"], "equal": False})
        
    if len(mapping["cited_equal"]) > 0:            
        df_filters.append({"field": cited_field, "values": mapping["cited_equal"], "equal": True})
        
    if len(mapping["cited_not_equal"]) > 0:
        df_filters.append({"field": cited_field, "values": mapping["cited_not_equal"], "equal": False})    
    return df_filters
    
         

In [None]:
process_paper_data(filter_parameters=filter_conditions, df=df_citations).head()

In [None]:
df_aggregated = dynamic_aggregation(
    columns=["source_year", "source_title", "source_conference"],
    df=df_filtered
)
df_aggregated.head()

In [None]:
nn_papers

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import create_react_agent
from langchain.agents import create_structured_chat_agent


model = ChatOllama(model="qwen3", temperature=0)

agent = create_structured_chat_agent(
    llm = model,
    tools=[],
    prompt
)

def check_aggregations_requested(query: str, model: ChatOllama) -> dict:
     

In [None]:
from langchain.agents import create_structured_chat_agent
from langchain.prompts import PromptTemplate
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
import polars as pl
from langchain.tools import StructuredTool
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# Define your tools with structured inputs
class CitationAnalysisTool(BaseModel):
    pass  # No input needed for this example

class PublicationTrendTool(BaseModel):
    year_range: list[int] = Field(
        None, 
        description="Optional year range filter [start_year, end_year]"
    )

def get_most_cited_countries(df: pl.DataFrame) -> dict:
    """Get top cited countries from the dataframe"""
    return (
        df.filter(pl.col("cited_country") != "UNKNOWN")
        .group_by("cited_country")
        .agg(pl.len().alias("citation_count"))
        .sort("citation_count", descending=True)
        .to_dicts()
    )

def get_publications_per_year(df: pl.DataFrame, year_range: list = None) -> dict:
    """Get publication count by year with optional filtering"""
    base = df.filter(pl.col("source_year").cast(int).is_not_null())
    
    if year_range:
        base = base.filter(
            pl.col("source_year").cast(int).is_between(year_range[0], year_range[1])
        )
    
    return (
        base.group_by("source_year")
        .agg(pl.len().alias("publication_count"))
        .sort("source_year")
        .to_dicts()
    )

# Create structured tools
tools = [
    StructuredTool.from_function(
        func=lambda _: get_most_cited_countries(df_citations),
        name="CountryCitationAnalysis",
        description="Analyze citation patterns by country",
        args_schema=CitationAnalysisTool
    ),
    StructuredTool.from_function(
        func=lambda yr: get_publications_per_year(df_citations, yr),
        name="PublicationTrendAnalysis",
        description="Analyze publication trends over years",
        args_schema=PublicationTrendTool
    )
]

# Initialize the agent
llm = ChatOllama(model="qwen3", temperature=0, top_k=20, top_p=0.8)
system = '''Respond to the human as helpfully and accurately as possible. You have access to the following tools:

{tools}

Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).

Valid "action" values: "Final Answer" or {tool_names}

Provide only ONE action per $JSON_BLOB, as shown:

```
{{model = ChatOllama(model="qwen3", temperature=0)
```

Follow this format:

Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{
  "action": "Final Answer",
  "action_input": "Final response to human"
}}

Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation /no_think'''

human = '''{input}

{agent_scratchpad}

(reminder to respond in a JSON blob no matter what) /no_think'''

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        # MessagesPlaceholder("chat_history", optional=True),
        ("human", human),
    ]
)

agent = create_structured_chat_agent(
    llm=llm,
    tools=tools,
    prompt=prompt
)

# Example usage
def process_query(query: str):
    response = agent.invoke({
        "input": f"Analyze this research query: {query}",
        "intermediate_steps": []
    })
    return response

# Test queries
print(process_query("Which countries are most cited in cloud research?"))
print(process_query("Show me publication trends between 2010 and 2020"))

In [None]:
def get_most_cited_countries(text) -> str:
    """Get top cited countries from the dataframe"""
    print("tool 1")
    df = (
        df_citations.filter(pl.col("cited_country") != "UNKNOWN")
        .group_by(["cited_country"])\
        .agg(pl.count().alias("citation_count"))\
        .sort("citation_count", descending=True)\
        .limit(10)
    )
    with pl.Config(
        tbl_formatting="MARKDOWN",
        tbl_hide_column_data_types=True,
        tbl_hide_dataframe_shape=True,
    ):    
        return "Here are the most cited countries in markdown format:  \n" + str((df).head(10))
def get_publications_per_year(text) -> dict:
    """Get publication count by year with optional filtering"""
    print("tool 2")
    return "Don't have enough data"

In [None]:
from langchain_core.messages import HumanMessage
model = ChatOllama(model="llama3.2", temperature=0.6, top_k=20, top_p=0.8)
tools = [get_most_cited_countries, get_publications_per_year]
model_with_tools = model.bind_tools(tools)
data = df_citations.to_dicts()

response = model_with_tools.invoke([HumanMessage(f"""
Use the data enclosed in <context> to answer the question asked enclosed in <query>.
If further information is needed use the following tools {tools}
<context>
{""}
</context>
<query>
Tell me the country most publications per year
</query>
/no_think
""")])

In [None]:
response.dict()

In [None]:
df_citations.head()

In [None]:
get_most_cited_countries("")

In [None]:
df_citations.filter(pl.col("cited_country") != "UNKNOWN")\
    .group_by(["cited_title", "cited_country"])\
    .agg(pl.count().alias("citation_count"))\
    .group_by(["cited_country"])\
    .agg(pl.sum("citation_count").alias("citation_count"))\
    .sort("citation_count", descending=True)\
    .limit(10)

In [None]:
df = df_citations.filter(pl.col("cited_country") != "UNKNOWN")\
    .unique(["source_title","cited_title", "cited_country"])\
    .group_by(["source_title","cited_title", "cited_country"])\
    .agg(pl.count().alias("citation_count"))\
    .sort("source_title", descending=True)

with pl.Config(
    tbl_formatting="MARKDOWN",
    tbl_hide_column_data_types=True,
    tbl_hide_dataframe_shape=True,
):
    st = str((df).head(10))
    

In [None]:
from langgraph.prebuilt import create_react_agent
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from pydantic import BaseModel
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_ollama import ChatOllama
import polars as pl

model = ChatOllama(model="qwen3", temperature=0.2)

class State(TypedDict):
    messages: Annotated[list, add_messages]

class AgentTools:
    def __init__(self, df: pl.DataFrame):
        self.df = df
    def get_most_cited_countries(self, text: str):
        """Get top cited countries from the dataframe"""
        print("Uso de tool 1")
        return "Most cited country is USA"
    def get_publications_per_year(self, text: str):
        """Get publication count by year with optional filtering"""
        print("Uso de tool 2")
        return "Most publications done in 2022"
    
    
agent_tools = AgentTools(df=df_citations)

tools = [agent_tools.get_most_cited_countries, agent_tools.get_publications_per_year]

# Tell the LLM which tools it can call
llm_with_tools = model.bind_tools(tools)
def chatbot(state: State):
    print(state)
    return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
graph_builder = StateGraph(State)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)


tool_node = ToolNode(tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge("tools", "chatbot")

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)

class State(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    messages: Annotated[list, add_messages]

graph = graph_builder.compile()



In [None]:
user_input = f"""
Use the right tool to answer the question asked enclosed in <query>.
If further information is needed use the following tools {tools}
<query>
Tell me if the country most cited in papers about stateless cloud computing and
</query>
/no_think
"""
response = graph.invoke({"messages": [{"role": "user", "content": user_input}]})

In [None]:
print(response["messages"][-1].content)

In [None]:
def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
        for value in event.values():
            print("Assistant:", value["messages"][-1].content)
            
stream_graph_updates(f"""
Use the data enclosed in <context> to answer the question asked enclosed in <query>.
If further information is needed use the following tools {tools}
<context>
{""} 
</context>
<query>
Tell me the country most cited in papers about stateless cloud computing and te publications per year
</query>
/no_think
""")

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:

from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
import json

llm = ChatOllama(model="qwen3", temperature=0.2, top_k=20, top_p=0.8)

def extract_filters(query):
    prompt = ChatPromptTemplate.from_template(
        """Extract structured parameters from this query. Return ONLY JSON.
        Available filters:
          - year (list of years)
          - country (use 2-letter abbreviations: US, DE, CA, etc.)
          - institution (list of universities, companies or corporations)
          - conference (list of conferences where the paper was published)
          - author (list of authors who published papers)
          - Aggregation: 'most cited' implies country aggregation

        Operators: $eq, $ne, $in, $nin, $gt, $gte, $lt, $lte, $between

        Output structure:
        {{
          "topic": "research topic of query"
          "source_filters": {{
            "field1": {{"$operator": value}},
            "field2": {{"$operator": value}}
          }},
          "cited_filters": {{
            "field3": {{"$operator": value}}
          }},
          "aggregation": {{
            "type": "count", 
            "group_by": "field",
            "sort": "desc"
          }}
        }}
        
        Special handling:
        - if $between operator first and last value are the same, use $eq instead

        Important: Convert all country names to 2-letter abbreviations! If filters apply to citaitons add them to cited_filters else to source_filters
        Example conversion:
          "USA" → "US"
          "Germany" → "DE"
          "Canada" → "CA"

        Query: {query}
        Output: 
        /no_think"""
    )
    chain = prompt | llm
    response = chain.invoke({"query": query})
    try:
        filters = json.loads(response.content.partition('</think>\n\n')[2])
        # Normalize country abbreviations in filters
        # for filter_type in ["source_filters", "cited_filters"]:
        #     if filter_type in filters and "country" in filters[filter_type]:
        #         country_spec = filters[filter_type]["country"]
                # if "$ne" in country_spec:
                #     country_spec["$ne"] = normalize_country(country_spec["$ne"])
                # if "$in" in country_spec:
                #     country_spec["$in"] = [normalize_country(c) for c in country_spec["$in"]]
        return filters
    except json.JSONDecodeError:
        return {"source_filters": {}, "cited_filters": {}, "aggregation": None}    
    


In [None]:
extract_filters("MAin papers about stateful serverless computing")

In [None]:
SUPPORTED_FIELDS = {
    "year": {"type": "int", "milvus_field": "year", "neo4j_field": "year"},
    "country": {"type": "str", "milvus_field": "country", "neo4j_field": "country"},
    "authors": {"type": "list", "milvus_field": "authors", "neo4j_field": "authors"},
    "venue": {"type": "str", "milvus_field": "venue", "neo4j_field": "venue"},
    "keywords": {"type": "list", "milvus_field": "keywords", "neo4j_field": "keywords"},
    "citation_count": {"type": "int", "milvus_field": "citation_count", "neo4j_field": "citationCount"},
    "document_type": {"type": "str", "milvus_field": "doc_type", "neo4j_field": "documentType"}
}

COUNTRY_MAPPING = {
    "USA": "US", "UNITED STATES": "US", "AMERICA": "US",
    "DEUTSCHLAND": "DE", "GERMANY": "DE",
    "CANADA": "CA", "FRANCE": "FR", "UK": "GB", 
    "UNITED KINGDOM": "GB", "ENGLAND": "GB", "CHINA": "CN",
    "JAPAN": "JP", "AUSTRALIA": "AU", "INDIA": "IN", "BRAZIL": "BR"
}

OPERATOR_KEYS_TO_SYMBOLS = {
    "$gt": ">", 
    "$gte": ">=", 
    "$lt": "<", 
    "$lte": "<=",
}

def normalize_field_value(field, value):
    """Normalize values based on field type and special handling"""
    # if field == "country":
    #     value = value.upper().strip() if isinstance(value, str) else [item.upper().strip() for item in value]
    #     return COUNTRY_MAPPING.get(value, value[:2].upper())
    
    if SUPPORTED_FIELDS[field]["type"] == "list" and not isinstance(value, list):
        return [v.strip() for v in value.split(",")]
    
    return value

def validate_and_normalize_filters(filters):
    """Validate and normalize filter values"""
    normalized = {"source_filters": {}, "cited_filters": {}}
    
    for filter_type in ["source_filters", "cited_filters"]:
        if filter_type not in filters:
            continue
            
        for field, conditions in filters[filter_type].items():
            if field not in SUPPORTED_FIELDS:
                continue
            
            normalized_conds = {}
            for op, value in conditions.items():
                # Handle special operators
                if op == "$between" and isinstance(value, list) and len(value) == 2:
                    normalized_conds["$gte"] = value[0]
                    normalized_conds["$lte"] = value[1]
                else:
                    # Normalize single values
                    if not isinstance(value, list) or op in ["$in", "$nin"]:
                        value = normalize_field_value(field, value)
                    # Normalize list values
                    elif isinstance(value, list):
                        value = [normalize_field_value(field, v) for v in value]
                    
                    normalized_conds[op] = value
            
            if normalized_conds:
                normalized[filter_type][field] = normalized_conds
    
    return normalized

In [None]:
def build_milvus_expr(filters):
    """Build Milvus boolean expression from filters"""
    conditions = []
    
    for field, spec in filters.items():
        if field not in SUPPORTED_FIELDS:
            continue
            
        milvus_field = SUPPORTED_FIELDS[field]["milvus_field"]
        field_type = SUPPORTED_FIELDS[field]["type"]
        
        for op, value in spec.items():
            print(f"op: {op}, value: {value}")
            # Handle different operators
            if op == "$eq":
                if field_type == "str":
                    conditions.append(f"{milvus_field} == '{value}'")
                else:
                    conditions.append(f"{milvus_field} == {value}")
                    
            elif op == "$ne":
                if field_type == "str":
                    conditions.append(f"{milvus_field} != '{value}'")
                else:
                    conditions.append(f"{milvus_field} != {value}")
                    
            elif op == "$in":
                if field_type == "str":
                    items = [f"'{v}'" for v in value]
                else:
                    items = [str(v) for v in value]
                conditions.append(f"{milvus_field} in [{','.join(items)}]")
                
            elif op == "$nin":
                if field_type == "str":
                    items = [f"'{v}'" for v in value]
                else:
                    items = [str(v) for v in value]
                conditions.append(f"not {milvus_field} in [{','.join(items)}]")
                
            elif op in ["$gt", "$gte", "$lt", "$lte"]:
                operator_symbol = OPERATOR_KEYS_TO_SYMBOLS[op]
                conditions.append(f"{milvus_field} {operator_symbol} {value}")
    
    return " and ".join(conditions) if conditions else None

def build_neo4j_conditions(filters, alias="cited"):
    """Build WHERE conditions for Neo4j"""
    conditions = []
    params = {}
    param_count = 0
    
    for field, spec in filters.items():
        
        print(f"field  {field}, spec {spec}")
        if field not in SUPPORTED_FIELDS:
            continue
            
        neo4j_field = SUPPORTED_FIELDS[field]["neo4j_field"]
        field_type = SUPPORTED_FIELDS[field]["type"]
        
        for op, value in spec.items():
            param_name = f"param{param_count}"
            param_count += 1
            
            print(f"op {op}, value {value}")
            
            if op == "$eq":
                if field_type == "str":
                    conditions.append(f"{alias}.{neo4j_field} = ${param_name}")
                else:
                    conditions.append(f"{alias}.{neo4j_field} = ${param_name}")
                params[param_name] = value
                
            elif op == "$ne":
                if field_type == "str":
                    conditions.append(f"{alias}.{neo4j_field} <> ${param_name}")
                else:
                    conditions.append(f"{alias}.{neo4j_field} <> ${param_name}")
                params[param_name] = value
                
            elif op == "$in":
                conditions.append(f"{alias}.{neo4j_field} IN ${param_name}")
                params[param_name] = value
                
            elif op == "$nin":
                conditions.append(f"NOT {alias}.{neo4j_field} IN ${param_name}")
                params[param_name] = value
                
            elif op == "$gt":
                conditions.append(f"{alias}.{neo4j_field} > ${param_name}")
                params[param_name] = value
                
            elif op == "$gte":
                conditions.append(f"{alias}.{neo4j_field} >= ${param_name}")
                params[param_name] = value
                
            elif op == "$lt":
                conditions.append(f"{alias}.{neo4j_field} < ${param_name}")
                params[param_name] = value
                
            elif op == "$lte":
                conditions.append(f"{alias}.{neo4j_field} <= ${param_name}")
                params[param_name] = value
    
    where_clause = " AND ".join(conditions)
    return f"WHERE {where_clause}" if conditions else "", params

In [None]:
query = "stateful serverless computing most cited countries"
filters = extract_filters(query)
print(filters)

valid_filters = validate_and_normalize_filters(filters)
print(valid_filters)

expr = build_milvus_expr(valid_filters.get("source_filters", {}))
print(expr)

neo4j_where = build_neo4j_conditions(valid_filters.get("cited_filters", {}))
print(neo4j_where)

In [None]:
def build_milvus_expr(filters):
    conditions = []
    for field, spec in filters.items():
        if field == "year" and "$in" in spec:
            years = spec["$in"]
            conditions.append(f"year in {years}")
        elif field == "country":
            if "$ne" in spec:
                abbr = normalize_country(spec["$ne"])
                conditions.append(f"country != '{abbr}'")
            if "$in" in spec:
                abbrs = [normalize_country(c) for c in spec["$in"]]
                conditions.append(f"country in {abbrs}")
    return " and ".join(conditions) if conditions else None

# Agent with multiple conditions

In [None]:
query = "Provide me a list of the main papers about serverless cloud computing"


from src.papers.io import db
from dotenv import load_dotenv
import os
import polars as pl

load_dotenv()
# TODO variables de entorno y no usar caminos relativos
EXTENDED_CRAWLER_DATA_PATH = os.getenv("EXTENDED_CRAWLER_DATA_PATH")
MILVUS_DB = os.getenv("MILVUS_DB")
MILVUS_COLLECTION = os.getenv("MILVUS_COLLECTION")
MILVUS_ALIAS = os.getenv("MILVUS_ALIAS")
MILVUS_HOST = os.getenv("MILVUS_HOST")
MILVUS_PORT = os.getenv("MILVUS_PORT")

milvus_client = db.Milvus(
    # db=MILVUS_DB, 
    collection=MILVUS_COLLECTION, 
    alias=MILVUS_ALIAS,
    host=MILVUS_HOST,
    port=MILVUS_PORT, 
    new_collection=False
)   

nn_papers = milvus_client.search(
    text=query,
    output_fields=[
        "Title",
        "TLDR",
        "Abstract",
        "KeyConcepts",
        "Year",
        "Conference",
        "Summary",
        "AuthorsAndInstitutions"
    ],
    limit=100,
    hybrid=True,
    hybrid_fields=[
        "AbstractVector", 
        "TitleVector", 
        "TLDRVector",
        "KeyConceptsVector"
    ],
    # expr=expr
    expr=""
)

df = pl.DataFrame(nn_papers)

##################################################################

from langgraph.prebuilt import create_react_agent
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from pydantic import BaseModel
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_ollama import ChatOllama
import polars as pl
from src.papers.domain.agent_tools import AgentTools

agent_tools = AgentTools(df)

model = ChatOllama(model="qwen3", temperature=0.2)

class State(TypedDict):
    messages: Annotated[list, add_messages]

class AgentTools:
    def __init__(self, df: pl.DataFrame):
        self.df = df
    def get_most_cited_countries(self, text: str):
        """Get top cited countries from the dataframe"""
        print("Uso de tool 1")
        return "Most cited country is USA"
    def get_publications_per_year(self, text: str):
        """Get publication count by year with optional filtering"""
        print("Uso de tool 2")
        return "Most publications done in 2022"
    
    
tools = [agent_tools.get_most_cited_countries, agent_tools.get_publications_per_year, agent_tools.get_other_citation_requested]

# Tell the LLM which tools it can call
llm_with_tools = model.bind_tools(tools)
def chatbot(state: State):
    print(state)
    return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
graph_builder = StateGraph(State)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)


tool_node = ToolNode(tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge("tools", "chatbot")

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)

class State(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    messages: Annotated[list, add_messages]

graph = graph_builder.compile()



In [None]:
SYSTEM_MESSAGE="""
You are an intelligent AI assisstant which is responsible of helping the user consult research paper data.
Your role is to use the provided tools to retrieve the needed information about papers and answer using it.
IMPORTANT: Never use the same tool twice for the same question. 
"""

USER_MESSAGE = f"""
Use the right tool to answer the question asked enclosed in <query>. If not adecuate tool is found, use get_other_citation_requested
If further information is needed use the following tools {tools}
<query>
{query}
</query>
/no_think
"""

response = graph.invoke({
    "messages": [
        
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": USER_MESSAGE},
    ]
        })
        

In [None]:
print(response["messages"][-1].content.partition('</think>\n\n')[2])

In [None]:
df.select("id", "distance").head(100)

In [None]:
with pl.Config(tbl_rows=-1):
    print(df.select("id", "distance").head(-1))

In [None]:
print(df.select("id", "distance").head(-1))