In [1]:
import gradio as gr
from source.pipeline.config import PipelineConfig
from source.pipeline.controller import PipelineController
from source.pipeline.step.retrieval import RetrievalStep
from source.pipeline.step.generation import (
    GenerationStep, 
    AnswerGenerateOutputParser, 
    AnswerGeneratePromptGenerator,
    ThoughtGenerateOutputParser,
    ThoughtGeneratePromptGenerator,
)
from source.pipeline.step.end import EndStep
from source.pipeline.state import QuestionState
from source.module.generate.llama import LlamaGenerator, LlamaGeneratorConfig
from source.module.retrieve.dense import DenseRetriever, DenseRetrieverConfig
from source.module.index.index import Indexer, IndexerConfig
from source.utility.system_utils import seed_everything

from huggingface_hub import login

login(token=f"{your_hf_token}")

# ----------------------------
# Config
# ----------------------------
seed_everything(100)

cfg = PipelineConfig(
    method="base",
    batch_size=1,
    generation_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct',
    generation_max_batch_size=1,
    generation_max_total_tokens=4096,
    generation_max_new_tokens=64,
    generation_min_new_tokens=1,
    retrieval_count=8,
    retrieval_query_type='full',
    dataset='musique',
    max_num_thought=6,
    answer_regex=".* answer is:? (.*)\\.?",
)

# ----------------------------
# Module Init
# ----------------------------
generator = LlamaGenerator(
    LlamaGeneratorConfig(
        model_name=cfg.generation_model_name,
        batch_size=cfg.generation_max_batch_size,
        max_total_tokens=cfg.generation_max_total_tokens,
        max_new_tokens=cfg.generation_max_new_tokens,
        min_new_tokens=cfg.generation_min_new_tokens,
        use_vllm=False, #True,
        gpu=0,
    )
)

retriever = DenseRetriever(
    DenseRetrieverConfig(
        query_model_name_or_path='facebook/contriever-msmarco',
        passage_model_name_or_path=None,
        batch_size=32,
        training_strategy=None,
        use_fp16=False
    )
)

indexer = Indexer.load_local(
    IndexerConfig(
        embedding_sz=768,
        database_path=cfg.database_path
    )
)

# ----------------------------
# Pipeline & Controller Setup
# ----------------------------
pipeline = [
    RetrievalStep(cfg=cfg, retriever=retriever, indexer=indexer),
    GenerationStep(cfg=cfg, generator=generator,
                   prompt_generator=AnswerGeneratePromptGenerator(cfg),
                   output_parser=AnswerGenerateOutputParser(cfg)),
    EndStep(cfg=cfg),
    GenerationStep(cfg=cfg, generator=generator,
                   prompt_generator=ThoughtGeneratePromptGenerator(cfg),
                   output_parser=ThoughtGenerateOutputParser(cfg)),
]

controller = PipelineController(
    pipeline=pipeline,
    logging_file_path=None,
    prediction_file_path=None
)

  from .autonotebook import tqdm as notebook_tqdm
2025-11-19 00:25:13,123	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.20s/it]
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


[    Indexer     ] Deserializing Index from ./data/database/contriever_msmarco/musique...  Deserializing Complete!


In [23]:
user_input="Which missionary helped spread the religion widely practiced in region having the second largest rain-forest in the world?"

start_state = QuestionState(question_id="1", question=user_input)

In [25]:
user_input="Which missionary helped spread the religion widely practiced in region having the second largest rain-forest in the world?"

start_state = QuestionState(question_id="1", question=user_input)

controller.update([start_state])
paths = controller.next()  
# 1st-hop Retrieve
next_states = controller.pipeline[0](paths) 
for d, document in enumerate(next_states[0].documents[:8]):
    print(document.metadata['title'])
controller.update(next_states)
paths = controller.next()
# 1st-hop Answer
next_states = controller.pipeline[1](paths) 
print(f"Model Prediction: {next_states[0].answer}")
# if "Unknown" continue else exit
controller.update(next_states)
paths = controller.next()
# 1st-hop Check Answer
next_states = controller.pipeline[2](paths) 
controller.update(next_states)
paths = controller.next()
# 1st-hop Think
next_states = controller.pipeline[3](paths) 
print(f"Intermediate Thought: {next_states[0].thought}")
controller.update(next_states)
paths = controller.next()

In [None]:
# ---- Loop for hop >= 2 ----
MAX_HOPS = 5
hop = 2

while hop <= MAX_HOPS:
    # Retrieve
    next_states = controller.pipeline[0](paths)
    # TODO: we use a buffer of 32 to remove redundant documents, this is not implemented in this demo
    titles = [doc.metadata['title'] for doc in next_states[0].documents[:8]]
    print(f"Retrieved: {titles}") 
    controller.update(next_states)
    paths = controller.next()

    # Answer
    next_states = controller.pipeline[1](paths)
    print(f"{hop}-hop Answer: {next_states[0].answer}")

    controller.update(next_states)
    paths = controller.next()

    if next_states[0].answer != "Unknown":
        break

    # Check
    next_states = controller.pipeline[2](paths)
    controller.update(next_states)
    paths = controller.next()

    # Think
    next_states = controller.pipeline[3](paths)
    print(f"{hop}-hop Thought: {next_states[0].thought}")

    controller.update(next_states)
    paths = controller.next()

    hop += 1