In [1]:
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.constants import Send
from typing_extensions import List, Dict, TypedDict, Annotated
from pydantic import BaseModel
import json
from operator import add
from rag_eval.prompt_manager import PromptManager
from rag_eval.utils import create_path_if_not_exists
from rag_eval import ChunkDataHandler
import os
from fcgb.cfg.precompiled import get_llm, get_checkpointer

import asyncio
from tqdm.notebook import tqdm as notebook_tqdm
import uuid

import sys
sys.path.append('..')

In [2]:
class HydeAnswerModel(BaseModel):
    answers: List[str]

class HyDEset(TypedDict):
    query: str
    hyde_queries: List[str]

class HyDEState(BaseModel):
    title: str
    summary: str = ''
    query: str = ''

class HyDEGenState(BaseModel):
    doc_name: str
    title: str = ''
    summary: str = ''
    queries: List[str] = []
    hyde_sets: Annotated[List[HyDEset], add] = []


class HyDEQueriesGenerator(ChunkDataHandler):
    def __init__(
            self,
            llm,
            data_path: str,
            prompts_config: Dict,
            memory=None,
            hyde_queries: int = 5,
            prompt_manager_spec: Dict = {}
    ):
        """
        Initializes the HyDEQueriesGenerator with the necessary parameters.
        Args:
            llm: The language model to use for generating queries.
            hyde_path (str): Output path for HyDE queries.
            data_path: Main path of chunk dataset.
            prompts_config (Dict): Configuration for prompts.
                Configuration should include:
                'path': main path where prompts are stored,
                'hyde_queries': prompt for generating HyDE queries.
            memory: Optional memory object for storing state.
            hyde_queries: Number of queries to generate (default is 5).
            prompt_manager_spec: Specifications for the prompt manager.
        """
        super().__init__(output_path=data_path)

        self.llm = llm
        self.hyde_queries = hyde_queries
        self.memory = memory
        self.prompts_config = prompts_config
        self.prompt_manager_spec = prompt_manager_spec

        self.hyde_path = os.path.join(data_path, 'hyde_queries')
        create_path_if_not_exists(self.hyde_path)

        self._set_prompts()
        self.build_graph()

    def _set_prompts(self):
        """
        Initializes the prompt manager and retrieves the prompts for generating queries.
        """
        prompt_manager = PromptManager(
            version_config=self.prompt_manager_spec,
            path=self.prompts_config['path']
        )

        self.prompts = prompt_manager.get_prompts([self.prompts_config[name] for name in ['hyde_queries']])

    def load_metadata_node(self, state: HyDEGenState):
        """
        Loads metadata for the document specified in the state.
        """
        metadata = self.get_metadata_file(state.doc_name)

        return {
            'title': metadata.get('title', ''),
            'summary': metadata.get('summary', ''),
            'queries': metadata.get('queries', []),
        }
    
    def queries_routing(self, state: HyDEGenState):
        """
        Distributes the queries to the HyDE generation nodes.
        """
        return [Send('hyde_gen', HyDEState(
                    title=state.title,
                    summary=state.summary,
                    query=query)
                    ) for query in state.queries]
    
    def hyde_gen(self, state: HyDEState):
        """
        Generates HyDE queries in parallel for every base query.
        """
        hyde_prompt = self.prompts['hyde_queries'].format(
            title=state.title,
            summary=state.summary,
            query=state.query,
            hyde_queries_num=self.hyde_queries
        )

        hyde_queries = self.llm.with_structured_output(HydeAnswerModel).invoke(hyde_prompt).answers

        hyde_output = HyDEset(
            query=state.query,
            hyde_queries=hyde_queries
        )
        return {'hyde_sets': [hyde_output]}
    
    def save_hyde_queries(self, state: HyDEGenState):
        """
        Saves the generated HyDE queries to a JSON file.
        """
        
        hyde_file_path = os.path.join(self.hyde_path, f"{state.doc_name}.json")
        with open(hyde_file_path, 'w') as f:
            json.dump({'queries_set': state.hyde_sets}, f, indent=4)

        return {}
    
    def build_graph(self):
        """
        Builds the state graph for the HyDE queries generation workflow.
        """
        workflow = StateGraph(HyDEGenState)
        workflow.add_node('load_metadata', self.load_metadata_node)
        workflow.add_node('hyde_gen', self.hyde_gen)
        workflow.add_node('save_hyde_queries', self.save_hyde_queries)

        workflow.add_edge(START, 'load_metadata')
        workflow.add_conditional_edges('load_metadata', self.queries_routing, 'hyde_gen')
        workflow.add_edge('hyde_gen', 'save_hyde_queries')
        workflow.add_edge('save_hyde_queries', END)

        self.graph = workflow.compile(checkpointer=self.memory)

    def run(self, doc_name: str, thread_id: str = None):
        """
        Runs the HyDE queries generation workflow for a given document name.
        Args:
            doc_name (str): The name of the document for which to generate HyDE queries.
            thread_id (str, optional): An identifier for the thread, if applicable.
        Returns:
            HyDEGenState: The final state of the workflow containing the generated HyDE queries.
        """
        config = {'configurable': {'thread_id': thread_id}} if thread_id else None
        return self.graph.invoke({'doc_name': doc_name}, config=config)
    
    async def add_hyde_queries(self, target_docs: int, concurrent_runs=5):

        docs_to_process = self._docs_without_hyde()
        current_docs_processed = len(self._docs_with_hyde())
        docs_needed = max(target_docs - current_docs_processed, 0)

        docs_to_process = docs_to_process[:docs_needed]

        process_pbar = notebook_tqdm(
            total=len(docs_to_process),
            desc='HyDE queries generation',
            unit='doc',
            postfix={'Target docs': target_docs, 'Processed docs': current_docs_processed}
        )

        async def worker(queue):
            while not queue.empty():
                doc_name = await queue.get()
                try:
                    thread_id = uuid.uuid4().hex
                    state = self.run(doc_name, thread_id=thread_id)
                    process_pbar.update(1)
                    process_pbar.set_postfix({'Target docs': target_docs, 'Processed docs': len(self._docs_with_hyde())})
                except Exception as e:
                    print(f"Error processing {doc_name}: {e}")
                finally:
                    queue.task_done()

        queue = asyncio.Queue()
        for doc_name in docs_to_process:
            await queue.put(doc_name)

        workers = [asyncio.create_task(worker(queue)) for _ in range(concurrent_runs)]
        await queue.join()

        await asyncio.gather(*workers)

