In [1]:
import os

import pandas as pd
import tiktoken

from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import (
    LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore

from tqdm.notebook import tqdm

In [2]:
INPUT_DIR = "../benchmark/output/"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2

In [None]:
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")

entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")

In [None]:
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")

In [None]:
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)

print(f"Report records: {len(report_df)}")

In [None]:
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_unit_df = text_unit_df.dropna(how='any')
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")

In [7]:
# api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = "gpt-4o-mini" # os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = 'text-embedding-3-small' # os.environ["GRAPHRAG_EMBEDDING_MODEL"]


llm = ChatOpenAI(
    # api_key=api_key,
    model=llm_model,
    api_type=OpenaiApiType.OpenAI,  # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
    max_retries=20,
)

token_encoder = tiktoken.get_encoding("cl100k_base")

text_embedder = OpenAIEmbedding(
    # api_key=api_key,
    api_base=None,
    api_type=OpenaiApiType.OpenAI,
    model=embedding_model,
    deployment_name=embedding_model,
    max_retries=20,
)

In [8]:
context_builder = LocalSearchMixedContext(
    community_reports=reports,
    text_units=text_units,
    entities=entities,
    relationships=relationships,
    # if you did not run covariates during indexing, set this to None
    covariates=None,
    entity_text_embeddings=description_embedding_store,
    embedding_vectorstore_key=EntityVectorStoreKey.ID,  # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
    text_embedder=text_embedder,
    token_encoder=token_encoder,
)

In [9]:
local_context_params = {
    "text_unit_prop": 0.5,
    "community_prop": 0.1,
    "conversation_history_max_turns": 5,
    "conversation_history_user_turns_only": True,
    "top_k_mapped_entities": 10,
    "top_k_relationships": 10,
    "include_entity_rank": True,
    "include_relationship_weight": True,
    "include_community_rank": False,
    "return_candidate_context": False,
    "embedding_vectorstore_key": EntityVectorStoreKey.ID,  # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
    "max_tokens": 12_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
}

llm_params = {
    "max_tokens": 2_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
    "temperature": 0.0,
}

In [10]:
search_engine = LocalSearch(
    llm=llm,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
    response_type="multiple paragraphs",  # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report
)

In [None]:
chunks_document_mapping = {}
chunks_document_mapping_r = {}
entity_document_mapping = {}
entity_document_mapping_r = {}
relationship_document_mapping = {}
relationship_document_mapping_r = {}
reports_documents_mapping = {}

entities_df = pd.read_parquet('../benchmark/output/create_final_entities.parquet')
entity_mapper = {}
for idx, row in entity_df.iterrows():
    entity_mapper[row['id']] = row['human_readable_id']

relationship_mapper = {}
for idx, row in relationship_df.iterrows():
    relationship_mapper[row['id']] = row['human_readable_id']


for idx, row in text_unit_df.iterrows():
    chunks_document_mapping[row['human_readable_id']] = row['document_ids'][0]
    for e in row['entity_ids']:
        entity_document_mapping[entity_mapper[e]] = row['document_ids'][0]
    for r in row['relationship_ids']:
        relationship_document_mapping[relationship_mapper[r]] = row['document_ids'][0]

# for idx, row in report_df.iterrows():
#     reports_documents_mapping[row['id']] = row[]

print(chunks_document_mapping[1])
print(entity_document_mapping[641])
print(relationship_document_mapping[638])

In [None]:
import json

# Read the JSON file
with open('../benchmark/results-benchmark.json', 'r') as f:
    data = json.load(f)

# Create id to title mapping
id_to_title = {item['id']: item['title'] for item in data}
id_to_title

# Create Responses

## Set Access Control

