In [1]:
from elasticsearch import Elasticsearch, helpers
import json
import pandas as pd
import requests
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from itertools import combinations

# Load a pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
url = "http://localhost:11434/api/generate"
es_client = Elasticsearch('http://localhost:9200')

# Create an index for arXiv papers
index_name = "arxiv_papers"
mapping = {
    "mappings": {
        "properties": {
            "id": {"type": "keyword"},
            "title": {"type": "text"},
            "abstract": {"type": "text"},
            "authors": {"type": "text"},
            "categories": {"type": "keyword"},
            "published_date": {"type": "date"},
            "updated_date": {"type": "date"},
            "doi": {"type": "keyword"},
            "journal_ref": {"type": "text"},
            "comments": {"type": "text"}
        }
    }
}

# Create the index with the mapping
if not es_client.indices.exists(index=index_name):
    es_client.indices.create(index=index_name, body=mapping)

df = pd.read_json('data/arxiv.json', lines=True)
print('Data is read')
# Convert DataFrame to a list of dictionaries
papers = df.to_dict(orient='records')
if es_client.indices.exists(index=index_name):
    print(f"Index '{index_name}' already exists. Skipping indexing.")
else:
    print(f"Index '{index_name}' does not exist. Proceeding with indexing.")

  from .autonotebook import tqdm as notebook_tqdm


Data is read
Index 'arxiv_papers' already exists. Skipping indexing.


In [None]:
# Index the papers into Elasticsearch
actions = [
    {
        "_index": index_name,
        "_id": paper['id'],  # Use the arXiv ID as the document ID
        "_source": paper
    }
    for paper in tqdm(papers)
]

helpers.bulk(es_client, actions)

In [2]:
def retrieve_documents(query, top_k=5):
    search_query = {
        "query": {
            "multi_match": {
                "query": query,
                "fields": ["title", "abstract", "authors"]
            }
        },
        "size": top_k
    }

    response = es_client.search(index=index_name, body=search_query)
    return [hit['_source'] for hit in response['hits']['hits']]

def generate_response(query, documents):
    # Combine the documents into a single context
    context = "\n\n".join([f"Title: {doc['title']}\nAbstract: {doc['abstract']}" for doc in documents])

    # Prepare the prompt for Llama3
    prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"

    # Send the prompt to Llama3
    data = {
        "model": "llama3.3",
        "prompt": prompt,
        "stream": False
    }

    response = requests.post(url, json=data).json()
    return response["response"]

def compute_similarity(text1, text2):
    """
    Compute cosine similarity between two texts using sentence embeddings.
    """
    embeddings = model.encode([text1, text2])
    similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
    return similarity

def marginal_contribution(query, documents, document_index):
    """
    Compute the marginal contribution of a document to the model's output.
    """
    # Generate the response with all documents
    response_with_all = generate_response(query, documents)

    # Generate the response without the specified document
    documents_without = [doc for i, doc in enumerate(documents) if i != document_index]
    response_without = generate_response(query, documents_without)

    # Compute the difference in responses (e.g., using cosine similarity)
    similarity = compute_similarity(response_with_all, response_without)
    return similarity

def compute_exact_shapley_values(query, documents):
    """
    Compute exact Shapley values for the retrieved documents.
    """
    num_documents = len(documents)
    shapley_values = np.zeros(num_documents)

    # Generate all possible subsets of documents
    all_subsets = []
    for subset_size in range(num_documents + 1):
        all_subsets.extend(combinations(range(num_documents), subset_size))

    # Iterate over each document
    for i in tqdm(range(num_documents)):
        contributions = []

        # Iterate over all subsets that exclude the current document
        for subset in all_subsets:
            if i not in subset:
                # Include the current document in the subset
                subset_with = list(subset) + [i]

                # Exclude the current document from the subset
                subset_without = list(subset)

                # Compute the marginal contribution
                response_with = generate_response(query, [documents[j] for j in subset_with])
                response_without = generate_response(query, [documents[j] for j in subset_without])
                similarity = compute_similarity(response_with, response_without)
                contributions.append(similarity)

        # Average the contributions to get the Shapley value
        shapley_values[i] = np.mean(contributions)

    return shapley_values

def visualize_datashap_values(datashap_values, documents):
    """
    Visualize the DataSHAP values for the retrieved documents.
    """
    # Plot the Shapley values
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(datashap_values)), datashap_values, tick_label=[f"Document {i+1}" for i in range(len(datashap_values))])
    plt.xlabel("Documents")
    plt.ylabel("Shapley Value")
    plt.title("DataSHAP Values for Retrieved Documents")
    plt.show()

def rag_pipeline_with_exact_shapley(query):
    """
    RAG pipeline with exact Shapley value computation.
    """
    # Step 1: Retrieve relevant documents
    documents = retrieve_documents(query)

    # Step 2: Generate response using Llama3
    response = generate_response(query, documents)

    # Step 3: Compute exact Shapley values
    shapley_values = compute_exact_shapley_values(query, documents)

    # Step 4: Visualize Shapley values
    visualize_datashap_values(shapley_values, documents)

    return response, documents, shapley_values

In [3]:
# Example query
query = "Are explainability methods susceptible to class outliers?"

# Run the RAG pipeline with DataSHAP
response, documents, datashap_values = rag_pipeline_with_exact_shapley(query)

# Print the response
print("Response:", response)

# Print the DataSHAP values
print("DataSHAP Values:", datashap_values)

100%|██████████| 5/5 [33:40<00:00, 404.16s/it]


NameError: name 'plt' is not defined

In [None]:
data = {
    "model": "llama3.3",
    "prompt": "Are explainability methods susceptible to class outliers?",
    "stream": False
}

response = requests.post(url, json=data).json()
print(response["response"])  # Print only the response text


ConnectionError: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f613374eab0>: Failed to establish a new connection: [Errno 111] Connection refused'))