# Create Energy Graph

This notebook will create a graph representing an energy grid with customer installations.

It is split into two sections

1. __Structured Source__: Tabular data. Ingestion is done with Cypher query templates (ordinary ETL)
2. __Unstructured Source__: Text data. In this case maintenance records. In estion involves an LLM powered entity extraction step prior to loading with a Cypher query templates

## Structured Source

In [None]:
from neo4j import GraphDatabase
import os
import requests
from dotenv import load_dotenv
from getpass import getpass

from pydantic import BaseModel, Field

# get credentials
load_dotenv('target-db.env', override=True)

uri = os.getenv('NEO4J_URI')
username = os.getenv('NEO4J_USERNAME')
password = os.getenv('NEO4J_PASSWORD')

if not uri:
  uri = getpass("Please enter your Neo4j URI: ")
if not username:
  username = getpass("Please enter your Neo4j username: ")
if not password:
  password = getpass("Please enter your Neo4j password: ")

driver = GraphDatabase.driver(uri, auth=(username, password))

In [None]:
from neo4j import RoutingControl

#create uniqueness constraint if not exists
driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Generator) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Bus) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Transformer) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Link) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Station) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)


driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Customer) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Installation) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Region) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Consumption) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Ticket) REQUIRE (n.ticketNumber) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:MaintenanceRecord) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

driver.execute_query(
    'CREATE CONSTRAINT IF NOT EXISTS FOR (n:Alert) REQUIRE (n.id) IS NODE KEY',
    #database_=DATABASE,
    routing_=RoutingControl.WRITE
)

In [None]:
data_dir = "source-data"
# helper function
def chunks(xs, n=1_000):
    n = max(1, n)
    return [xs[i:i + n] for i in range(0, len(xs), n)]

In [None]:
import pandas as pd
from neo4j import RoutingControl

generator_df = pd.read_csv(os.path.join(data_dir,'generators.csv'))

for records in chunks(generator_df.to_dict(orient='records')):
    res = generator_df = driver.execute_query("""
      UNWIND $records as rec
      MERGE (g:Generator {id:rec.ID})
      MERGE (b:Bus {id:rec.BUS_ID})
      MERGE (g)-[r:CONNECTED]->(b)
      SET
        g.capacity = rec.CAPACITY,
        g.category = rec.CATEGORY,
        g.geometry = point({latitude: rec.LATITUDE, longitude: rec.LONGITUDE}),
        g.mb_symbol = rec.MB_SYMBOL,
        g.name_eng = rec.NAME_ENG,
        g.name_nat = rec.NAME_NAT,
        g.symbol = rec.SYMBOL,
        g.tso = rec.TSO,
        g.visible = rec.VISIBLE
      RETURN count(rec) AS records_upserted
    """,
        #database_=DATABASE,
        routing_=RoutingControl.WRITE,
        result_transformer_= lambda r: r.data(),
        records = records
    )
    print(res)

In [None]:
bus_df = pd.read_csv(os.path.join(data_dir,'buses.csv'))