In [13]:
def create_document_access_mapping():
    """Create a mock document access mapping with fictional user and document IDs."""
    return {
        "2a5249a0a6455998a6380194c2e0396894f20be186335de35b3a05dd1ee3aa0ffccb7acaf5cae36971bf8e41403e369e5cd8727aaceb6d6739c0a809fe513cf6": ["user1", "admin"],
        "60dd3abebec49fd661cf23fa4abbe211d9cdac2d404e1d3c1c6033abdac768f7bce7f45c8ddca2fcb1a5f30328d32329ba3f945fb5b64c01db8b8f98e7def4a6": ["user1", "admin"],
        "8b197120432108f3a0b055d6623358e7472338fcacbcf2116628b36091940f4880948496d9554442010bc02f25250af299e30399b9eaa857db5b552e215ca48b": ["admin"],
        "913c8b875eac2a1e19c34de9f136ec6d79577f71af2a37a7094e0d6312833c3a48904ac32ab9d79d6b68aa9cd281ecd5b2ab0c3b03beaef417db7f50d8854be2": ["user1", "admin"],
        "9685dd41c9837a9f1b5561eaa29359a3e911d202232ed51665ddcd062d7900928e0705d17f5f932ea81fd89d5251c29984038fcf61598dd25e2a0635f21e6415": ["user1", "admin"],
        "b4bfb7c4d56b91f3c1901805f5c3ff3b9d80a26d20023b3614af41bde57839e2db37e303d51946bf91738bdfee19014ed55eef91bf079b941dc3d791eb20a392": ["user1", "admin"],
        "bec5c29288289e3d4be72eaf2d10228c199a545cb363344bb35c4e8f0024a3243e4f8abdf196ed3288832e72a93cc9e042a53eb35c862c123681f0239ecbb0b4": ["user1", "admin"],
        "c6b30efb0aededdf34730d9958cf585ac6c304fb5290881d28588b4f26b9197fd0da9877829af4a11386f23eb79ba3e4afb2aa5053ccfc8261579267ddfe5a4f": ["admin"],
        "e8429b1a8c6a7f5151f887658dea4e957abb4366e169e55bd9b4c1ead6f37fc6eb9f9aca80d8cf5764c4fad3f745644edc5fd5a01d7bbcdb32f69dcaa3ca55f2": ["user1", "admin"],
    }


## Response Filter accoridng to access control

In [14]:
import textwrap

def filter_response_by_access(response_dict, access_level, user_id=None):
    """
    Filter the response dictionary based on access level and user permissions.
    
    Args:
        response_dict (dict): Original response containing all information
        access_level (str): One of 'KG_ONLY', 'CHUNKS', 'FULL', 'DOCUMENT_LEVEL'
        user_id (str): Required for DOCUMENT_LEVEL access checking
    
    Returns:
        dict: Filtered response based on access level
    """
    if access_level == "KG_ONLY":
        return {
            "relationships": response_dict["relationships"],
            "entities": response_dict["entities"]
        }
    
    elif access_level == "CHUNKS":
        return {
            "relationships": response_dict["relationships"],
            "entities": response_dict["entities"],
            "sources": response_dict["sources"]
        }
    
    elif access_level == "FULL":
        return response_dict
    
    elif access_level == "DOCUMENT_LEVEL":
        if not user_id:
            raise ValueError("user_id is required for DOCUMENT_LEVEL access")
        
        doc_access = create_document_access_mapping()
        accessible_docs = {doc_id for doc_id, users in doc_access.items() 
                         if user_id in users}
        # print(accessible_docs)
        
        filtered_response = {}
        
        # Filter entities
        if "entities" in response_dict:
            filtered_response["entities"] = [
                entity for _, entity in response_dict["entities"].iterrows()
                if entity_document_mapping.get(int(entity["id"])) in accessible_docs
            ]
        
        # Filter relationships
        if "relationships" in response_dict:
            filtered_response["relationships"] = [
                rel for _, rel in response_dict["relationships"].iterrows()
                if relationship_document_mapping.get(int(rel["id"])) in accessible_docs
            ]
        
        # Filter sources
        if "sources" in response_dict:
            filtered_response["sources"] = [
                source for _, source in response_dict["sources"].iterrows()
                if chunks_document_mapping.get(int(source["id"])) in accessible_docs
            ]
        
        # Include reports if present and user has access
        if "reports" in response_dict:
            reports_ = []
            ids_ = response_dict["reports"]['id'].unique()
            for id in ids_:
                docs_ = list(set([entity_document_mapping[x] for x in entity_df[entity_df['community'] == int(id)]['human_readable_id'].values]))
                # print(docs_)
                # print(accessible_docs)
                # include = True
                for doc in docs_:
                    # print(doc in accessible_docs)
                    if doc in accessible_docs:
                        report_string = ''
                        # for row in response_dict['reports'].loc[response_dict['reports']['id'] == str(id)].values:
                        report_string += response_dict['reports'].loc[response_dict['reports']['id'] == str(id)]['title'].values
                        report_string += response_dict['reports'].loc[response_dict['reports']['id'] == str(id)]['content'].values
                        # print(f"Not allowed {doc}")
                        # include = False
                
                # if include:
                #     for row in response_dict['reports'].values:
                #         report_string += row[2]

                    # report_string = f"{response_dict['reports'].loc[response_dict['reports']['id'] == str(id)]['title']} - {response_dict['reports'].loc[response_dict['reports']['id'] == str(id)]['content']}"
                        reports_.append(report_string)

            filtered_response['reports'] = reports_
                    
            # filtered_response['reports'] = []
            # for e in entities_:
            #     doc = entity_document_mapping[entity_df[entity_df['community'] == e]['human_readable_id'].values[0]]
            #     if doc not in accessible_docs:
            #         include = False
            #     filtered_response['reports'].append()
            #     filtered_response['reports'] = [
            #         source for _, source in response_dict["reports"].iterrows()
            #         if chunks_document_mapping.get(int(source["id"])) in accessible_docs
            #     ]
            #     # filtered_response["reports"] = response_dict["reports"]
        
        return filtered_response
    
    else:
        raise ValueError("Invalid access level")

