
#  Step 1: Install Dependencies


In [4]:
!pip install transformers beir sentence-transformers faiss-cpu wikipedia-api torch scikit-learn



# 📌 Step 2: Import Libraries

In [18]:
import logging
import json
import random
import numpy as np
import torch
import wikipediaapi
import faiss
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
# Instead of importing the entire util module from sentence_transformers,
# import only the specific functions you need, if any.
# from sentence_transformers import util as st_util
from beir import util # This line imports the beir.util module which contains download_and_unzip
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from sklearn.metrics import accuracy_score, f1_score

# 📌 Step 3: Define Reward Model

In [6]:
# class RewardModel:
#     def __init__(self):
#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         self.model = AutoModelForSequenceClassification.from_pretrained(
#             "bert-base-uncased", num_labels=1
#         ).to(self.device)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

#     def score(self, reasoning_path: str) -> float:
#         inputs = self.tokenizer(reasoning_path, return_tensors="pt", truncation=True, max_length=512).to(self.device)
#         with torch.no_grad():
#             outputs = self.model(**inputs)
#             return outputs.logits.squeeze().item()

class RewardModel:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=1
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def score(self, reasoning_path: str) -> float:
        inputs = self.tokenizer(reasoning_path, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            return outputs.logits.squeeze().item()

# 📌 Step 4: Define Knowledge Base (Wikipedia + FAISS)

In [7]:
import wikipediaapi
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
from typing import List
import os
import json
import re

logger = logging.getLogger(__name__)

class KnowledgeBase:
    def __init__(self):
        # Initialize sentence transformer for embedding-based retrieval
        self.retriever = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
        self.embedding_dim = self.retriever.get_sentence_embedding_dimension()
        self.index = faiss.IndexFlatL2(self.embedding_dim)

        # Initialize Wikipedia API with user-agent
        user_agent = "AirRAG-Research-Bot/1.0"
        self.wiki = wikipediaapi.Wikipedia(
            user_agent=user_agent, language='en', extract_format=wikipediaapi.ExtractFormat.WIKI
        )

        # Initialize text storage and cache
        self.texts = []
        self.cache_dir = "/content/knowledge_cache"  # Set for Google Colab
        os.makedirs(self.cache_dir, exist_ok=True)

    def clean_query(self, query: str) -> str:
        """Cleans the query by removing special characters."""
        return re.sub(r'[?!.,]', '', query).strip()

    def extract_search_terms(self, query: str) -> List[str]:
        """Extracts important search terms by removing stop words."""
        stop_words = {'what', 'is', 'the', 'where', 'when', 'who', 'how', 'why',
                      'and', 'or', 'in', 'on', 'at', 'to', 'for', 'of', 'with'}

        terms = self.clean_query(query.lower()).split()
        meaningful_terms = []

        i = 0
        while i < len(terms):
            if i + 1 < len(terms):
                combined = f"{terms[i]} {terms[i+1]}"
                if not any(word in stop_words for word in combined.split()):
                    meaningful_terms.append(combined)
                    i += 2
                    continue

            if terms[i] not in stop_words:
                meaningful_terms.append(terms[i])
            i += 1

        return meaningful_terms if meaningful_terms else [query]

    def get_wiki_content(self, query: str) -> List[str]:
        """Retrieves content from Wikipedia and tries alternative search terms if needed."""
        try:
            clean_query = self.clean_query(query.lower())
            page = self.wiki.page(clean_query)

            if page.exists():
                paragraphs = [p.strip() for p in page.text.split('\n\n') if len(p.strip()) > 50][:3]
                return [p[:300] for p in paragraphs]

            # If no exact match, try searching individual terms
            terms = self.extract_search_terms(query)
            for term in terms:
                page = self.wiki.page(term)
                if page.exists():
                    paragraphs = [p.strip() for p in page.text.split('\n\n') if len(p.strip()) > 50][:2]
                    return [p[:300] for p in paragraphs]

            return []

        except Exception as e:
            logger.error(f"Error in get_wiki_content: {str(e)}")
            return []

    def add_knowledge(self, query: str):
        """Adds retrieved Wikipedia knowledge and caches embeddings."""
        try:
            cache_file = os.path.join(self.cache_dir, f"{hash(query)}.json")

            # Load from cache if available
            if os.path.exists(cache_file):
                with open(cache_file, 'r', encoding='utf-8') as f:
                    cached_data = json.load(f)
                    self.texts.extend(cached_data['texts'])
                    embeddings = np.array(cached_data['embeddings'], dtype='float32')
                    self.index.add(embeddings)
                    return

            # Retrieve knowledge from Wikipedia
            paragraphs = self.get_wiki_content(query)
            if paragraphs:
                embeddings = self.retriever.encode(paragraphs, convert_to_numpy=True)
                self.index.add(embeddings.astype('float32'))
                self.texts.extend(paragraphs)

                # Save to cache
                with open(cache_file, 'w', encoding='utf-8') as f:
                    json.dump({'texts': paragraphs, 'embeddings': embeddings.tolist()}, f, ensure_ascii=False)

        except Exception as e:
            logger.error(f"Error in add_knowledge: {str(e)}")

    def retrieve(self, query: str, k: int = 3) -> List[str]:
        """Retrieves the most relevant knowledge using FAISS similarity search."""
        try:
            if not self.texts:
                return []

            query_embedding = self.retriever.encode([query], convert_to_numpy=True)
            D, I = self.index.search(query_embedding.astype('float32'), min(k, len(self.texts)))

            return [self.texts[i] for i in I[0] if i < len(self.texts)]

        except Exception as e:
            logger.error(f"Error in retrieve: {str(e)}")
            return []


# Action Manager for Reasoning Actions

In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from typing import List, Dict, Optional

logger = logging.getLogger(__name__)

class ActionManager:
    def __init__(self):
        """Initialize the reasoning agent using a causal language model."""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            device_map="auto"
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")

    def system_analysis(self, query: str) -> List[str]:
        """
        Decompose a complex query into sub-queries for better reasoning.
        """
        try:
            prompt = f"Decompose the following question into meaningful sub-queries:\nQuestion: {query}\nSub-Queries:"
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_length=256,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9
                )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return [q.strip() for q in response.split("\n") if q.strip()]
        except Exception as e:
            logger.error(f"Error in system_analysis: {str(e)}")
            return []

    def direct_answer(self, query: str) -> str:
        """
        Answer the question directly using the model's internal knowledge.
        """
        try:
            prompt = f"Answer the following question concisely and accurately:\nQuestion: {query}\nAnswer:"
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_length=256,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9
                )

            return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        except Exception as e:
            logger.error(f"Error in direct_answer: {str(e)}")
            return ""

    def retrieval_answer(self, query: str, context: List[str], reasoning_context: Optional[Dict] = None) -> str:
        """
        Answer the question using external knowledge and reasoning context.
        """
        try:
            # Truncate and combine context
            truncated_context = [c[:200] + "..." if len(c) > 200 else c for c in context]
            context_text = " ".join(truncated_context)

            # Add reasoning paths if available
            reasoning_text = ""
            if reasoning_context and "reasoning_paths" in reasoning_context:
                reasoning_text = "\nReasoning steps:\n"
                for path in reasoning_context["reasoning_paths"]:
                    path_text = " -> ".join([step["state"] for step in path])
                    reasoning_text += f"- {path_text}\n"

            # Create prompt
            prompt = (
                f"Based on the following context and reasoning, answer the question concisely:\n"
                f"Context: {context_text}\n"
                f"{reasoning_text}"
                f"Question: {query}\n"
                f"Answer:"
            )

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_length=256,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9
                )

            return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        except Exception as e:
            logger.error(f"Error in retrieval_answer: {str(e)}")
            return ""

    def query_transformation(self, query: str) -> str:
        """
        Transform the query for better retrieval performance.
        """
        try:
            prompt = (
                f"Rewrite the following question to make it more specific and retrieval-friendly:\n"
                f"Original Question: {query}\n"
                f"Transformed Question:"
            )
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_length=128,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9
                )

            return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        except Exception as e:
            logger.error(f"Error in query_transformation: {str(e)}")
            return query  # Return original query if transformation fails

    def summary_answer(self, query: str, reasoning_steps: List[Dict]) -> str:
        """
        Summarize all reasoning steps and intermediate answers to generate the final answer.
        """
        try:
            reasoning_context = "\n".join(
                f"Step {i+1}: {step['action']} - {step['state']}" for i, step in enumerate(reasoning_steps)
            )
            prompt = (
                f"Given the following reasoning steps, provide a final answer to the query:\n"
                f"Query: {query}\n"
                f"Reasoning Steps:\n{reasoning_context}\n"
                f"Final Answer:"
            )
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_length=128,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9
                )

            return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        except Exception as e:
            logger.error(f"Error in summary_answer: {str(e)}")
            return ""


