In [47]:
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage, AIMessage, AnyMessage
from langgraph.graph import StateGraph, START, END
from typing import List, Dict, Any, TypedDict, Annotated
from pydantic import BaseModel
from langgraph.constants import Send
import json
from operator import add
from fcgb.cfg.precompiled import get_llm, get_checkpointer
from fcgb.prompt_manager import PromptManager

In [27]:
import os
import requests
from urllib.parse import urlparse
import urllib.request as libreq
import feedparser
import fitz  # PyMuPDF
import re

def sanitize_query(query: str) -> str:
    """
    Sanitizes a query string for API calls by removing special characters
    and replacing white spaces with '+'.

    Args:
        query (str): The input query string.

    Returns:
        str: The sanitized query string.
    """
    # Remove special characters except alphanumeric and spaces
    sanitized = re.sub(r'[^\w\s]', '', query)
    # Replace white spaces with '+'
    sanitized = sanitized.replace(' ', '+')
    return sanitized

def create_path_if_not_exists(path: str):
    """
    Creates a directory path if it doesn't exist.

    Args:
        path (str): The directory path to create.
    """
    if not os.path.exists(path):
        os.makedirs(path)


def arxiv_pdf_link_extractor(links: List[Dict[str, str]]) -> str | None:
    """
    Extracts the PDF link from a list of links.
    Args:
        links (List[Dict[str, str]]): A list of dictionaries containing link information.
    Returns:
        str | None: The PDF link if found, otherwise None.
    """

    link = [link['href'] for link in links if link.get('title', '') == 'pdf']
    if len(link) > 0:
        return link[0]
    return None

def arxiv_search(search_query: str, start: int = 0, max_results: int = 5) -> List[Dict[str, Any]]:
    """
    Searches the arXiv API for papers matching the search query.

    Args:
        search_query (str): The query to search for.
        start (int): The starting index for results.
        max_results (int): The maximum number of results to return.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries containing paper titles, summaries, publish dates and URLs.
    """
    api_call = 'http://export.arxiv.org/api/query?search_query=all:%s&start=%i&max_results=%i' % (
        sanitize_query(search_query), start, max_results)
    try:
        response = libreq.urlopen(api_call)
        if response.status != 200:
            raise Exception(f"Error fetching data from arXiv API: {response.status}")
        results = feedparser.parse(response.read())
        docs = [{
            'title': entry.title, 
            'summary': entry.summary,
            'published': entry.published,
            'url': arxiv_pdf_link_extractor(entry.links)
            } for entry in results.entries]
        return [doc for doc in docs if doc['url'] is not None]
    except Exception as e:
        print(e)
        return []
    

def download_pdf(url: str, save_dir: str):
    """
    Downloads a PDF file from a URL to a specified path.

    Args:
        url (str): The URL of the PDF file.
        save_path (str): The path where the PDF file will be saved.
    Returns:
        str: The path where the PDF file was saved.
    """
    # Extract the filename from the URL
    filename = os.path.basename(urlparse(url).path) + '.pdf'
    save_path = os.path.join(save_dir, filename)

    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    try:
        # Download the file
        response = requests.get(url, stream=True)
        if response.status_code != 200:
            raise Exception(f"Failed to download file: {response.status_code}")
        with open(save_path, 'wb') as file:
            for chunk in response.iter_content(chunk_size=1024):
                file.write(chunk)
        return save_path
    except Exception as e:
        print(f"Error downloading PDF: {e}")
        return None

In [28]:
class arXivDocSpec(TypedDict):
    title: str
    summary: str
    published: str
    url: str

class PapersRetrieveGraphState(BaseModel):
    search_query: str
    docs_specs: List[arXivDocSpec] = []
    metadata_files: Annotated[List[str], add] = []