## Access Control Test

In [15]:
from questions.apple import quiz_questions as apple_questions
from questions.cs2 import quiz_questions as cs2_questions

In [None]:
apple_questions['1']

In [17]:
result = search_engine.search(apple_questions['1']['question'])

In [None]:
# Example calls:
kg_only_result = filter_response_by_access(result.context_data, "KG_ONLY")
chunks_result = filter_response_by_access(result.context_data, "CHUNKS")
full_result = filter_response_by_access(result.context_data, "FULL")
print(result.context_data.keys())
print()

for u in ["user1", "admin"]:
    document_level_result = filter_response_by_access(result.context_data, "DOCUMENT_LEVEL", user_id=u)
    print(f"User {u}:")
    print(len(document_level_result['entities']))
    print(len(document_level_result['relationships']))
    print(len(document_level_result['sources']))
    print()
    print('#'*10)

In [None]:
document_level_result = filter_response_by_access(result.context_data, "DOCUMENT_LEVEL", user_id='user1')
document_level_result['entities'][:2]

for e in document_level_result['entities'][:2]:
    print(f"Entity: {e['id']} belongs to document: {id_to_title[entity_document_mapping.get(int(e['id']))]}")

# Generate responses

In [20]:
from pathlib import Path
import ollama

def generate_search_prompt(context_data: str, response_type: str = "multiple paragraphs") -> str:
    """
    Generate a prompt using the local search system prompt template
    
    Args:
        context_data: The context data to include in the prompt
        response_type: The desired response format/length
    
    Returns:
        str: The formatted prompt
    """
    # Read template file
    template_path = Path("../benchmark/prompts/local_search_system_prompt.txt")
    with open(template_path, "r") as f:
        template = f.read()
    
    # Replace variables
    prompt = template.replace("{response_type}", response_type)
    # content_str = context_data['reports']
    prompt = prompt.replace("{context_data}", context_data)
    
    return prompt

def create_llm_response(model, prompt, question):
    # Use Ollama to test the prompt
    response = ollama.chat(
        model=model,  # or your preferred model
        messages=[{
            'role': 'user',
            'content': prompt
        }, {
            'role': 'user',
            'content': question
        }],
        options={'timeout': 1000}
    )
    return response # response['message']['content']

def save_llm_response(model: str, access: str, question_num: int, response: str):
    """Save LLM response to a file in the specified folder structure."""
    # Create base directory if it doesn't exist
    base_dir = "../benchmark/llm_output"
    model_dir = f"{base_dir}/{model.replace(':', '_')}"
    os.makedirs(model_dir, exist_ok=True)
    
    # Create filename with access type and question number
    filename = f"{access}_{question_num}_2.txt"
    filepath = os.path.join(model_dir, filename)
    
    # Save response to file
    with open(filepath, "w") as f:
        f.write(response)

### Test

In [None]:
model_1 = dict(dict(ollama.list())['models'][7])['model']
model_1

In [None]:
document_level_result = filter_response_by_access(result.context_data, "DOCUMENT_LEVEL", user_id='admin')
# Generate prompt
prompt = generate_search_prompt(
    context_data=str(document_level_result['reports']),
    response_type="multiple paragraphs"
)
# response = create_llm_response(model_1, prompt, apple_questions['1']['question'])
# print(prompt)
response = create_llm_response(model_1, prompt, apple_questions['1']['question'])
# print()
print(response['message']['content'])

