In [None]:
import os
import json
from typing import List, Tuple, Dict
from openai import OpenAI
from neo4j import GraphDatabase
import logging
from dotenv import load_dotenv
from transformers import pipeline
import spacy

# Load environment variables
load_dotenv()
API_KEY = os.getenv("OPENAI_API_KEY")

# Configure logging
logging.basicConfig(level=logging.INFO)

# Initialize NER model
def load_ner_model(model_name: str):
    return pipeline(
        "token-classification",
        model=model_name,
        aggregation_strategy="simple"
    )

# Load spaCy model
def load_spacy_model():
    try:
        return spacy.load("en_core_web_sm")
    except OSError:
        print("Downloading spaCy language model...")
        spacy.cli.download("en_core_web_sm")
        return spacy.load("en_core_web_sm")

# Generate OpenAI embeddings
def generate_openai_embedding(text: str, openai_client) -> List[float]:
    try:
        response = openai_client.embeddings.create(
            model="text-embedding-ada-002",
            input=text
        )
        return response['data'][0]['embedding']
    except Exception as e:
        logging.error(f"Error generating embedding: {e}")
        return []

# Split text into chunks
def split_text_into_chunks(text: str, max_length: int = 512) -> List[str]:
    words = text.split()
    return [" ".join(words[i:i + max_length]) for i in range(0, len(words), max_length)]

# Extract entities using both models
def extract_entities(text: str, ner_model, nlp) -> List[Tuple[str, str]]:
    ner_results = ner_model(text)
    entities = [(res["word"], res["entity_group"]) for res in ner_results]
    
    doc = nlp(text)
    spacy_entities = [(ent.text, ent.label_) for ent in doc.ents]
    
    all_entities = entities + spacy_entities
    unique_entities = list({(e.lower(), t): (e, t) for e, t in all_entities}.values())
    return unique_entities

# Generate relationships using OpenAI
def generate_relations(text: str, entities: List[Tuple[str, str]], openai_client, openai_model: str) -> List[Dict]:
    entities_str = "\n".join([
        f"{idx+1}. {entity} (Type: {entity_type})"
        for idx, (entity, entity_type) in enumerate(entities)
    ])
    
    prompt = f"""
        You are an expert in knowledge graph construction. Identify relationships between these entities:
        
        ### Text:
        {text}
        
        ### Entities:
        {entities_str}
        
        Output relationships in JSON format:
        [
            {{
                "source": "Entity1",
                "source_type": "Type1",
                "target": "Entity2",
                "target_type": "Type2",
                "relationship": "RELATIONSHIP_TYPE",
                "evidence": "Evidence from the text"
            }}.
        ]
    """

    try:
        response = openai_client.chat.completions.create(
            model=openai_model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1500,
            temperature=0.1
        )
        raw_content = response.choices[0].message.content.strip()
        relations = json.loads(raw_content)
        return relations if isinstance(relations, list) else []
    except Exception as e:
        logging.error(f"Error generating relationships: {e}")
        return []

# Create nodes in Neo4j
def create_nodes_with_embeddings(tx, entity_metadata):
    for data in entity_metadata:
        entity, entity_type = data["entity"]
        embedding = data["embedding"]
        chunk_text = data["chunk_text"]
        description = chunk_text[:200]  # First 200 characters as description
        
        tx.run(
            """
            MERGE (n:Entity {name: $name, type: $type})
            SET n.embedding = $embedding,
                n.chunk = $chunk,
                n.description = $description
            """,
            {"name": entity, "type": entity_type, "embedding": embedding, "chunk": chunk_text, "description": description}
        )


# Create relationships in Neo4j
def create_relationships(tx, relations):
    for relation in relations:
        relationship_type = relation["relationship"].replace(" ", "_").upper()
        query = f"""
        MATCH (source:Entity {{name: $source}})
        MATCH (target:Entity {{name: $target}})
        MERGE (source)-[r:`{relationship_type}`]->(target)
        SET r.evidence = $evidence
        """
        tx.run(query, {
            "source": relation["source"],
            "target": relation["target"],
            "evidence": relation.get("evidence", "No specific evidence")
        })

