<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/COVE_RAG_POC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q langchain huggingface_hub transformers langchain_community

!pip install sentence-transformers -q

!pip install faiss-gpu -q

!pip install -U langchain_huggingface -q

In [None]:
!pip show langchain

Name: langchain
Version: 0.3.3
Summary: Building applications with LLMs through composability
Home-page: https://github.com/langchain-ai/langchain
Author: 
Author-email: 
License: MIT
Location: /usr/local/lib/python3.10/dist-packages
Requires: aiohttp, async-timeout, langchain-core, langchain-text-splitters, langsmith, numpy, pydantic, PyYAML, requests, SQLAlchemy, tenacity
Required-by: langchain-community


In [None]:
!nvidia-smi

Fri Oct 11 19:53:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              49W / 400W |  28713MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.schema import Document  # Import Document class

import warnings
warnings.simplefilter("ignore")

docs = [
    Document(page_content="The capital of France is Paris.", metadata={"source": "doc1"}),
    Document(page_content="The Eiffel Tower is located in Paris.", metadata={"source": "doc2"}),
    Document(page_content="The Eiffel Tower is a wrought-iron lattice tower constructed in 1887.", metadata={"source": "doc3"}),
    Document(page_content="The Louvre Museum is a famous museum in Paris, known for its extensive art collection.", metadata={"source": "doc4"}),
    Document(page_content="The Eiffel Tower was the world's tallest man-made structure until 1930.", metadata={"source": "doc5"}),
]

# 2. Initialize Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")


# 3. Create FAISS vector store
db = FAISS.from_documents(docs, embeddings)

# 4. Initialize Mistral LLM from Hugging Face
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, clean_up_tokenization_spaces=True)
model = AutoModelForCausalLM.from_pretrained(model_id)


#print(tokenizer.clean_up_tokenization_spaces)  # Should print True

# 5. Create a LangChain LLM wrapper for Mistral
mistral_llm = HuggingFacePipeline.from_model_id(
    model_id=model_id,
    task="text-generation",
    pipeline_kwargs={"max_new_tokens": 512, "do_sample": False},
    device=0
)

# 6. Initialize RAG chain with Mistral LLM and FAISS
chain = RetrievalQA.from_chain_type(
    llm=mistral_llm, chain_type="stuff", retriever=db.as_retriever()
)

# 7. Define a Chain of Verification (CoVe) function
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

def chain_of_verificationOLD(initial_response, docs):
    """
    Refines the initial response using information from the FAISS database.
    """

    print('\n')
    print(f"Chain of Verification")
    print('\n')

    refined_response = initial_response.copy()
    result = refined_response['result']

    # Extract the main answer from the initial response
    main_answer = initial_response['result'].split("Helpful Answer:")[-1].split(".")[0].strip()

    # Retrieve relevant documents from FAISS based on the main answer
    retrieved_docs = db.similarity_search(main_answer)
    context = " ".join([doc.page_content for doc in retrieved_docs])

    # Generate a verification question about additional evidence
    prompt_template = "Based on this context: {context}, what additional evidence or details can you provide to support the statement: '{statement}'?"
    prompt = PromptTemplate(
        input_variables=["statement", "context"],
        template=prompt_template,
    )
    question_chain = LLMChain(llm=mistral_llm, prompt=prompt)
    question = question_chain.run({"statement": main_answer, "context": context})

    print(f"Verification Question: {question}")
    verification_answer = chain.invoke(question)
    print('\n')
    print(f"Verification Answer: {verification_answer}")

    # Extract and add supporting evidence
    if "Helpful Answer:" in verification_answer['result']:
        evidence = verification_answer['result'].split("Helpful Answer:")[-1].strip()
        result = result.replace(
            f"Helpful Answer: {main_answer}.",
            f"Helpful Answer: {main_answer}. {evidence}"
        )

    refined_response['result'] = result
    return refined_response

def chain_of_verification(initial_response, docs):
    """
    Dynamically generates verification questions and refines the initial response using FAISS.
    """

    verification_question_templates = {
        "evidence": "Based on this context: {context}, what additional evidence or details can you provide to support the statement: '{statement}'?",
        "contradiction": "Are there any sources that contradict this statement: '{statement}'? Consider this context: {context}",
    }

    refined_response = initial_response.copy()
    result = refined_response['result'].split("Helpful Answer:")[0]  # Start with the initial part
    statements = result.split(". ")

    helpful_answer = ""
    added_contradiction = False

    for statement in statements:
        retrieved_docs = db.similarity_search(statement)
        context = " ".join([doc.page_content for doc in retrieved_docs])

        # Generate evidence question
        prompt = PromptTemplate(
            input_variables=["statement", "context"],
            template=verification_question_templates["evidence"],
        )
        question_chain = LLMChain(llm=mistral_llm, prompt=prompt)
        question = question_chain.run({"statement": statement, "context": context})
        #print(f"Verification Question (evidence): {question}")
        verification_answer = chain.invoke(question)
        #print(f"Verification Answer: {verification_answer}")

        if "Helpful Answer:" in verification_answer['result']:
            evidence = verification_answer['result'].split("Helpful Answer:")[-1].strip()
            helpful_answer = f" {statement} {evidence}"
            break  # Stop after finding evidence

        # Generate contradiction question
        prompt = PromptTemplate(
            input_variables=["statement", "context"],
            template=verification_question_templates["contradiction"],
        )
        question_chain = LLMChain(llm=mistral_llm, prompt=prompt)
        question = question_chain.run({"statement": statement, "context": context})
        #print(f"Verification Question (contradiction): {question}")
        verification_answer = chain.invoke(question)
        #print(f"Verification Answer: {verification_answer}")

        if "no contradictory" not in verification_answer['result'].lower() and "contradictory" in verification_answer['result'].lower():
            result += " However, it's important to note that there might be conflicting information or varying perspectives on this statement."
            added_contradiction = True
            break  # Stop after finding a contradiction

    if helpful_answer:
        result += f"Helpful Answer - Refined Response:{helpful_answer}"

    refined_response['result'] = result
    return refined_response


# 8. Test Case
query = "Where is the Eiffel Tower?"
#initial_response = chain.run(query)
initial_response = chain.invoke(query)
print('\n')
#print(f"Initial Response: {initial_response}")

print(f"Query - Initial Response: {initial_response['query']}")
print(f"Result - Initial Response: {initial_response['result']}")

print('\n\n')

# 9. Apply Chain of Verification
refined_response = chain_of_verification(initial_response, docs)
#print(f"Refined Response: {refined_response}")
print('\n')

print(f"Query - Refined Response: {refined_response['query']}")
print(f"Result - Refined Response:{refined_response['result']}")

print('\n')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Query - Initial Response: Where is the Eiffel Tower?
Result - Initial Response: Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

The Eiffel Tower is located in Paris.

The Eiffel Tower was the world's tallest man-made structure until 1930.

The Eiffel Tower is a wrought-iron lattice tower constructed in 1887.

The capital of France is Paris.

Question: Where is the Eiffel Tower?
Helpful Answer: The Eiffel Tower is located in Paris.





Query - Refined Response: Where is the Eiffel Tower?
Result - Refined Response:Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

The Eiffel Tower is located in Paris.

The Eiffel Tower was the world's tallest man-made structure until 1930.

The Eiffel Tower is a wrought-iron lattice tower constructed in 1887.

The capital of F