# Source attribution detection for RAG based natural language question responses using watsonx

## Notebook content
This notebook contains the steps and code to demonstrate support of Retrieval Augumented Generation in watsonx.ai and identify source attribution. It introduces commands for data retrieval, knowledge base building & querying, and model testing. Some familiarity with Python is helpful. This notebook uses Python 3.10, and is based on [this notebook](https://github.com/IBM/watson-openscale-samples/blob/main/WatsonX.Governance/Cloud/GenAI/samples/source-attribution-using-protodash-for-rag-usecase%20.ipynb) by Sowmaya Kollipara.

### About Retrieval Augmented Generation
Retrieval Augmented Generation (RAG) is a versatile pattern that can unlock a number of use cases requiring factual recall of information, such as querying a knowledge base in natural language.

In its simplest form, RAG requires 3 steps:

- Index knowledge base passages (once)
- Retrieve relevant passage(s) from knowledge base (for every user query)
- Generate a response by feeding retrieved passage into a large language model (for every user query)

#### Source Attribution Detection
Source attribution detection is to identify the part(s) from the context which could have attributed to the response from the foundation model . 

#### The flow of this notebook is as follows :
1. Building a knowledge base
2. Getting the relevant information from the vectordb to get the relevant context for a bunch of questions for which user is looking for responses.
3. Construct the prompt using the question and relevant context for each question considered.
4. Generate the retrieval augmented response to the question using the foundation models hosted on watsonx.ai
5. Intialize the WOS client, supply the configuration needed for identifying source attribution.
6. Identify the source attribution for the RAG based responses.

### Install and import the dependecies

In [None]:
!pip install "langchain==0.0.345" | tail -n 1
!pip install wget | tail -n 1
!pip install sentence-transformers | tail -n 1
!pip install "chromadb==0.3.26" | tail -n 1
!pip install "ibm-watson-machine-learning>=1.0.335" | tail -n 1
!pip install "pydantic==1.10.0" | tail -n 1
!pip install --upgrade ibm-metrics-plugin  --no-cache | tail -n 1
!pip install --upgrade ibm-watson-openscale --no-cache | tail -n 1
!pip install --upgrade pyspark==3.3.1 | tail -n 1
!pip install -U "torch==2.0.0"


### Edit the two values below with your API key and Project ID

Refer to the  [lab guide](https://github.com/ericmartens/rag-explain/blob/main/RAG_Explanations_Lab_Guide.pdf) for instructions on gathering the correct values for API\_KEY and PROJECT\_ID. Copy and paste the values into the cell below in between the quotation marks.

In [None]:
API_KEY = "___PASTE API KEY HERE___"
PROJECT_ID = "___PASTE PROJECT ID HERE___"

## Foundation Models on `watsonx.ai`

IBM watsonx foundation models are among the <a href="https://python.langchain.com/docs/integrations/llms/watsonxllm" target="_blank" rel="noopener no referrer">list of LLM models supported by Langchain</a>. This example shows how to communicate with <a href="https://newsroom.ibm.com/2023-09-28-IBM-Announces-Availability-of-watsonx-Granite-Model-Series,-Client-Protections-for-IBM-watsonx-Models" target="_blank" rel="noopener no referrer">Granite Model Series</a> using <a href="https://python.langchain.com/docs/get_started/introduction" target="_blank" rel="noopener no referrer">Langchain</a>.

In [None]:
from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes

model_id = ModelTypes.GRANITE_13B_CHAT_V2

### Defining the model parameters
We need to provide a set of model parameters that will influence the result.

In [None]:
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods

parameters = {
    GenParams.DECODING_METHOD: DecodingMethods.GREEDY,
    GenParams.MIN_NEW_TOKENS: 1,
    GenParams.MAX_NEW_TOKENS: 100,
    GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
}

### LangChain CustomLLM wrapper for watsonx model
Initialize the `WatsonxLLM` class from Langchain with defined parameters and `ibm/granite-13b-chat-v2`. 

In [None]:
from langchain.llms import WatsonxLLM

watsonx_granite = WatsonxLLM(
    model_id=model_id.value,
    url="https://us-south.ml.cloud.ibm.com",
    apikey=API_KEY,
    project_id=PROJECT_ID,
    params=parameters
)

### Set up the OpenScale client

Explanations use the watsonx.governance monitoring (OpenScale) service to calculate which of the context paragraphs contributed to the model's answer. The next cell uses the supplied credentials to authenticate with the OpenScale client.

In [None]:
from ibm_cloud_sdk_core.authenticators import IAMAuthenticator, BearerTokenAuthenticator

from ibm_watson_openscale import *
from ibm_watson_openscale.supporting_classes.enums import *
from ibm_watson_openscale.supporting_classes import *


authenticator = IAMAuthenticator(apikey=API_KEY)
client = APIClient(authenticator=authenticator)
client.version

## Import libraries

In [None]:
import wget
import os

from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.schema.document import Document

from ibm_metrics_plugin.common.utils.constants import ExplainabilityMetricType
from ibm_metrics_plugin.metrics.explainability.entity.explain_config import ExplainConfig
from ibm_metrics_plugin.common.utils.constants import InputDataType, ProblemType

import pandas as pd
import nltk
import warnings
import json

## Source attribution detection for RAG based response for LLMs

Source Attrbution for RAG based response is computed using Protodash Explainer. The information needed for this computation:
1. Response data for which source attribution has to be identified. This is considered as input data.
2. Context information retained using RAG. This is considered as reference data.

Using the above information, prototypes of the input are identified. Using this technique the source, in the context which has contributed to the response is identified.

## Explanation

Source attribution can be understood using the weights (the attribution/contribution factor) and the prototypes (the relevant context/source) which has influenced the response by the foundation model behind the scenes. For example, a weight of 1.0 indicates that a single paragraph of the context informatino has contributed to the response by foundation model. Likewise, weights of three differing values indicate that three paragraphs have contributed to the response by foundation model behind the scenes. The prototype values are the paragraphs supplied as part of the relevant context.

## Define the function to get the RAG response

In [None]:
def get_rag_response(rag_documents, rag_question_list):
    text_splitter = CharacterTextSplitter(chunk_size=1200, chunk_overlap=0)
    texts = text_splitter.split_documents(rag_documents)
    
    # Create an embedding function
    embeddings = HuggingFaceEmbeddings()
    docsearch = Chroma.from_documents(texts, embeddings)
    
    # Generate a retrieval-augmented response to a question
    qa = RetrievalQA.from_chain_type(llm=watsonx_granite, chain_type="stuff", retriever=docsearch.as_retriever())
    
    questions = rag_question_list
        
    responses = []
    contexts = []
    for query in questions:
        #Retrive relevant context for each question from the vector db
        docs = docsearch.as_retriever().get_relevant_documents(query)

        context = []
        #Extract the needed information
        for doc in docs:
            context.append(doc.to_json()['kwargs']['page_content'])

        #Capture the context
        contexts.append(context)

        #Run the prompt and get the response
        response = qa.run(query)
        responses.append(response)
    
    #Print the result
    for query in questions:
        print(f"{query} \n {responses[questions.index(query)]} \n")
        
    data = pd.DataFrame({"generated_text":responses,"context":contexts})
    data.head()
    
    return data
    

## Define the function to compute the explanations

In [None]:
def compute_protodash(rag_dataframe, rag_question_list):
    embeddings = HuggingFaceEmbeddings()
    # Update the configuration for source attribution
    config_json = {
            "configuration": {
                "input_data_type": InputDataType.TEXT.value,
                "problem_type": ProblemType.QA.value,
                "feature_columns":["context"],
                "prediction": "generated_text", #Column name that has the prompt response from FM
                "context_column": "context",
                "explainability": {
                    "metrics_configuration":{
                        ExplainabilityMetricType.PROTODASH.value:{
                                    "embedding_fn": embeddings.embed_documents #Make sure to supply the embedded function else TfIDfvectorizer will be used
                                }
                    }
                }
            }
        }
    warnings.filterwarnings("ignore")
    results = client.ai_metrics.compute_metrics(configuration=config_json,data_frame=rag_dataframe)
    
    metrics = results.get("metrics_result")
    results = metrics.get("explainability").get("protodash")
    
    for idx, entry in enumerate(results):
        print(f"====idx:{idx}: Question:{rag_question_list[idx]} Response:{rag_dataframe['generated_text'][idx]}====")
        print(json.dumps(entry,indent=4))

# USE CASE 1: Mafia History

Download the first document, a history of mafia involvement in Las Vegas, from a Las Vegas Review Journal article by Jeff German.

In [None]:
# Get the file
filename = 'mafia_history.txt'
url = 'https://raw.githubusercontent.com/ericmartens/rag-explain/refs/heads/main/mafia_history.txt'

if not os.path.isfile(filename):
    wget.download(url, out=filename)

# Build up the knowledge base
loader = TextLoader(filename)
documents = loader.load()

query0 = "Why was Bugsy Seigel killed?"
query1 = "Which political figures have tried to curtail mob activities in Las Vegas?"
query2 = "What casinos were owned by the mob?"

question_list = [query0, query1, query2]

rag_data = get_rag_response(documents, question_list)
compute_protodash(rag_data, question_list)

# USE CASE 2: The Super Bowl

Download the second document, a Wikipedia entry for the origins and overview of the Super Bowl.

In [None]:
# Get the file
filename = 'super_bowl.txt'
url = 'https://raw.githubusercontent.com/ericmartens/rag-explain/refs/heads/main/super_bowl.txt'

if not os.path.isfile(filename):
    wget.download(url, out=filename)

# Build up the knowledge base
loader = TextLoader(filename)
documents = loader.load()

q0 = 'Who is the Lombardi Trophy named for?'
q1 = 'Which teams have never appeared in a Super Bowl?'
q2 = 'Where did the Super Bowl get its name?'

question_list = [q0, q1, q2]

rag_data = get_rag_response(documents, question_list)
compute_protodash(rag_data, question_list)

# Custom content

Please refer to your lab guide for instructions on using custom text in the lab.

In [None]:
# Build up the knowledge base
custom_text = streaming_body_1.read().decode('utf-8', errors='ignore')
text_splitter = CharacterTextSplitter(chunk_size=1200, chunk_overlap=100)
documents = [Document(page_content=x) for x in text_splitter.split_text(custom_text)]

query0 = "Question text 1"
query1 = "Question text 2"
query2 = "Question text 3"

question_list = [query0, query1, query2]

rag_data = get_rag_response(documents, question_list)
compute_protodash(rag_data, question_list)