# 📌 Step 6: Define Monte Carlo Tree Search (MCTS)

In [9]:
class ReasoningNode:
    def __init__(self, state: str, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.reasoning_chain = []

class MonteCarloTreeSearch:
    def __init__(self, root, knowledge_base, action_manager):
        self.root = root
        self.kb = knowledge_base
        self.am = action_manager
        self.evaluator = RewardModel()

    def select(self, node: ReasoningNode):
        if not node.children:
            return node
        ucb_values = [
            (child.value / (child.visits + 1e-6)) + np.sqrt(2 * np.log(node.visits + 1) / (child.visits + 1e-6))
            for child in node.children
        ]
        return node.children[np.argmax(ucb_values)]

    def expand(self, node: ReasoningNode, actions):
        context = self.kb.retrieve(node.state, k=3)
        for action in actions:
            result = action(node.state, context)
            if result:
                child = ReasoningNode(result, parent=node)
                node.children.append(child)

    def prune_paths(self, paths: list):
        return list(set(paths))

    def simulate(self, node: ReasoningNode):
        return self.evaluator.score(" -> ".join([step['state'] for step in node.reasoning_chain]))

    def backpropagate(self, node: ReasoningNode, reward: float):
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent

    def run(self, max_iterations, actions):
        for _ in range(max_iterations):
            node = self.select(self.root)
            self.expand(node, actions)
            reward = self.simulate(node)
            self.backpropagate(node, reward)

        best_child = max(self.root.children, key=lambda c: c.value / c.visits if c.visits > 0 else -float('inf'))
        return best_child.state


# 📌 Step 7: Load BEIR Dataset (SciFact)

In [19]:
from beir import util as beir_util # Import util from beir and rename it

def load_beir_dataset():
    dataset = "scifact"
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
    data_path = beir_util.download_and_unzip(url, "/content/datasets") # Use the renamed import
    corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
    return queries, corpus, qrels

# 📌 Step 8: Evaluate Model on BEIR 


In [None]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer

# Create a wrapper class to provide required methods
class SentenceTransformerWrapper:
    def __init__(self, model):
        self.model = model

    def encode_queries(self, queries, batch_size=16, **kwargs):
        return self.model.encode(queries, batch_size=batch_size, **kwargs)

    def encode_corpus(self, corpus, batch_size=16, **kwargs):
        if isinstance(corpus[0], dict):
            texts = [doc["text"] for doc in corpus]
        else:
            texts = corpus
        return self.model.encode(texts, batch_size=batch_size, **kwargs)

# Load the BEIR dataset
queries, corpus, qrels = load_beir_dataset()

# Initialize model with wrapper
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
wrapped_model = SentenceTransformerWrapper(model)
retriever = DRES(wrapped_model, batch_size=16)

# Perform retrieval
results = retriever.search(corpus, queries, top_k=10,score_function='cos_sim')

# Evaluate the retrieval results
evaluator = EvaluateRetrieval()
metrics = evaluator.evaluate(qrels, results, k_values=[1, 3, 5, 10])

print(json.dumps(metrics, indent=4))

  0%|          | 0/5183 [00:00<?, ?it/s]

Batches:   0%|          | 0/19 [00:00<?, ?it/s]

Batches:   0%|          | 0/324 [00:00<?, ?it/s]