In [4]:
from dotenv import load_dotenv
import os
from langchain.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_anthropic import ChatAnthropic
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain.graphs import Neo4jGraph
from langchain_neo4j import GraphCypherQAChain, Neo4jGraph
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from utils.visualizer import visualize_neo4j_graph

load_dotenv()
# Get API key from environment variable 
api_key = os.getenv("ANTHROPIC_API_KEY")

if not api_key:
    print("Warning: ANTHROPIC_API_KEY not found in environment variables")
    print("Please add ANTHROPIC_API_KEY=your_api_key_here to your .env file")
else:
    print("Anthropic API key loaded successfully")

# Use Claude 3.5 Sonnet with increased max_tokens
llm = ChatAnthropic(
    model="claude-3-5-sonnet-20241022",
    temperature=0.1,  # Very low temperature for consistent extraction
    max_tokens=8192,  # Increased max_tokens to avoid truncation
    anthropic_api_key=api_key
)

# Create graph transformer with the fixed LLM
graph_transformer = LLMGraphTransformer(
    llm=llm,
    node_properties=False,  # Disabled to reduce token usage
    relationship_properties=False  # Disabled to reduce token usage
)

print(f"✓ Claude LLM initialized with max_tokens=8192")
print(f"✓ Graph transformer created")



Anthropic API key loaded successfully
✓ Claude LLM initialized with max_tokens=8192
✓ Graph transformer created


In [7]:
import pandas as pd
from langchain_core.documents import Document

csv_file = r"/Users/kathisnehith/Downloads/healthcare_dataset.csv"
df = pd.read_csv(csv_file)

# ✅ Limit to first 1500 rows
df = df.head(100)

lc_docs = []

for idx, row in df.iterrows():
    # Convert each row to a readable string
    row_str = ", ".join([f"{col}: {val}" for col, val in row.items()])

    # Create LangChain Document with metadata
    lc_docs.append(Document(
        page_content=row_str,
        metadata={'row': idx, 'source': csv_file}
    ))

    # Optional: Print preview
    print(f"Row {idx} processed: {row_str[:10]}...")
print(f"✓ {len(lc_docs)} documents created from CSV rows")


Row 0 processed: Name: Tiff...
Row 1 processed: Name: Rube...
Row 2 processed: Name: Chad...
Row 3 processed: Name: Anto...
Row 4 processed: Name: Mrs....
Row 5 processed: Name: Patr...
Row 6 processed: Name: Char...
Row 7 processed: Name: Patt...
Row 8 processed: Name: Ryan...
Row 9 processed: Name: Shar...
Row 10 processed: Name: Amy ...
Row 11 processed: Name: Mrs....
Row 12 processed: Name: Chri...
Row 13 processed: Name: Will...
Row 14 processed: Name: Mich...
Row 15 processed: Name: Bria...
Row 16 processed: Name: Oliv...
Row 17 processed: Name: Tere...
Row 18 processed: Name: Desi...
Row 19 processed: Name: Sall...
Row 20 processed: Name: Will...
Row 21 processed: Name: Stev...
Row 22 processed: Name: Hale...
Row 23 processed: Name: Ange...
Row 24 processed: Name: Beve...
Row 25 processed: Name: Dani...
Row 26 processed: Name: Kimb...
Row 27 processed: Name: Fran...
Row 28 processed: Name: Ronn...
Row 29 processed: Name: Shan...
Row 30 processed: Name: Tere...
Row 31 processed: 

In [None]:

# Convert to graph documents
graph_documents_lc = graph_transformer.convert_to_graph_documents(lc_docs)
print(lc_docs)