In [None]:
print(f"Time taken: {round(response['total_duration'] / 1000000000,2)} seconds")
print(response['message']['content'])
print(response['eval_count'])

# RUN

In [56]:
ollama_models = ollama.list()
ollama_models = [dict(x)['model'] for x in dict(ollama.list())['models']]

In [None]:
merged_dict = {
    str(i+1): value 
    for i, value in enumerate(list(apple_questions.values()) + list(cs2_questions.values()))
}
merged_dict.items()

In [26]:
# Define the expected structure
expected_keys = {'question', 'answer', 'reference'}

# Function to check and fix dictionary structure
def validate_and_fix_dict(input_dict):
    fixed_dict = {}
    
    for key, value in input_dict.items():
        # Create a new entry for this item
        fixed_entry = {}
        
        # Check if any key needs to be fixed
        for k, v in value.items():
            # If 'ion' is found, change it to 'question'
            if k == 'ion':
                fixed_entry['question'] = v
            else:
                fixed_entry[k] = v
        
        # Check if all expected keys are present
        for expected_key in expected_keys:
            if expected_key not in fixed_entry:
                print(f"Warning: Missing key '{expected_key}' in entry {key}")
                fixed_entry[expected_key] = ''  # Add empty string for missing keys
        
        fixed_dict[int(key)] = fixed_entry
    
    return fixed_dict

# Use the function
fixed_dict = validate_and_fix_dict(merged_dict)

In [27]:
kg_responses = {}

for idx, q in fixed_dict.items():
    kg_responses[int(idx)] = search_engine.search(q['question'])

In [None]:
len(ollama_models) * 4*2*len(fixed_dict.items())

In [None]:
import signal
from contextlib import contextmanager
import time

class TimeoutException(Exception):
    pass

@contextmanager
def timeout(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    
    # Set the signal handler and a timeout
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    
    try:
        yield
    finally:
        # Disable the alarm
        signal.alarm(0)

n_model = []
n_access = []
n_user = []
n_question = []
n_response = []
n_answer = []
n_reference = []
n_response_time = []
n_reponse_length = []
n_response_tokens = []

counter = 0
# Update the loop to save responses
for m in ollama_models:
    for access in ['DOCUMENT_LEVEL']: # ["KG_ONLY", "CHUNKS", "FULL", "DOCUMENT_LEVEL"]:
        for u in ["user1", "admin"]:
            for idx, q in fixed_dict.items():
                try:
                    with timeout(60):  # Set 60 second timeout
                        # Query KG
                        document_level_result = filter_response_by_access(kg_responses[idx].context_data, access, user_id=u)
                        # Generate prompt
                        prompt = generate_search_prompt(
                            context_data=str(document_level_result['reports']),
                            response_type="multiple paragraphs"
                        )
                        response = ollama.chat(
                            model=m,
                            messages=[{
                                'role': 'user',
                                'content': prompt
                            }, {
                                'role': 'user',
                                'content': q['question']
                            }]
                        )
                        
                        n_model.append(m)
                        n_access.append(access)
                        n_user.append(u)
                        n_question.append(q['question'])
                        n_response.append(response['message']['content'])
                        n_answer.append(q['answer'])
                        n_reference.append(q['reference'])
                        n_response_time.append(round(response['total_duration'] / 1000000000,2))
                        n_reponse_length.append(len(response['message']['content']))
                        n_response_tokens.append(response['eval_count'])
                        
                        print(f"{m} - {access} - {u} - {idx}")

                except TimeoutException:
                    print(f"Timeout occurred for {m} - {access} - {u} - {idx}")
                    continue
                except Exception as e:
                    print(f"Error: {e}")
                    continue
                
                if counter % 100 == 0:
                    print(f"Progress: {round(counter / 2240, 2)} %")
                counter += 1

In [59]:
pd.DataFrame({
    'model': n_model,
    'access': n_access, 
    'user': n_user,
    'question': n_question,
    'response': n_response,
    'answer': n_answer,
    'reference': n_reference,
    'response_time': n_response_time,
    'response_length': n_reponse_length,
    'response_tokens': n_response_tokens
}).to_csv('./results/results_0902025.csv', index=False)

In [None]:
df = pd.read_csv('./results/results_0902025.csv')
df.head()