In [None]:
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFaceEndpoint
from langchain_community.retrievers.tfidf import TFIDFRetriever
from langchain.schema import Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from typing import List, Dict, Any
import os

In [None]:
class ReasoningAugmentedSystem:
    def __init__(self, model_name="google/flan-t5-base"):
        """
        Initialize the reasoning-augmented generation system.
        
        Args:
            model_name: Name of the open-source model to use
        """
        # Initialize the language model
        self.llm = HuggingFaceEndpoint(
            repo_id=model_name,
            max_length=512,
            temperature=0.1
        )
        
        # Initialize knowledge base
        self.facts = []
        self.rules = []
        self.documents = []
        
        # Create the reasoning prompt template
        self.reasoning_template = """
        Given the following facts and rules:
        
        Facts:
        {facts}
        
        Rules:
        {rules}
        
        Question: {question}
        
        Please reason step by step to answer the question:
        """
        
        self.prompt = PromptTemplate(
            template=self.reasoning_template,
            input_variables=["facts", "rules", "question"]
        )

    def add_fact(self, fact: str):
        """Add a fact to the knowledge base."""
        self.facts.append(fact)
        # Also add as document for retrieval
        self.documents.append(Document(page_content=fact, metadata={"type": "fact"}))
        
    def add_rule(self, rule: str):
        """Add a rule to the knowledge base."""
        self.rules.append(rule)
        # Also add as document for retrieval
        self.documents.append(Document(page_content=rule, metadata={"type": "rule"}))
    
    def setup_retriever(self):
        """Set up the retriever with current documents."""
        self.retriever = TFIDFRetriever.from_documents(self.documents)
        
    def answer_question(self, question: str) -> str:
        """
        Answer a question using reasoning over facts and rules.
        
        Args:
            question: The question to answer
            
        Returns:
            str: The reasoned answer
        """
        # Format facts and rules as strings
        facts_str = "\n".join([f"- {fact}" for fact in self.facts])
        rules_str = "\n".join([f"- {rule}" for rule in self.rules])
        
        # Create the prompt with the facts, rules, and question
        prompt_input = {
            "facts": facts_str,
            "rules": rules_str,
            "question": question
        }
        
        # Generate the answer using the language model
        response = self.llm.generate([self.prompt.format(**prompt_input)])
        
        return response.generations[0][0].text.strip()
    
    def answer_with_retrieval(self, question: str, k: int = 3) -> str:
        """
        Answer a question by first retrieving relevant facts/rules.
        
        Args:
            question: The question to answer
            k: Number of documents to retrieve
            
        Returns:
            str: The reasoned answer
        """
        # Set up retriever if not already done
        if not hasattr(self, 'retriever'):
            self.setup_retriever()
            
        # Retrieve relevant documents
        relevant_docs = self.retriever.get_relevant_documents(question, k=k)
        
        # Extract facts and rules from retrieved documents
        retrieved_facts = [doc.page_content for doc in relevant_docs 
                        if doc.metadata.get("type") == "fact"]
        retrieved_rules = [doc.page_content for doc in relevant_docs 
                        if doc.metadata.get("type") == "rule"]
        
        # Format facts and rules as strings
        facts_str = "\n".join([f"- {fact}" for fact in retrieved_facts])
        rules_str = "\n".join([f"- {rule}" for rule in retrieved_rules])
        
        # Create the prompt with the retrieved facts, rules, and question
        prompt_input = {
            "facts": facts_str,
            "rules": rules_str,
            "question": question
        }
        
        # Generate the answer using the language model
        response = self.llm.generate([self.prompt.format(**prompt_input)])
        
        return response.generations[0][0].text.strip()

In [None]:
if __name__ == "__main__":
    # Initialize the system
    rag_system = ReasoningAugmentedSystem(model_name="google/flan-t5-large")
    
    # Add facts
    rag_system.add_fact("John is a student at MIT.")
    rag_system.add_fact("MIT is located in Cambridge.")
    rag_system.add_fact("Cambridge is in Massachusetts.")
    rag_system.add_fact("Massachusetts is in the United States.")
    rag_system.add_fact("John studies computer science.")
    
    # Add rules
    rag_system.add_rule("If someone studies at a university, they are a student of that university.")
    rag_system.add_rule("If a place A is located in place B, and place B is in place C, then place A is in place C.")
    rag_system.add_rule("If someone studies a subject, they have knowledge about that subject.")
    
    # Ask a question
    question = "Where is John studying and what can we infer about his knowledge?"
    print(f"Question: {question}")
    
    # Get answer with all facts and rules
    answer = rag_system.answer_question(question)
    print("\nAnswer using all facts and rules:")
    print(answer)
    
    # Get answer with retrieval
    answer_retrieval = rag_system.answer_with_retrieval(question, k=4)
    print("\nAnswer using retrieval:")
    print(answer_retrieval)