[Document(metadata={'row': 0, 'source': '/Users/kathisnehith/Downloads/healthcare_dataset.csv'}, page_content='Name: Tiffany Ramirez, Age: 81, Gender: Female, Blood Type: O-, Medical Condition: Diabetes, Date of Admission: 2022-11-17, Doctor: Patrick Parker, Hospital: Wallace-Hamilton, Insurance Provider: Medicare, Billing Amount: 37490.98336352819, Room Number: 146, Admission Type: Elective, Discharge Date: 2022-12-01, Medication: Aspirin, Test Results: Inconclusive'), Document(metadata={'row': 1, 'source': '/Users/kathisnehith/Downloads/healthcare_dataset.csv'}, page_content='Name: Ruben Burns, Age: 35, Gender: Male, Blood Type: O+, Medical Condition: Asthma, Date of Admission: 2023-06-01, Doctor: Diane Jackson, Hospital: Burke, Griffin and Cooper, Insurance Provider: UnitedHealthcare, Billing Amount: 47304.06484547511, Room Number: 404, Admission Type: Emergency, Discharge Date: 2023-06-15, Medication: Lipitor, Test Results: Normal'), Document(metadata={'row': 2, 'source': '/Users/k

In [14]:

# nodes and relationships extracted from the second document chunk
print(f"Nodes:{graph_documents_lc[0].nodes}")
print(f"Relationships:{graph_documents_lc[2].relationships}")

Nodes:[Node(id='Tiffany Ramirez', type='Person', properties={}), Node(id='Patrick Parker', type='Person', properties={}), Node(id='Wallace-Hamilton', type='Hospital', properties={}), Node(id='Medicare', type='Organization', properties={}), Node(id='Diabetes', type='Disease', properties={}), Node(id='Aspirin', type='Medicine', properties={}), Node(id='Room 146', type='Location', properties={})]
Relationships:[Relationship(source=Node(id='Chad Byrd', type='Person', properties={}), target=Node(id='Obesity', type='Condition', properties={}), type='HAS_CONDITION', properties={}), Relationship(source=Node(id='Paul Baker', type='Person', properties={}), target=Node(id='Chad Byrd', type='Person', properties={}), type='TREATS', properties={}), Relationship(source=Node(id='Chad Byrd', type='Person', properties={}), target=Node(id='Walton Llc', type='Hospital', properties={}), type='ADMITTED_TO', properties={}), Relationship(source=Node(id='Chad Byrd', type='Person', properties={}), target=Node(i

In [12]:
from pyvis.network import Network

def visualize_graph(graph_documents):

    # Create network
    net = Network(height="1200px", width="100%", directed=True,
                      notebook=False, bgcolor="#222222", font_color="white")
    
    nodes = graph_documents[0].nodes
    relationships = graph_documents[0].relationships

    # Build lookup for valid nodes
    node_dict = {node.id: node for node in nodes}
    
    # Filter out invalid edges and collect valid node IDs
    valid_edges = []
    valid_node_ids = set()
    for rel in relationships:
        if rel.source.id in node_dict and rel.target.id in node_dict:
            valid_edges.append(rel)
            valid_node_ids.update([rel.source.id, rel.target.id])


    # Track which nodes are part of any relationship
    connected_node_ids = set()
    for rel in relationships:
        connected_node_ids.add(rel.source.id)
        connected_node_ids.add(rel.target.id)

    # Add valid nodes
    for node_id in valid_node_ids:
        node = node_dict[node_id]
        try:
            net.add_node(node.id, label=node.id, title=node.type, group=node.type)
        except:
            continue  # skip if error

    # Add valid edges
    for rel in valid_edges:
        try:
            net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
        except:
            continue  # skip if error

    # Configure physics
    net.set_options("""
            {
                "physics": {
                    "forceAtlas2Based": {
                        "gravitationalConstant": -100,
                        "centralGravity": 0.01,
                        "springLength": 200,
                        "springConstant": 0.08
                    },
                    "minVelocity": 0.75,
                    "solver": "forceAtlas2Based"
                }
            }
            """)
        
    output_file = "knowledge_graph.html"
    net.save_graph(output_file)
    print(f"Graph saved to {os.path.abspath(output_file)}")

    # Try to open in browser
    try:
        import webbrowser
        webbrowser.open(f"file://{os.path.abspath(output_file)}")
    except:
        print("Could not open browser automatically")
        
# Run the function
visualize_graph(graph_documents_lc)

Graph saved to /Users/kathisnehith/Desktop/Med-GraphRAG-project/notebooks/knowledge_graph.html


In [16]:
graph = Neo4jGraph(url=os.getenv("NEO4J_URI"), 
                username=os.getenv("NEO4J_USERNAME"), 
                password=os.getenv("NEO4J_PASSWORD"),
                enhanced_schema=True)

In [17]:


# add the graph documents to the Neo4j graph
print("Adding graph documents to Neo4j...")
graph.add_graph_documents(graph_documents_lc, include_source=True)


# Get the schema of the graph
schema = graph.get_schema
print("Sucessfully! Added and Graph schema retrieved.........")
print("Graph schema: \n", schema)


Adding graph documents to Neo4j...
Sucessfully! Added and Graph schema retrieved.........
Graph schema: 
 Node properties:
- **Document**
  - `id`: STRING Example: "52852b3fe189de1d135bea41e887857e"
  - `text`: STRING Example: "Prompt  EngineeringAuthor: Lee Boonstra"
  - `page`: INTEGER Min: 0, Max: 67
  - `source`: STRING Available options: ['/Users/kathisnehith/Downloads/22365_3_Prompt Engin']
- **Person**
  - `id`: STRING Example: "Lee Boonstra"
- **Book**
  - `id`: STRING Available options: ['Prompt Engineering']
- **Topic**
  - `id`: STRING Example: "Introduction"
- **Technique**
  - `id`: STRING Example: "Step-Back Prompting"
- **Concept**
  - `id`: STRING Example: "Best Practices"
- **Practice**
  - `id`: STRING Available options: ['Provide Examples', 'Design With Simplicity', 'Be Specific About Output', 'Use Instructions Over Constraints', 'Control Max Token Length', 'Use Variables In Prompts', 'Experiment With Input Formats', 'Mix Up Classes In Few-Shot Prompting', 'Adapt To 

In [None]:


## Visualize the Knowledge-graph 
# Run the function with your Neo4j graph
print("🚀 Creating Neo4j visualization...")
result = visualize_neo4j_graph(graph, max_nodes=500, max_relationships=1000)

if result:
    print(f"\n🎉 Visualization completed!")
    print(f"   📊 Nodes: {result['nodes_count']}")
    print(f"   🔗 Relationships: {result['relationships_count']}")
    print(f"   📁 File: {result['output_file']}")

In [18]:
# Create the GraphCypherQAChain with the tracker
chain = GraphCypherQAChain.from_llm(
    llm=llm,                             # Use OpenAI LLM for question answering
    graph=graph,                                # Use the Neo4j graph
    #cypher_prompt=custom_cypher_prompt,
    verbose=True,                               # Enable verbose logging
    top_k=15,                                    # Return top 5 results
    allow_dangerous_requests=True,
    #callbacks=[tracker]                         # Add the tracker to the callbacks
)

In [19]:
question = "details about Tiffany Ramirez"
result = chain.invoke({"query": question})



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (p:Person {id: 'Tiffany Ramirez'})
OPTIONAL MATCH (p)-[r]->(n)
OPTIONAL MATCH (n2)-[r2]->(p)
RETURN p, type(r), n, type(r2), n2[0m
Full Context:
[32;1m[1;3m[{'p': {'id': 'Tiffany Ramirez'}, 'type(r)': 'ADMITTED_TO', 'n': {'id': 'Wallace-Hamilton'}, 'type(r2)': 'MENTIONS', 'n2': {'id': '8f3321e466358b25925d4a6d89f757e9', 'text': 'Name: Tiffany Ramirez, Age: 81, Gender: Female, Blood Type: O-, Medical Condition: Diabetes, Date of Admission: 2022-11-17, Doctor: Patrick Parker, Hospital: Wallace-Hamilton, Insurance Provider: Medicare, Billing Amount: 37490.98336352819, Room Number: 146, Admission Type: Elective, Discharge Date: 2022-12-01, Medication: Aspirin, Test Results: Inconclusive', 'source': '/Users/kathisnehith/Downloads/healthcare_dataset.csv', 'row': 0}}, {'p': {'id': 'Tiffany Ramirez'}, 'type(r)': 'INSURED_BY', 'n': {'id': 'Medicare'}, 'type(r2)': 'MENTIONS', 'n2': {'id': '8f3321e466358b

In [21]:
print(f"\n🔍 Question: {question}")
print(f"💡 Answer: {result['result']}")


🔍 Question: details about Tiffany Ramirez
💡 Answer: Tiffany Ramirez is an 81-year-old female patient who was admitted to Wallace-Hamilton Hospital on November 17, 2022, for an elective procedure. She has diabetes and is taking Aspirin as medication. She was assigned to Room 146 during her stay and was discharged on December 1, 2022. Her blood type is O-negative, and she is insured by Medicare. Dr. Patrick Parker was her attending physician, and her test results were inconclusive. The billing amount for her hospital stay was $37,490.98.