class PapersRetrieveGraph:
    def __init__(
            self,
            docs_path,
            docs_metadata_path,
            memory=None,
            max_results: int = 5,
    ):
        self.docs_path = docs_path
        self.docs_metadata_path = docs_metadata_path
        self.memory = memory
        self.max_results = max_results

        create_path_if_not_exists(self.docs_path)
        create_path_if_not_exists(self.docs_metadata_path)

        self.build_graph()

    def search_node(self, state: PapersRetrieveGraphState):
        """
        Searches for papers on arXiv based on the search query in the state.
        """
        print(f"Searching for papers with query: {state.search_query}")
        docs_specs = arxiv_search(state.search_query, max_results=self.max_results, start=0)
        return {'docs_specs': docs_specs}
    
    def download_router(self, state: PapersRetrieveGraphState):
        """
        Distributes docs specifications into download nodes.
        """
        return [Send("doc_downloader", doc_spec) for doc_spec in state.docs_specs]
    
    def doc_download_node(self, doc_spec: arXivDocSpec):
        """
        Downloads a PDF file based on the document specification provided and saves its metadata.
        """
        print(f"Downloading PDF for: {doc_spec['title']}")
        result = download_pdf(doc_spec['url'], self.docs_path)
        if result is not None:
            doc_spec['path'] = result
            # Get pdf file pages count with PyMuPDF
            try:
                pdf_document = fitz.open(result)
                doc_spec['pages_count'] = pdf_document.page_count
                pdf_document.close()
            except Exception as e:
                print(f"Error reading PDF file {result}: {e}")
                doc_spec['pages_count'] = None
            # Save doc_spec to metadata json file
            metadata_file = os.path.join(self.docs_metadata_path, f"{'.'.join(os.path.basename(result).split('.')[:-1])}.json")
            with open(metadata_file, 'w') as f:
                json.dump(doc_spec, f, indent=4)
            return {'metadata_files': [metadata_file]}
        return {'metadata_files': []}
    
    def build_graph(self):
        """
        Builds the state graph for retrieving papers from arXiv.
        """
        workflow = StateGraph(PapersRetrieveGraphState)
        workflow.add_node('search_docs', self.search_node)
        workflow.add_node('doc_downloader', self.doc_download_node)

        workflow.add_edge(START, 'search_docs')
        workflow.add_conditional_edges('search_docs', self.download_router, ['doc_downloader'])
        workflow.add_edge('doc_downloader', END)

        self.graph = workflow.compile(checkpointer=self.memory)

    def run(self, search_query: str, thread_id: str = None):
        """
        Runs the graph with the provided search query.

        Args:
            search_query (str): The query to search for papers on arXiv.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        return self.graph.invoke({'search_query': search_query}, config=config)

class QueriesModel(BaseModel):
    queries: List[str]

class RandomQueriesPaperSearchGraphState(BaseModel):
    queries: QueriesModel = []
    metadata_files: Annotated[List[str], add] = []


class RandomQueriesPaperSearchGraph:
    """
    A graph for generating random queries and searching for papers on arXiv.
    This graph generates random queries, retrieves papers based on those queries,
    and generates additional queries based on the retrieved papers.
    """
    def __init__(
            self,
            llm,
            prompts_config: Dict,
            docs_path: str,
            docs_metadata_path: str,
            memory=None,
            main_queries_num: int = 5,
            paper_queries_num: int = 10,
            max_results: int = 5,
            prompt_manager_spec: Dict = {}
    ):
        """
        Initializes the RandomQueriesPaperSearchGraph.
        Args:
            llm: The language model to use for generating queries.
            prompts_config (Dict): Configuration for prompts.
                Configuration should include:
                'path': main path where prompts are stored,
                'random_queries': title of the prompt for generating random queries,
                'paper_queries': title of the prompt for generating queries based on papers.
            docs_path (str): Path to save downloaded documents.
            docs_metadata_path (str): Path to save metadata of downloaded documents.
            memory: Optional memory for the graph.
            main_queries_num (int): Number of main queries to generate.
            paper_queries_num (int): Number of queries to generate for each paper.
            max_results (int): Maximum number of results to retrieve from arXiv.
            prompt_manager_spec (Dict): Specification for the prompt manager, especially prompts versions.
                ex.
                {'version_config': {'random_queries': 'v1.0', 'paper_queries': 'v1.1'}}
        """
        self.llm = llm
        self.prompts_config = prompts_config
        self.docs_path = docs_path
        self.docs_metadata_path = docs_metadata_path
        self.memory = memory
        self.main_queries_num = main_queries_num
        self.paper_queries_num = paper_queries_num
        self.prompt_manager_spec = prompt_manager_spec
        self.max_results = max_results

        create_path_if_not_exists(self.docs_path)
        create_path_if_not_exists(self.docs_metadata_path)

        self._set_promps()
        self._set_downloader()

        self.build_graph()

    def _set_downloader(self):
        """
        Initializes the downloader graph for retrieving papers from arXiv.
        """
        self.downloader = PapersRetrieveGraph(
            docs_path=self.docs_path,
            docs_metadata_path=self.docs_metadata_path,
            max_results=self.max_results
        )

    def _set_promps(self):
        """
        Initializes the prompt manager and retrieves the prompts for generating queries.
        """
        prompt_manager = PromptManager(
            version_config=self.prompt_manager_spec,
            path=self.prompts_config['path']
        )

        self.prompts = prompt_manager.get_prompts([self.prompts_config[name] for name in ['random_queries', 'paper_queries']])

    def main_queries_node(self, state: RandomQueriesPaperSearchGraphState):
        """
        Generates a set of random queries for searching papers.
        """    
        print("Generating random queries for paper search...")
        query_prompt = self.prompts['random_queries'].format(
            main_queries_num=self.main_queries_num
        )
        queries = self.llm.with_structured_output(QueriesModel).invoke(query_prompt)
        return {'queries': queries}
    
    def query_routing(self, state: RandomQueriesPaperSearchGraphState):
        """
        Distributes queries into paper search nodes.
        """
        return [Send("paper_search", {'search_query': query}) for query in state.queries.queries]
    
    def papers_routing(self, state: RandomQueriesPaperSearchGraphState):
        """
        Distributes metadata files paths into paper queries nodes.
        """
        return [Send("paper_queries", file) for file in state.metadata_files]
    
    def paper_queries(self, metadata_path: str):
        """
        Generates paper queries based on the metadata file.
        """
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        print(f"Generating queries for paper: {metadata['title']}")
        
        query_prompt = self.prompts['paper_queries'].format(
            title=metadata['title'],
            summary=metadata['summary'],
            paper_queries_num=self.paper_queries_num
        )
        
        queries = self.llm.with_structured_output(QueriesModel).invoke(query_prompt)

        metadata['queries'] = queries.queries
        # Save updated metadata with queries
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=4)
        return {}
    
    def build_graph(self):
        """
        Builds the state graph for retrieving papers based on random queries.
        """
        workflow = StateGraph(RandomQueriesPaperSearchGraphState)
        workflow.add_node('main_queries', self.main_queries_node)
        workflow.add_node('paper_search', self.downloader.graph)
        workflow.add_node('paper_queries', self.paper_queries)

        workflow.add_edge(START, 'main_queries')
        workflow.add_conditional_edges('main_queries', self.query_routing, ['paper_search'])
        workflow.add_conditional_edges('paper_search', self.papers_routing, ['paper_queries'])
        workflow.add_edge('paper_queries', END)

        self.graph = workflow.compile(checkpointer=self.memory)

    def run(self, thread_id: str = None):
        """
        Runs the graph to generate random queries and search for papers.

        Args:
            thread_id (str): Optional thread ID for memory management.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        return self.graph.invoke({}, config=config)

