In [1]:
import os

import pandas as pd
import tiktoken

from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import (
    LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore

In [2]:
INPUT_DIR = "./output"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "community_reports"
ENTITY_TABLE = "entities"
COMMUNITY_TABLE = "communities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2


# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")

entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)

print(f"Entity count: {len(entity_df)}")
entity_df.head()

relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

# covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")

# claims = read_indexer_covariates(covariate_df)

# print(f"Claim records: {len(claims)}")
# covariates = {"claims": claims}


report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, community_df, COMMUNITY_LEVEL)

print(f"Report records: {len(report_df)}")
report_df.head()

text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
# text_unit_df.head()



Entity count: 577
Relationship count: 467
Report records: 72
Text unit records: 58


In [3]:
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager

api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]

chat_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.OpenAIChat,
    model=llm_model,
    max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
    name="local_search",
    model_type=ModelType.OpenAIChat,
    config=chat_config,
)

token_encoder = tiktoken.encoding_for_model(llm_model)

embedding_config = LanguageModelConfig(
    api_key=api_key,
    type=ModelType.OpenAIEmbedding,
    model=embedding_model,
    max_retries=20,
)

text_embedder = ModelManager().get_or_create_embedding_model(
    name="local_search_embedding",
    model_type=ModelType.OpenAIEmbedding,
    config=embedding_config,
)

In [4]:
context_builder = LocalSearchMixedContext(
    community_reports=reports,
    text_units=text_units,
    entities=entities,
    relationships=relationships,
    # if you did not run covariates during indexing, set this to None
    covariates=None,
    entity_text_embeddings=description_embedding_store,
    embedding_vectorstore_key=EntityVectorStoreKey.ID,  # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
    text_embedder=text_embedder,
    token_encoder=token_encoder,
)

In [5]:
local_context_params = {
    "text_unit_prop": 0.5,
    "community_prop": 0.1,
    "conversation_history_max_turns": 5,
    "conversation_history_user_turns_only": True,
    "top_k_mapped_entities": 10,
    "top_k_relationships": 10,
    "include_entity_rank": True,
    "include_relationship_weight": True,
    "include_community_rank": False,
    "return_candidate_context": False,
    "embedding_vectorstore_key": EntityVectorStoreKey.ID,  # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
    "max_tokens": 12_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
}

model_params = {
    "max_tokens": 2_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
    "temperature": 0.0,
}

search_engine = LocalSearch(
    model=chat_model,
    context_builder=context_builder,
    token_encoder=token_encoder,
    model_params=model_params,
    context_builder_params=local_context_params,
    response_type="single paragraph",  # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report
)

In [6]:
result = await search_engine.search("Hot oil fell on my arm and burned me. I have a blister on my arm. What should I do?")  # query	")
print(result.response)	

If you have a burn from hot oil that has resulted in a blister, it's important to take immediate steps to care for the burn and prevent complications. First, cool the burn by running cool (not cold) water over the affected area for 10 to 20 minutes to reduce pain and swelling. Avoid using ice, as it can cause further damage to the skin. Do not pop the blister, as this can increase the risk of infection. Instead, cover the burn with a clean, non-stick bandage or dressing to protect it. Over-the-counter pain relief, such as ibuprofen or acetaminophen, can help manage discomfort. Keep the burn clean and dry, and monitor for signs of infection, such as increased redness, swelling, or pus. If the burn is large, very painful, or shows signs of infection, seek medical attention promptly. It's also important to consider that burns can lead to complications like hypovolemia and hypothermia, especially if they cover a large area or are severe [Data: Reports (11, 43); Entities (392, 393, 398); Re

In [7]:
result.context_data["entities"].head()

Unnamed: 0,id,entity,description,number of relationships,in_context
0,393,THERMAL BURN,"A type of burn caused by exposure to heat, inc...",1,True
1,392,BURN INJURY,Burn injury refers to damage to the skin or ot...,15,True
2,395,CHEMICAL BURN,"Burns caused by exposure to chemicals, which c...",1,True
3,400,CIRCUMFERENTIAL BURNS,"Burns that encircle a body part, potentially l...",1,True
4,397,RADIATION BURN,"Burns caused by exposure to radiation, requiri...",1,True


In [8]:
result.context_data["relationships"].head()


Unnamed: 0,id,source,target,description,weight,links,in_context
0,295,BURN INJURY,HYPOTHERMIA,Hypothermia is a risk in burn patients due to ...,7.0,1.0,True
1,286,BURN INJURY,INHALATION INJURY,Inhalation injuries are often associated with ...,7.0,1.0,True
2,285,BURN INJURY,THERMAL BURN,Thermal burns are a type of burn injury caused...,8.0,,True
3,287,BURN INJURY,CHEMICAL BURN,Chemical burns are a type of burn injury cause...,8.0,1.0,True
4,288,BURN INJURY,ELECTRICAL BURN,Electrical burns are a type of burn injury cau...,8.0,1.0,True


In [9]:
if "reports" in result.context_data:
    result.context_data["reports"].head()

In [10]:
result.context_data["sources"].head()

Unnamed: 0,id,text
0,37,. Place skin in anatomic position if flat avul...
1,46,are linear. At times care must be adjusted \...
2,36,\nG. Knee \n1. Vascular and nerve damage \n2....
3,43,risk of injury and \nalter the patient’s resp...
4,48,Multiple organ injury common \n \nV. Extremit...