for records in chunks(bus_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (b:Bus {id: rec.ID})
      MERGE (s:Station {id: rec.STATION_ID})
      MERGE (b)-[:IN_STATION]->(s)
      SET
        b.category = rec.CATEGORY,
        b.geometry = point({latitude: rec.LATITUDE, longitude: rec.LONGITUDE}),
        b.mb_symbol = rec.MB_SYMBOL,
        b.name_eng = rec.NAME_ENG,
        b.name_nat = rec.NAME_NAT,
        b.symbol = rec.SYMBOL,
        b.tso = rec.TSO,
        b.visible = rec.VISIBLE,
        b.voltage = rec.VOLTAGE
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

In [None]:
transformer_df = pd.read_csv(os.path.join(data_dir,'transformers.csv'))

for records in chunks(transformer_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (t:Transformer {id: rec.ID})
      MERGE (b:Bus {id: rec.BUS_ID})
      MERGE (t)-[:CONNECTED]->(b)
      SET
        t.dst_dc = rec.DST_DC,
        t.dst_voltage = rec.DST_VOLTAGE,
        t.geometry = point({latitude: rec.LATITUDE, longitude: rec.LONGITUDE}),
        t.src_dc = rec.SRC_DC,
        t.src_voltage = rec.SRC_VOLTAGE,
        t.symbol = rec.SYMBOL
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

In [None]:
link_df = pd.read_csv(os.path.join(data_dir,'links.csv'))

for records in chunks(link_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (l:Link {id: rec.ID})
      MERGE (b:Bus {id: rec.BUS_ID})
      MERGE (l)-[:CONNECTED]->(b)
      SET
        l.circuits = rec.CIRCUITS,
        l.dc = rec.DC,
        l.length_m = rec.LENGTH_M,
        l.shape_leng = rec.SHAPE_LENG,
        l.symbol = rec.SYMBOL,
        l.t9_code = rec.T9_CODE,
        l.underground = rec.UNDERGROUND,
        l.visible = rec.VISIBLE,
        l.voltage = rec.VOLTAGE
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

In [None]:
station_df = pd.read_csv(os.path.join(data_dir,'stations.csv'))

for records in chunks(station_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (s:Station {id: rec.ID})
      SET
        s.name_eng = rec.NAME_ENG,
        s.geometry = point({latitude: rec.LATITUDE, longitude: rec.LONGITUDE})
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)


In [None]:
customer_df = pd.read_csv(os.path.join(data_dir,'customers.csv'))

for records in chunks(customer_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (c:Customer {id: rec.ID})
      SET
        c.name = rec.NAME,
        c.type = rec.TYPE
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)


In [None]:
installation_df = pd.read_csv(os.path.join(data_dir,'installations.csv'))

for records in chunks(installation_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (i:Installation {id: rec.ID})
      SET
        i.installationDate = rec.INSTALLATIONDATE,
        i.nome = rec.NOME,
        i.type = rec.TYPE
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

# Due to data quality issues a minority of installations are missing customers, links, and/or region. We work around for now
for records in chunks(installation_df[~installation_df.LINK_ID.isna()].to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (i:Installation {id: rec.ID})
      MERGE (l:Link {id: rec.LINK_ID})
      MERGE (l)-[:LINK_HAS_INSTALLATION]->(i)
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

for records in chunks(installation_df[~installation_df.CUSTOMER_ID.isna()].to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (i:Installation {id: rec.ID})
      MERGE (c:Customer {id: rec.CUSTOMER_ID})
      MERGE (c)-[:CUSTOMER_HAS_INSTALLATION]->(i)
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

for records in chunks(installation_df[~installation_df.REGION_ID.isna()].to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (i:Installation {id: rec.ID})
      MERGE (r:Region {id: rec.REGION_ID})
      MERGE (i)-[:INSTALL_HAS_REGION]->(r)
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)


In [None]:
region_df = pd.read_csv(os.path.join(data_dir,'regions.csv'))

for records in chunks(region_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (r:Region {id: rec.ID})
      SET r.name = rec.NAME
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

In [None]:
import pandas as pd
consumption_df = pd.read_csv(os.path.join(data_dir,'consumption_logs.csv'))

for records in chunks(consumption_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (c:Consumption {id: rec.ID})
      MERGE (i:Installation {id: rec.INSTALLATION_ID})
      MERGE (i)-[:INSTALL_HAS_CONSUMPTION]->(c)
      SET
        c.seqId = rec.SEQ_ID,
        c.referenceDate = rec.REFERENCEDATE,
        c.quantity = rec.QUANTIDADE,
        c.consumptionValue = rec.CONSUMPTIONVALUE,
        c.invoiceValue = rec.INVOICEVALUE,
        c.newConsumptionValue = rec.NEWCONSUMPTIONVALUE
      //add NEXT relationships
      WITH c,i, rec
      MATCH (i)-[:INSTALL_HAS_CONSUMPTION]->(c_next:Consumption {seqId: c.seqId + 1})
      MERGE (c)-[:NEXT]->(c_next)
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

# First
res = driver.execute_query("""
MATCH(i:Installation)-[:INSTALL_HAS_CONSUMPTION]->(c)
WHERE COUNT{()-[:NEXT]->(c)} = 0
MERGE (i)-[r:FIRST]->(c)
RETURN count(r) AS relationships_written
""", routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data())
print(res)

# Last
res = driver.execute_query("""
MATCH(i:Installation)-[:INSTALL_HAS_CONSUMPTION]->(c)
WHERE COUNT{(c)-[:NEXT]->()} = 0
MERGE (i)-[r:LAST]->(c)
RETURN count(r) AS relationships_written
""", routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data())
print(res)

In [None]:
ticket_df = pd.read_csv(os.path.join(data_dir,'tickets.csv'))

for records in chunks(ticket_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MERGE (t:Ticket {ticketNumber: rec.TICKETNUMBER})
      MERGE (c:Customer {id: rec.CUSTOMER_ID})
      MERGE (c)-[:CREATED_TICKET]->(t)
      SET
        t.createdDate = rec.CREATEDATE,
        t.resolutionDate = rec.RESOLUTIONDATE,
        t.severity = rec.SEVERITY,
        t.status = rec.STATUS
      RETURN count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

In [None]:
alert_df = pd.read_csv(os.path.join(data_dir,'alerts.csv'))

for records in chunks(alert_df.to_dict(orient='records')):
    res = driver.execute_query("""
      UNWIND $records AS rec
      MATCH (eq:$(rec.EQUIPMENT_TYPE) {id:rec.EQUIPMENT_ID})
      MERGE (a:Alert {id:rec.ID})
      MERGE (eq)-[r:HAS_ALERT]->(a)
      SET
        a.type = rec.TYPE,
        a.date = CASE
            WHEN toString(rec.DATE) = 'NaN' OR rec.DATE IS NULL OR rec.DATE = ''
            THEN NULL
            ELSE date(rec.DATE)
            END

    return count(rec) AS records_upserted
    """, routing_=RoutingControl.WRITE, result_transformer_=lambda r: r.data(), records=records)
    print(res)

## Unstructured Source

In [None]:
from pprint import pprint
import json

# Read from JSON file into array of Python objects
with open(os.path.join(data_dir,'maintenance_records.json'), 'r') as file:
    maintenance_record_texts = json.load(file)
print(f"Loaded {len(maintenance_record_texts)} records")
print("Sample:")
for rec in maintenance_record_texts[:3]:
    print('-----')
    print(rec)


In [None]:
from enum import Enum
from pydantic import BaseModel, Field

class MaintenanceType(str, Enum):
    PREDICTIVE = "Predictive"
    CORRECTIVE = "Corrective"
    PREVENTIVE = "Preventive"
    EMERGENCY = "Emergency"


class MaintenanceRecord(BaseModel):
    id: str = Field(..., description="The maintenance record id")
    equipmentId: int = Field(..., description="The equipment id")
    equipmentType: str = Field(..., description="The equipment type")
    description: str = Field(..., description="The maintenance record description in English.  "
                                              "Translate as necessary. "
                                              "Some description may be in different languages such "
                                              "as Portuguese, Spanish, etc.")
    date: str = Field(..., description="the date of the maintenance record")
    downtimeInHours: int = Field(..., description="The downtime in hours.  These may not be labeled as hours, but numbers without units are hours by default.")
    type: MaintenanceType = Field(..., description="The maintenance record type")
    rootCause: str = Field(..., description="The root cause of the maintenance")

In [None]:
import asyncio
from typing import List
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel
from tqdm.asyncio import tqdm as tqdm_async


class TextExtractor:
    def __init__(self,
                 llm_with_struct_output,
                 prompt_template: PromptTemplate):
        self.llm = llm_with_struct_output
        self.prompt_template = prompt_template

    async def extract(self, texts: List[str], semaphore) -> BaseModel:
        async with semaphore:
            prompt = self.prompt_template.invoke({'texts': '\n\n'.join(texts)})
            # Use structured LLM for extraction
            entity: BaseModel = await self.llm.ainvoke(prompt)
        return entity


    async def extract_all(self, texts: List[str], chunk_size=1, max_workers=10) -> List[BaseModel]:
        # Create a semaphore with the desired number of workers
        semaphore = asyncio.Semaphore(max_workers)

        # Create tasks with the semaphore
        text_chunks = chunks(texts, chunk_size)
        tasks = [self.extract(text_chunk, semaphore) for text_chunk in text_chunks]

        # Explicitly update progress using `tqdm` as tasks complete
        entities: List[BaseModel] = []
        with tqdm_async(total=len(tasks), desc="extracting texts") as pbar:
            for future in asyncio.as_completed(tasks):
                result = await future
                entities.append(result)
                pbar.update(1)  # Increment progress bar for each completed task
        return entities

In [None]:
from langchain_openai import ChatOpenAI

#Get LLM api key
load_dotenv('source-db.env', override=True)
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI AI API key: ")

# Define Prompt and LLM with structured output
prompt_template = PromptTemplate.from_template("""
Extract and structure the maintenance information from the following text:

# Text
{texts}
""")
llm = ChatOpenAI(model="gpt-4.1", temperature=0).with_structured_output(MaintenanceRecord)

# Perform entity extraction
text_extractor = TextExtractor(llm, prompt_template)
maintenance_records = await text_extractor.extract_all(maintenance_record_texts)

In [None]:
for maintenance_record in maintenance_records[:3]:
    pprint(maintenance_record.model_dump())

In [None]:
for record_objects in chunks(maintenance_records):
    records = [i.model_dump() for i in record_objects]
    res = generator_df = driver.execute_query("""
      UNWIND $records as rec
      MATCH (eq:$(rec.equipmentType) {id:rec.equipmentId})
      MERGE (m:MaintenanceRecord {id:rec.id})
      MERGE (eq)-[r:HAS_MAINTENANCE_RECORD]->(m)
      SET
        m.description = rec.description,
        m.type = rec.type,
        m.date = date(rec.date),
        m.downtimeInHours = rec.downtimeInHours,
        m.rootCause = rec.rootCause
    return count(rec) AS records_upserted
    """,
        #database_=DATABASE,
        routing_=RoutingControl.WRITE,
        result_transformer_= lambda r: r.data(),
        records = records
    )
    print(res)