In [14]:
downloader = PapersRetrieveGraph(
    docs_path='../docs',
    docs_metadata_path='../docs_metadata',
    max_results=3,
)

In [15]:
downloader.run(search_query='Chain of Thought')

Searching for papers with query: Chain of Thought
Downloading PDF for: Contrastive Chain-of-Thought Prompting
Downloading PDF for: Beyond Chain-of-Thought, Effective Graph-of-Thought Reasoning in
  Language Models
Downloading PDF for: CoT is Not True Reasoning, It Is Just a Tight Constraint to Imitate: A
  Theory Perspective
Error downloading PDF: Failed to download file: 404


{'search_query': 'Chain of Thought',
 'docs_specs': [{'title': 'Contrastive Chain-of-Thought Prompting',
   'summary': 'Despite the success of chain of thought in enhancing language model\nreasoning, the underlying process remains less well understood. Although\nlogically sound reasoning appears inherently crucial for chain of thought,\nprior studies surprisingly reveal minimal impact when using invalid\ndemonstrations instead. Furthermore, the conventional chain of thought does not\ninform language models on what mistakes to avoid, which potentially leads to\nmore errors. Hence, inspired by how humans can learn from both positive and\nnegative examples, we propose contrastive chain of thought to enhance language\nmodel reasoning. Compared to the conventional chain of thought, our approach\nprovides both valid and invalid reasoning demonstrations, to guide the model to\nreason step-by-step while reducing reasoning mistakes. To improve\ngeneralization, we introduce an automatic method t

In [29]:
query_runner = RandomQueriesPaperSearchGraph(
    llm=get_llm(llm_model='google'),
    prompts_config={
        'path': '../prompts',
        'random_queries': 'random_queries',
        'paper_queries': 'paper_queries'
    },
    docs_path='../docs',
    docs_metadata_path='../docs_metadata',
    main_queries_num=3,
    paper_queries_num=10,
    max_results=5,
)

In [30]:
query_runner.run()

Generating random queries for paper search...
Searching for papers with query: explainable AI methods for fraud detection in banking
Searching for papers with query: federated learning privacy preserving techniques healthcare
Searching for papers with query: transformer networks time series forecasting stock market
Downloading PDF for: MASTER: Market-Guided Stock Transformer for Stock Price Forecasting
Downloading PDF for: Quantum-Enhanced Forecasting: Leveraging Quantum Gramian Angular Field
  and CNNs for Stock Return Predictions
Downloading PDF for: An Evaluation of Deep Learning Models for Stock Market Trend Prediction
Downloading PDF for: Transformer Based Time-Series Forecasting for Stock
Downloading PDF for: Stock Market Telepathy: Graph Neural Networks Predicting the Secret
  Conversations between MINT and G7 Countries
Downloading PDF for: Vision Through the Veil: Differential Privacy in Federated Learning for
  Medical Image Classification
Downloading PDF for: Robust Aggregati

{'queries': QueriesModel(queries=['explainable AI methods for fraud detection in banking', 'federated learning privacy preserving techniques healthcare', 'transformer networks time series forecasting stock market']),
 'metadata_files': ['../docs_metadata/2406.11389v1.json',
  '../docs_metadata/1811.08212v1.json',
  '../docs_metadata/2206.12415v1.json',
  '../docs_metadata/2409.13406v1.json',
  '../docs_metadata/2108.02501v3.json',
  '../docs_metadata/2306.17794v1.json',
  '../docs_metadata/2009.08294v1.json',
  '../docs_metadata/2406.15962v1.json',
  '../docs_metadata/2110.03478v2.json',
  '../docs_metadata/2406.05517v1.json',
  '../docs_metadata/2312.15235v1.json',
  '../docs_metadata/2310.07427v3.json',
  '../docs_metadata/2408.12408v1.json',
  '../docs_metadata/2502.09625v1.json',
  '../docs_metadata/2506.01945v1.json']}

## Chunks labeling

In [243]:
class ChunkEvalModel(TypedDict):
    idx: int
    text: str
    label: bool

class ChunkEvalLabel(BaseModel):
    evaluation: bool

class ChunkEvalState(BaseModel):
    query: str
    query_idx: int
    doc_file: str
    title: str
    summary: str
    context: str = ''
    chunks: List[str] = []
    chunks_num: int = 0
    chunk_idx: int = 0
    chunks_eval: Annotated[List[ChunkEvalModel], add] = []

In [282]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm
from tqdm.notebook import tqdm as notebook_tqdm
import asyncio

class ChunkEvalGraph:
    def __init__(
            self,
            llm,
            prompts_config: Dict,
            docs_metadata_path: str,
            saving_path: str,
            memory=None,
            chunk_size: int = 600,
            chunk_overlap: int = 0,
            context_agg_interval: int = 5,
            prompt_manager_spec: Dict = {}
    ):
        """
        Initializes the ChunkEvalGraph for evaluating chunks of text based on a query.
        
        Args:
            llm: The language model to use for evaluating chunks.
            prompts_config (Dict): Configuration for prompts.
                Configuration should include:
                'path': main path where prompts are stored,
                'chunk_eval_system': title of the system prompt for chunk evaluation.
                'chunk_eval_task': title of the prompt for chunk evaluation task.
                'doc_context_system': title of the prompt for document's current context system message.
                'doc_context_update': title of the prompt for updating document's current context.
                'doc_context_aggregation': title of the prompt for aggregating document's context.
            docs_metadata_path (str): Path to save metadata of downloaded documents.
            saving_path (str): Path to save evaluation results.
            memory: Optional memory for the graph.
            chunk_size (int): Size of each chunk to evaluate.
            chunk_overlap (int): Overlap between chunks.
            context_agg_interval (int): Interval for aggregating context.
            prompt_manager_spec (Dict): Specification for the prompt manager, especially prompts versions.
        """
        self.llm = llm
        self.prompts_config = prompts_config
        self.docs_metadata_path = docs_metadata_path
        self.saving_path = saving_path
        self.memory = memory
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.context_agg_interval = context_agg_interval
        self.prompt_manager_spec = prompt_manager_spec

        create_path_if_not_exists(self.saving_path)

        self._set_text_splitter()
        self._set_promps()
        self.build_graph()
    
    def _set_promps(self):
        """
        Initializes the prompt manager and retrieves the prompts for chunk evaluation.
        """
        prompt_manager = PromptManager(
            version_config=self.prompt_manager_spec,
            path=self.prompts_config['path']
        )

        self.prompts = prompt_manager.get_prompts([self.prompts_config[name] for name in ['chunk_eval_system', 'chunk_eval_task', 'doc_context_system', 'doc_context_update', 'doc_context_aggregation']])

    def _set_text_splitter(self):
        """
        Initializes the text splitter for chunking documents.
        """
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=len
        )

    def chunking_node(self, state: ChunkEvalState):
        
        try:
            loader = PyPDFLoader(state.doc_file, mode='single', extraction_mode='plain')
            content = loader.load()[0].page_content
            chunks = self.text_splitter.split_text(content)
            
            chunks_num = len(chunks)
        except Exception as e:
            print(f"Error loading or chunking document {state.doc_file}: {e}")
            chunks = []
            chunks_num = 0

        #print(f"Chunking document: {state.doc_file} into {chunks_num} chunks.")

        return {
            'chunks': chunks,
            'chunks_num': chunks_num,
            'chunk_idx': 0,
            'context': '',
        }
    
    def _get_template_inputs(self, state: ChunkEvalState):
        """
        Prepares the template inputs for the system and task messages.
        """
        prev_chunk = state.chunks[state.chunk_idx - 1] if state.chunk_idx > 0 else ''
        current_chunk = state.chunks[state.chunk_idx] if state.chunk_idx < state.chunks_num else ''
        next_chunk = state.chunks[state.chunk_idx + 1] if state.chunk_idx + 1 < state.chunks_num else ''

        return {
            'title': state.title,
            'summary': state.summary,
            'query': state.query,
            'context': state.context,
            'prev_chunk': prev_chunk,
            'current_chunk': current_chunk,
            'next_chunk': next_chunk,
            'current_chunk_index': state.chunk_idx + 1,
            'total_chunks_num': state.chunks_num
        }
    
    def chunk_eval_node(self, state: ChunkEvalState):
        """
        Evaluates the current chunk based on the query and updates the context.
        """

        template_inputs = self._get_template_inputs(state)
        system_msg = SystemMessage(content=self.prompts['chunk_eval_system'].format(**template_inputs))
        task_msg = HumanMessage(content=self.prompts['chunk_eval_task'].format(**template_inputs))

        eval_result = self.llm.with_structured_output(ChunkEvalLabel).invoke([system_msg, task_msg]).evaluation

        chunk_eval = ChunkEvalModel(
            idx=state.chunk_idx,
            text=template_inputs['current_chunk'],
            label=eval_result
        )

        return {'chunks_eval': [chunk_eval]}
    
    def context_update_node(self, state: ChunkEvalState):
        """
        Updates the context based on the evaluation of the current chunk.
        """
        
        template_inputs = self._get_template_inputs(state)
        system_msg = SystemMessage(content=self.prompts['doc_context_system'].format(**template_inputs))
        update_msg = HumanMessage(content=self.prompts['doc_context_update'].format(**template_inputs))

        context_update = self.llm.invoke([system_msg, update_msg]).content

        return {
            'context': context_update,
            'chunk_idx': state.chunk_idx + 1
        }
    
    def context_aggregation_node(self, state: ChunkEvalState):
        """
        Aggregates current summaries
        """
        template_inputs = self._get_template_inputs(state)
        agg_msg = HumanMessage(content=self.prompts['doc_context_aggregation'].format(**template_inputs))

        context_update = self.llm.invoke([agg_msg]).content

        return {'context': context_update}
    
    def context_aggregation_routing(self, state: ChunkEvalState):
        """
        Determines whether to aggregate context based on the current chunk index.
        If the current chunk index is a multiple of the context aggregation interval and greater than 0,
        it returns 'context_aggregation', otherwise it returns 'chunk_eval_node'.
        """
        if state.chunk_idx % self.context_agg_interval == 0 and state.chunk_idx > 0:
            return 'context_aggregation'
        return 'chunk_eval'
        
    def routing_edge(self, state: ChunkEvalState):
        """
        Determines the next state based on the current chunk index.
        If there are more chunks to evaluate, it returns 'chunk_eval', otherwise it returns 'END'.
        """
        if state.chunk_idx >= state.chunks_num - 1:
            return 'save_chunks_eval'
        return 'context_update'
    
    def save_chunks_eval_node(self, state: ChunkEvalState):
        print(f"Saving evaluation results for query {state.query_idx} on document {state.doc_file}...")
        eval_file = os.path.join(self.saving_path, f"{'.'.join(os.path.basename(state.doc_file).split('.')[:-1])}_{state.query_idx}.json")
        with open(eval_file, 'w') as f:
            json.dump({
                'query': state.query,
                'query_idx': state.query_idx,
                'doc_file': state.doc_file,
                'title': state.title,
                'chunks_eval': state.chunks_eval
            }, f, indent=4)

        return {'chunk_idx': state.chunk_idx + 1}

    def build_graph(self):
        """
        Builds the state graph for chunk evaluation.
        """
        workflow = StateGraph(ChunkEvalState)
        workflow.add_node('chunking', self.chunking_node)
        workflow.add_node('chunk_eval', self.chunk_eval_node)
        workflow.add_node('context_update', self.context_update_node)
        workflow.add_node('context_aggregation', self.context_aggregation_node)
        workflow.add_node('save_chunks_eval', self.save_chunks_eval_node)

        workflow.add_edge(START, 'chunking')
        workflow.add_edge('chunking', 'chunk_eval')
        workflow.add_conditional_edges('chunk_eval', self.routing_edge, ['context_update', 'save_chunks_eval'])
        workflow.add_conditional_edges('context_update', self.context_aggregation_routing, ['context_aggregation', 'chunk_eval'])
        workflow.add_edge('context_aggregation', 'chunk_eval')
        workflow.add_edge('save_chunks_eval', END)

        self.graph = workflow.compile(checkpointer=self.memory)

    def run(self, query: str, query_idx: int, doc_file: str, title: str, summary: str, thread_id: str = None):
        """
        Runs the graph to evaluate chunks of a document based on a query.

        Args:
            query (str): The query to evaluate chunks against.
            query_idx (int): The index of the query in the list of queries.
            doc_file (str): The path to the document file to be evaluated.
            title (str): The title of the document.
            summary (str): The summary of the document.
            thread_id (str): Optional thread ID for memory management.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        return self.graph.invoke({
            'query': query,
            'query_idx': query_idx,
            'doc_file': doc_file,
            'title': title,
            'summary': summary
        }, config=config)
    
    def run_with_progress(self, query: str, query_idx: int, doc_file: str, title: str, summary: str, thread_id: str = None):
        """
        Runs the graph with a progress bar to track chunk evaluations.

        Args:
            query (str): The query to evaluate chunks against.
            query_idx (int): The index of the query in the list of queries.
            doc_file (str): The path to the document file to be evaluated.
            title (str): The title of the document.
            summary (str): The summary of the document.
            thread_id (str): Optional thread ID for memory management.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        state = self.graph.invoke({
            'query': query,
            'query_idx': query_idx,
            'doc_file': doc_file,
            'title': title,
            'summary': summary
        }, 
        config=config,
        interrupt_before='chunk_eval')

        chunks_num = state['chunks_num']
        positive_evaluations = 0
        negative_evaluations = 0

        with tqdm(total=chunks_num, desc="Evaluating Chunks", postfix={"Positive": positive_evaluations, "Negative": negative_evaluations, "Query_id": query_idx}) as pbar:
            while state['chunk_idx'] < chunks_num:
                # Invoke the next step in the graph
                state = self.graph.invoke(input=None, config={'configurable': {'thread_id': thread_id}}, interrupt_before='chunk_eval')
                
                # Update evaluations
                positive_evaluations = sum([1 for chunk in state['chunks_eval'] if chunk['label']])
                negative_evaluations = sum([1 for chunk in state['chunks_eval'] if not chunk['label']])
                
                # Update progress bar
                pbar.update(1)
                pbar.set_postfix(Positive=positive_evaluations, Negative=negative_evaluations)

    async def run_with_progress_async(self, query: str, query_idx: int, doc_file: str, title: str, summary: str, thread_id: str = None):
        """
        Runs the graph asynchronously with a progress bar to track chunk evaluations.

        Args:
            query (str): The query to evaluate chunks against.
            query_idx (int): The index of the query in the list of queries.
            doc_file (str): The path to the document file to be evaluated.
            title (str): The title of the document.
            summary (str): The summary of the document.
            thread_id (str): Optional thread ID for memory management.

        Returns:
            dict: The final state after all chunks are evaluated.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        state = await asyncio.to_thread(self.graph.invoke, {
            'query': query,
            'query_idx': query_idx,
            'doc_file': doc_file,
            'title': title,
            'summary': summary
        }, config=config, interrupt_before='chunk_eval')

        chunks_num = state['chunks_num']
        positive_evaluations = 0
        negative_evaluations = 0

        # Initialize the progress bar
        pbar = notebook_tqdm(total=chunks_num, desc="Evaluating Chunks", postfix={"Positive": positive_evaluations, "Negative": negative_evaluations, "Query_id": query_idx})
        try:
            while state['chunk_idx'] < chunks_num:
                # Invoke the next step in the graph asynchronously
                state = await asyncio.to_thread(self.graph.invoke, input=None, config={'configurable': {'thread_id': thread_id}}, interrupt_before='chunk_eval')

                # Update evaluations
                positive_evaluations = sum([1 for chunk in state['chunks_eval'] if chunk['label']])
                negative_evaluations = sum([1 for chunk in state['chunks_eval'] if not chunk['label']])

                # Update progress bar
                pbar.update(1)
                pbar.set_postfix(Positive=positive_evaluations, Negative=negative_evaluations, Query_id=query_idx)
        finally:
            pbar.close()

        return state

In [283]:
from uuid import uuid4

async def run_all_queries_async(chunk_eval_graph, metadata_file_path):
    """
    Runs run_with_progress_async concurrently for all queries in the metadata file.

    Args:
        chunk_eval_graph (ChunkEvalGraph): The ChunkEvalGraph instance to run evaluations.
        metadata_file_path (str): Path to the metadata file containing queries and document information.

    Returns:
        dict: A dictionary containing the results for each query.
    """
    # Load metadata file
    with open(metadata_file_path, 'r') as f:
        metadata = json.load(f)

    queries = metadata.get('queries', [])
    doc_file = metadata.get('path')
    title = metadata.get('title')
    summary = metadata.get('summary')

    if not queries or not doc_file or not title or not summary:
        raise ValueError("Metadata file is missing required fields.")

    async def run_query_async(query_idx, query):
        thread_id = uuid4().hex
        inputs = {
            'query': query,
            'query_idx': query_idx,
            'doc_file': doc_file,
            'title': title,
            'summary': summary
        }
        #print(f"Starting async evaluation for query {query_idx}...")
        state = await chunk_eval_graph.run_with_progress_async(**inputs, thread_id=thread_id)
        return query_idx, state

    # Run all queries concurrently
    tasks = [run_query_async(idx, query) for idx, query in enumerate(queries)]
    results = await asyncio.gather(*tasks)

In [287]:
memory = get_checkpointer(checkpointer_mode='local', mode='test')

In [288]:
chunk_eval_graph = ChunkEvalGraph(
    llm=get_llm(llm_model='google'),
    prompts_config={
        'path': '../prompts',
        'chunk_eval_system': 'chunk_eval_system',
        'chunk_eval_task': 'chunk_eval_task',
        'doc_context_system': 'doc_context_system',
        'doc_context_update': 'doc_context_update',
        'doc_context_aggregation': 'doc_context_aggregation'
    },
    memory=memory,
    docs_metadata_path='../docs_metadata',
    saving_path='../chunk_eval_results',
    chunk_size=600,
    chunk_overlap=0,
    context_agg_interval=5,
    prompt_manager_spec={}
)

In [289]:
await run_all_queries_async(chunk_eval_graph, "../docs_metadata/2502.09625v1.json")

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=3]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=4]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=2]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=5]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=0]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=6]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=7]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=1]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=8]

Evaluating Chunks:   0%|          | 0/59 [00:00<?, ?it/s, Negative=0, Positive=0, Query_id=9]

Saving evaluation results for query 8 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 1 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 4 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 2 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 0 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 5 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 6 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 7 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 3 on document ../docs/2502.09625v1.pdf...
Saving evaluation results for query 9 on document ../docs/2502.09625v1.pdf...


In [250]:
docs_metadata_path = '../docs_metadata'

metadata_file = '2502.09625v1.json'
query_idx = 1

with open(os.path.join(docs_metadata_path, metadata_file), 'r') as f:
    metadata = json.load(f)

paper_query = metadata['queries'][query_idx]

inputs = {
    'query': paper_query,
    'query_idx': query_idx,
    'doc_file': metadata['path'],
    'title': metadata['title'],
    'summary': metadata['summary']
}

inputs

{'query': 'How does Stockformer differ from the original Transformer model?',
 'query_idx': 1,
 'doc_file': '../docs/2502.09625v1.pdf',
 'title': 'Transformer Based Time-Series Forecasting for Stock',
 'summary': 'To the naked eye, stock prices are considered chaotic, dynamic, and\nunpredictable. Indeed, it is one of the most difficult forecasting tasks that\nhundreds of millions of retail traders and professional traders around the\nworld try to do every second even before the market opens. With recent advances\nin the development of machine learning and the amount of data the market\ngenerated over years, applying machine learning techniques such as deep\nlearning neural networks is unavoidable. In this work, we modeled the task as a\nmultivariate forecasting problem, instead of a naive autoregression problem.\nThe multivariate analysis is done using the attention mechanism via applying a\nmutated version of the Transformer, "Stockformer", which we created.'}

In [251]:
thread_id = 'Stockformer1'

In [252]:
state = chunk_eval_graph.run_with_progress(**inputs, thread_id=thread_id)

Chunking document: ../docs/2502.09625v1.pdf into 59 chunks.


Evaluating Chunks: 100%|██████████| 59/59 [02:26<00:00,  2.49s/it, Negative=48, Positive=11]         

Saving evaluation results for query 1 on document ../docs/2502.09625v1.pdf...





In [253]:
state = memory.get({'configurable': {'thread_id': thread_id}})['channel_values']

In [254]:
positive_evaluations = sum([1 for chunk in state['chunks_eval'] if chunk['label']])
negative_evaluations = sum([1 for chunk in state['chunks_eval'] if not chunk['label']])
print(f"Positive evaluations: {positive_evaluations}\nNegative evaluations: {negative_evaluations}")

print("\nCONTEXT:")
for line in state['context'].split('\n'):
    print(line)

Positive evaluations: 11
Negative evaluations: 48

CONTEXT:
```
## Summary of previous sections:
- Stockformer is a modified Transformer model for hourly stock price forecasting, utilizing ProbSparse Attention and Self-attention Distilling for efficiency.
- Stockformer's performance benefits from larger embedding sizes and more attention heads.
- Stockformer with ProbSparse Attention is more efficient than Full Attention and more profitable than LSTM models.
- Future research will focus on incorporating more stock tickers, dynamic training methods, and exploring Time2Vec and TimeFrame2Vec temporal encoding to improve the model's profitability and adaptability.
- The paper provides a list of references used in the research.

## Current Section:
The paper provides a list of references used in the research. This section lists academic papers and publications that were cited in the study. The list of references continues from the previous chunk, completing the bibliography.
```


In [255]:
for chunk in state['chunks_eval']:
    if chunk['label']:
        print(f"\nChunk {chunk['idx']} (Positive):")
        for line in chunk['text'].split('\n'):
            print(line)
        print("\n---\n")


Chunk 1 (Positive):
the market generated over years, applying machine learning
techniques such as deep learning neural networks is unavoidable.
In this work, we modeled the task as a multivariate forecast-
ing problem, instead of a naive autoregression problem. The
multivariate analysis is done using the attention mechanism via
applying a mutated version of the Transformer, ”Stockformer”,
which we created.
I. I NTRODUCTION
Predicting the financial time series such as stock price
means predicting the behavior of the stock price steps ahead of
the series with the help of various variables. By knowing the

---


Chunk 14 (Positive):
In this project, we implemented Stockformer on the top of
the Transformer, discussed issues of naive Transformer, and
changed the original architecture to fit with the financial ticker
forecasting task.
III. P ROBLEM FORMULATION
Although the goal for the neural network is to predict the
stock price ahead, the task for this project is to assist traders
to make

In [171]:
thread_id = 'Byzantine7'

In [172]:
state = chunk_eval_graph.graph.invoke(
    input=inputs,
    config={'configurable': {'thread_id': thread_id}},
    interrupt_before='chunk_eval'
)

Chunking document: ../docs/2009.08294v1.pdf into 44 chunks.


In [182]:
state = chunk_eval_graph.graph.invoke(
    input=None,
    config={'configurable': {'thread_id': thread_id}},
    interrupt_before='chunk_eval'
)

positive_evaluations = sum([1 for chunk in state['chunks_eval'] if chunk['label']])
negative_evaluations = sum([1 for chunk in state['chunks_eval'] if not chunk['label']])
print(f"Positive evaluations: {positive_evaluations}\nNegative evaluations: {negative_evaluations}")

print("\nCONTEXT:")
for line in state['context'].split('\n'):
    print(line)

Positive evaluations: 3
Negative evaluations: 7

CONTEXT:
## Summary of previous sections:
- The paper examines robust aggregation methods in federated learning (FL) for healthcare, aiming to protect privacy and defend against malicious clients by detecting and discarding them during training.
- Experiments use healthcare datasets to evaluate robust FL aggregation against poisoning attacks, demonstrating that privacy-preserving methods can be combined with Byzantine-robust aggregation without significantly impacting learning.
- The paper addresses the challenges of sharing private patient data and introduces federated learning (FL) as a solution, while also acknowledging privacy and robustness concerns like biased datasets and poisoning attacks.
- The paper's approach involves using k-anonymity and differential privacy (DP), modeling poisoning attack strategies, and using healthcare datasets to evaluate robust aggregation methods in FL.
- The main contributions include evaluating robus

# Experiments