In [2]:
import pandas as pd
import re
import pickle
import requests
import numpy as np
from tqdm import tqdm
import faiss
from sklearn.metrics.pairwise import cosine_similarity
from itertools import combinations
import math
from sentence_transformers import SentenceTransformer
tqdm.pandas()

splits = {'train': 'question-answer-passages/train-00000-of-00001.parquet', 'test': 'question-answer-passages/test-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/" + splits["train"])

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df1 = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/text-corpus/test-00000-of-00001.parquet")
df1['passage']=df1['passage'].str.replace(r'[\n]', ' ', regex=True)
df['question']=df['question'].str.replace(r'[\n]', ' ', regex=True)

In [9]:
def gen_embs(qtext, model="nomic"):
    if model=="nomic":
    
        data = {
            "model": "nomic-embed-text",
            "prompt": qtext
        }
        return np.array(requests.post('http://localhost:11434/api/embeddings', json=data).json()['embedding'])
    else:
        return SentenceTransformer("abhinand/MedEmbed-large-v0.1").encode([qtext], convert_to_numpy=True)
    
# df1['embedding'] = df1['passage'].progress_apply(lambda x: gen_embs(x, model='medemb'))

# # Save the embeddings to a pickle file
# with open('embed_bioasq_medemb.pkl', 'wb') as f:
#     pickle.dump(df1['embedding'].tolist(), f)
# print("Embeddings saved to embed_bioasq_medemb.pkl")

In [5]:
with open("data/embed_bioasq_medemb.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [6]:
def normalize_embeddings(embeddings):
    return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

embeddings = normalize_embeddings(embeddings)

def index_documents(method="faiss", index_name="recipes_nomic", es_host="http://localhost:9200"):
    if method == "faiss":
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        faiss.write_index(index, "data/bioasq_nomic_faiss.index")
        print("FAISS index saved.")
        return index
    elif method == "elasticsearch":
        es = Elasticsearch(es_host)
        mapping = {"mappings": {"properties": {"text": {"type": "text"}, "vector": {"type": "dense_vector", "dims": embeddings.shape[1]}}}}
        es.indices.create(index=index_name, body=mapping, ignore=400)
        for i, (text, vector) in enumerate(zip(documents, embeddings)):
            es.index(index=index_name, id=i, body={"text": text, "vector": vector.tolist()})
        print("Elasticsearch index created.")
        return es
# index_documents(method="faiss", index_name="bioasq_nomic_faiss", es_host="http://localhost:9200")
faiss_index = faiss.read_index("data/medemb_bioasq_faiss.index")


In [130]:
def retrieve_documents(query, k=5):
    query_embedding = gen_embs(query, model="medemb")
    query_embedding = normalize_embeddings(query_embedding.reshape(1, -1))
    scores, indices = faiss_index.search(query_embedding, k)
    return [df1['passage'][i] for i in indices[0]], scores

### 3️⃣ Query RAG Pipeline ###
def query_rag(query, retrieved_docs=None):
    """Query the Ollama API with a prompt based on provided documents."""
    if retrieved_docs:
        retrieved_text = "\n".join(retrieved_docs)
        prompt = f"Using only the following context, answer the question. \nContext:{retrieved_text} \n Be concise and short. If you can not find a relevant information from the provided, then just answer -I do not have this information.\n Question: {query}. Answer:"
    else:
        prompt= f'You are an expert in medical domain. Given the question provide a relevant answer. Question: {query}. Answer:'
    HEADERS = {"Content-Type": "application/json"}
    url = "http://localhost:11434/api/generate"
    data = {"model": "llama3.3", "prompt": prompt, 'stream': False}

    response = requests.post(url, headers=HEADERS, json=data)

    return response.json()['response']

In [138]:
query='List the common retinal diseases associated with circRNA and relate to tumorigenesis?'
docs, scores=retrieve_documents(query=query, k=5)

In [134]:
for a, b in zip(scores[0], docs):
    print(f"{a}\t{b[:230]}\n")

0.43578189611434937	Circular RNAs (circRNAs) are a novel class of endogenous non-coding RNAs  produced by back-splicing. They are found to be expressed in eukaryotic cells  and play certain roles in various cellular functions, including fibrosis, cel

0.4444296360015869	Retinal neovascularization is a complication which caused human vision loss  severely. It has been shown that circular RNAs (circRNAs) play essential roles  in gene regulation. However, circRNA expression profile and the underlyin

0.4829050302505493	A newly rediscovered subclass of noncoding RNAs, circular RNAs (circRNAs), is  produced by a back-splicing mechanism with a covalently closed loop structure.  They not only serve as the sponge for microRNAs (miRNAs) and proteins b

0.5035773515701294	In diabetic patients, diabetic retinopathy (DR) is the leading cause of  blindness and seriously affects the quality of life. However, current treatment  methods of DR are not satisfactory. Advances have been made in understand

In [139]:
[df1[df1['passage']==i]['id'] for i in docs]

[37541    31171902
 Name: id, dtype: int64,
 37864    31692917
 Name: id, dtype: int64,
 38997    33015046
 Name: id, dtype: int64,
 38574    32519377
 Name: id, dtype: int64,
 39461    33761053
 Name: id, dtype: int64]

In [119]:
resp=query_rag(query=query, retrieved_docs=docs)

In [140]:
resp=query_rag(query=query, retrieved_docs=docs)

In [141]:
print(resp)

Based on the context, here are some retinal diseases associated with circRNA:

1. Diabetic retinopathy
2. Retinoblastoma (related to tumorigenesis)
3. Retinal neovascularization 
4. Proliferative vitreoretinopathy 

Note: The exact relationship between these diseases and tumorigenesis is not fully explained in the provided context, except for retinoblastoma which is directly related to tumorigenesis.


In [None]:
def F(subset, full_set_embedding):
    """
    Cost function: Cosine similarity between the LLM's response for the subset and the full set.
    """
    if not subset:
        return 0.0  # Empty subset has no contribution
    
    # Query the LLM with the subset
    response = query_rag(query, subset)
    
    # Generate and normalize embeddings for the subset's response
    subset_embedding = normalize_embeddings(gen_embs(response, model='medemb').reshape(1, -1))
    
    # Compute cosine similarity with the full set's embedding
    return cosine_similarity(subset_embedding, full_set_embedding)

def shapley_values(S):
    """
    Compute Shapley values for a set of textual queries S.
    """
    S = list(S)
    n = len(S)
    
    # Query the LLM with the full set to get the reference embedding
    full_set_response = query_rag(query, S)
    full_set_embedding = normalize_embeddings(gen_embs(full_set_response, model='medemb').reshape(1, -1))
    
    # Precompute the cost for all subsets
    F_cache = {}
    for bitmask in tqdm(range(0, 1 << n), desc="Calculating cosine to full response"):
        subset = [S[i] for i in range(n) if (bitmask & (1 << i))]
        if len(subset)==n:
            F_cache[bitmask] = 1
        else:
            F_cache[bitmask] = F(subset, full_set_embedding)
    
    # Initialize Shapley values
    # shap = {element: 0.0 for element in S}
    shap = np.zeros(len(S))
    
    # Calculate contributions for each subset
    for bitmask in tqdm(range(0, 1 << n), desc="Calculating shap"):
        subset_size = bin(bitmask).count('1')
        if subset_size == 0:
            continue  # Skip empty subsets
        
        for i in range(n):
            if not (bitmask & (1 << i)):
                continue  # Skip subsets without the current element
            
            # Compute subset without the current element
            subset_without_i = bitmask ^ (1 << i)
            
            # Compute Shapley weight
            k = bin(subset_without_i).count('1')
            weight = (math.factorial(k) * math.factorial(n - k - 1)) / math.factorial(n)
            
            # Compute marginal contribution
            marginal = F_cache[bitmask] - F_cache[subset_without_i]
            shap[i] += marginal * weight
    
    return shap

def ragshap(values, retrival_type='max_shap'):
    if retrival_type=='max_shap':
        new_query=docs[values.argmax()]
    else:
        new_query=query_rag(query=query, retrieved_docs=docs[values.argmax()])
    new_docs, new_scores=retrieve_documents(query=new_query, k=5)
    return query_rag(query=query, retrieved_docs=docs)


In [160]:
shap=shapley_values(docs)

Calculating cosine to full response: 100%|██████████| 32/32 [06:13<00:00, 11.66s/it]
  shap[i] += marginal * weight
Calculating shap: 100%|██████████| 32/32 [00:00<00:00, 24105.20it/s]


In [165]:
ragshap(shap, retrival_type='max_shap')

'Based on the context, here are some common retinal diseases associated with circRNA:\n\n1. Diabetic retinopathy (DR)\n2. Retinoblastoma\n3. Retinal neovascularization \n4. Proliferative vitreoretinopathy \n\nThese circRNAs may play a role in tumorigenesis, particularly in retinoblastoma, which is a type of eye cancer.'