In [3]:
hyde_gen = HyDEQueriesGenerator(
    llm=get_llm(llm_model='google'),
    data_path='../data',
    prompts_config={
        'path': None,
        'hyde_queries': 'hyde_queries'
    },
    memory=get_checkpointer(checkpointer_mode='local'),
    hyde_queries=5,
    prompt_manager_spec={}
)

In [4]:
hyde_gen.summary()

All docs: 115
Docs to evaluate: 15 - 13.04% of all docs
Docs evaluated: 100 - 86.96% of all docs
Docs with HyDE queries: 100 - 100.00% of evaluated docs


### database extension

In [9]:
await hyde_gen.add_hyde_queries(target_docs=100, concurrent_runs=5)

HyDE queries generation:   0%|          | 0/80 [00:00<?, ?doc/s, Processed docs=20, Target docs=100]

### single run

In [4]:
docs_names = hyde_gen._evaluated_docs()

hyde_gen.run(
    doc_name=docs_names[2],
    thread_id='test_thread'
)

{'doc_name': '1705.03998v1',
 'title': 'Mining Functional Modules by Multiview-NMF of Phenome-Genome Association',
 'summary': 'Background: Mining gene modules from genomic data is an important step to\ndetect gene members of pathways or other relations such as protein-protein\ninteractions. In this work, we explore the plausibility of detecting gene\nmodules by factorizing gene-phenotype associations from a phenotype ontology\nrather than the conventionally used gene expression data. In particular, the\nhierarchical structure of ontology has not been sufficiently utilized in\nclustering genes while functionally related genes are consistently associated\nwith phenotypes on the same path in the phenotype ontology. Results: We propose\na hierarchal Nonnegative Matrix Factorization (NMF)-based method, called\nConsistent Multiple Nonnegative Matrix Factorization (CMNMF), to factorize\ngenome-phenome association matrix at two levels of the hierarchical structure\nin phenotype ontology for m

### outputs processing

In [5]:
hyde_docs = hyde_gen._docs_with_hyde()
print(f"HyDE queries generated for {len(hyde_docs)} documents.")

HyDE queries generated for 100 documents.


In [6]:
hyde_doc = hyde_gen.get_hyde_file(hyde_docs[0])

In [10]:
hyde_doc

{'queries_set': [{'query': 'What is complex-valued federated learning?',
   'hyde_queries': ['complex-valued federated learning combines federated learning with complex-valued data processing techniques, particularly relevant in fields like MRI.',
    'complex-valued federated learning uses differential privacy to protect sensitive data in medical applications like MRI.',
    'complex-valued Gaussian mechanism is introduced to address the application of differential privacy to complex-valued data.',
    'DP stochastic gradient descent is generalized to complex-valued neural networks in complex-valued federated learning.',
    'complex-valued federated learning trains complex-valued neural networks with differential privacy on MRI pulse sequence classification in k-space.']},
  {'query': 'How does Differential Privacy enhance federated learning?',
   'hyde_queries': ['Differential Privacy (DP) protects sensitive data during federated learning.',
    'DP adds noise to gradients, ensuring

In [9]:
queries_lengths = [len(hyde_set['hyde_queries'])+1 for hyde_set in hyde_doc['queries_set']]
print(queries_lengths)

concatenated_queries = [query for queries_set in hyde_doc['queries_set'] for query in [queries_set['query']] + queries_set['hyde_queries']]
concatenated_queries

[6, 6, 6, 6, 6]


['What is complex-valued federated learning?',
 'complex-valued federated learning combines federated learning with complex-valued data processing techniques, particularly relevant in fields like MRI.',
 'complex-valued federated learning uses differential privacy to protect sensitive data in medical applications like MRI.',
 'complex-valued Gaussian mechanism is introduced to address the application of differential privacy to complex-valued data.',
 'DP stochastic gradient descent is generalized to complex-valued neural networks in complex-valued federated learning.',
 'complex-valued federated learning trains complex-valued neural networks with differential privacy on MRI pulse sequence classification in k-space.',
 'How does Differential Privacy enhance federated learning?',
 'Differential Privacy (DP) protects sensitive data during federated learning.',
 'DP adds noise to gradients, ensuring individual data privacy in federated settings.',
 'Federated learning with DP allows traini