# Clear Neo4j database
def clear_neo4j_database(uri: str, username: str, password: str):
    driver = GraphDatabase.driver(uri, auth=(username, password))
    with driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")
        logging.info("Neo4j database cleared.")
    driver.close()

def process_text_with_embeddings(text: str, ner_model, nlp, openai_client, openai_model, neo4j_config):
    chunks = split_text_into_chunks(text)
    all_entities = []
    all_relations = []
    entity_metadata = []  # List to store entities with corresponding embeddings and chunk texts

    for chunk in chunks:
        embedding = generate_openai_embedding(chunk, openai_client)
        entities = extract_entities(chunk, ner_model, nlp)
        relations = generate_relations(chunk, entities, openai_client, openai_model)
        
        for entity in entities:
            entity_metadata.append({
                "entity": entity,
                "embedding": embedding,
                "chunk_text": chunk
            })
        
        all_entities.extend(entities)
        all_relations.extend(relations)

    # Deduplicate entities and align metadata
    unique_entities = list(set(all_entities))
    unique_relations = {frozenset((rel["source"], rel["target"])): rel for rel in all_relations}.values()

    # Prepare final metadata for Neo4j
    final_metadata = [
        {
            "entity": entity,
            "embedding": next(
                (meta["embedding"] for meta in entity_metadata if meta["entity"] == entity), []
            ),
            "chunk_text": next(
                (meta["chunk_text"] for meta in entity_metadata if meta["entity"] == entity), ""
            )
        }
        for entity in unique_entities
    ]

    driver = GraphDatabase.driver(neo4j_config["uri"], auth=(neo4j_config["username"], neo4j_config["password"]))
    with driver.session() as session:
        # Add nodes with embeddings and text metadata
        session.execute_write(create_nodes_with_embeddings, final_metadata)
        session.execute_write(create_relationships, list(unique_relations))
    driver.close()

    return unique_entities, list(unique_relations)

if __name__ == "__main__": 
    text_file_path = 'data/sample.txt' 
    neo4j_url = "bolt://localhost:7687" 
    neo4j_username = "neo4j" 
    neo4j_password = "123456789" 
    ner_model_name = "Clinical-AI-Apollo/Medical-NER" 
    openai_model_name = "gpt-4"


    # Load models
    ner_model = load_ner_model(ner_model_name)
    nlp = load_spacy_model()
    openai_client = OpenAI(api_key=API_KEY)

    # Clear the Neo4j database
    clear_neo4j_database(neo4j_url, neo4j_username, neo4j_password)

    # Directory with text files
    paper_text_dir = 'paper_16_paragraphs'
    text_files = [os.path.join(paper_text_dir, file) for file in os.listdir(paper_text_dir) if file.endswith('.txt')]

    for idx, text_file in enumerate(text_files):
        db_name = f"db_{idx + 1}"
        neo4j_config = {
            "uri": neo4j_url,
            "username": neo4j_username,
            "password": neo4j_password,
            "database": db_name
        }
        
        with open(text_file, 'r', encoding='utf-8') as f:
            text = f.read()
        
        if not text.strip():  # Skip empty files
            logging.warning(f"File {text_file} is empty. Skipping.")
            continue
        
        logging.info(f"Processing file: {text_file}")
        entities, relations = process_text_with_embeddings(
            text, ner_model, nlp, openai_client, openai_model_name, neo4j_config
        )


INFO:root:Neo4j database cleared.
INFO:root:Processing file: paper_16_paragraphs\text_1.txt
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
ERROR:root:Error generating embedding: 'CreateEmbeddingResponse' object is not subscriptable
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:root:Processing file: paper_16_paragraphs\text_10.txt
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
ERROR:root:Error generating embedding: 'CreateEmbeddingResponse' object is not subscriptable
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:root:Processing file: paper_16_paragraphs\text_11.txt
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
ERROR:root:Error generatin