In [6]:
import psycopg2
from psycopg2 import sql
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import time
import torch
import os
from torch_geometric.utils import sort_edge_index
import csv

In [2]:
## TODO CHeck if we would prefer to set features rather to 0 to 1

In [3]:
# from ogb.nodeproppred import NodePropPredDataset

# # Download and process data at './dataset/ogbg_molhiv/'
# dataset = NodePropPredDataset(name = "ogbn-papers100M", root = 'ogbn_dataset/')
# print(dataset[0])

In [4]:
edge_file_name = "ogbn_paper100M_edge_index.csv"
# pd.DataFrame(dataset[0][0]["edge_index"].transpose(), columns = ["source_id", "target_id"]).to_csv(f"data/{edge_file_name}", index = False)

In [4]:
node_file_name = "ogbn_paper100M_features.csv"
# X_and_y = pd.DataFrame(dataset[0][0]["node_feat"]).astype(np.float32)
# X_and_y.columns = list(map(lambda col: f"f_{col}", X_and_y.columns))
# X_and_y["label"] = dataset[0][1].squeeze()
# X_and_y["label"] = X_and_y["label"].fillna(-1).astype(np.int16)
# X_and_y.to_csv(f"data/{node_file_name}", sep = ",", index = True)

In [7]:
with open(f"data/{node_file_name}", newline='') as f:
  reader = csv.reader(f)
  row1 = next(reader)

In [13]:
columns = row1[1:]

In [47]:
def extract_execution_time(explain_output):
    """
    Extracts the total execution time from the given query plan.

    The function parses the query plan to find the execution time, which is typically
    represented in the format 'Execution Time: X ms'. It returns the execution time
    in seconds.

    Parameters:
    query_plan (list): A list of strings representing the lines of the query plan output.

    Returns:
    float: The total execution time in ms. If the execution time cannot be found,
           returns None.
    """

    execution_time = 0.0
    pattern = re.compile(r"Execution Time: (\d+\.\d+) ms")
    for row in explain_output:
        match = pattern.search(row[0])
        if match:
            execution_time += float(match.group(1))
    return execution_time
    
def connect_to_postgres(dbname = "postgres"):
    """Connect to the PostgreSQL database server."""
    try:
        conn = psycopg2.connect(
            dbname=dbname,  # Connect to default db to create new db
            user='postgres',
            password='password',
            host='localhost'
        )
        print("Connection successful.")
        return conn
    except Exception as e:
        print(f"Error connecting to database: {e}")
        raise

def create_database(conn, new_db_name):
    """Create a new database."""
    try:
        conn.autocommit = True
        with conn.cursor() as cursor:
            cursor.execute(sql.SQL("CREATE DATABASE {};").format(sql.Identifier(new_db_name)))
            print(f"Database '{new_db_name}' created successfully.")
    except Exception as e:
        print(f"Error creating database: {e}")
        raise
    finally:
        conn.autocommit = False

def create_schema(conn, schema_sql):
    """Create the database schema."""
    try:
        with conn.cursor() as cursor:
            cursor.execute(schema_sql)
            conn.commit()
            print("Schema created successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error creating schema: {e}")
        raise

def upload_csv_to_table(conn, csv_file_path, table_name):
    """Upload a CSV file to a table using the COPY command."""
    try:
        with conn.cursor() as cursor:
            with open(csv_file_path, 'r') as f:
                cursor.copy_expert(
                    sql.SQL("""
                        COPY {} FROM STDIN WITH (FORMAT CSV, HEADER TRUE, DELIMITER ',');
                    """).format(sql.Identifier(table_name)), f
                )
            conn.commit()
            print(f"Data from '{csv_file_path}' uploaded to table '{table_name}' successfully using COPY.")
    except Exception as e:
        conn.rollback()
        print(f"Error uploading CSV data using COPY: {e}")
        raise

def create_index(conn, table_name, column_name):
    """Create an index on a specific column."""
    try:
        with conn.cursor() as cursor:
            index_name = f"{table_name}_{column_name}_idx"
            cursor.execute(sql.SQL("CREATE INDEX {} ON {} ({});").format(
                sql.Identifier(index_name),
                sql.Identifier(table_name),
                sql.Identifier(column_name)
            ))
            conn.commit()
            print(f"Index '{index_name}' created successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error creating index: {e}")
        raise

def read_data(conn, query):
    """Read data from the database."""
    try:
        with conn.cursor() as cursor:
            cursor.execute(query)
            results = cursor.fetchall()
            print("Data read successfully.")
            return results
    except Exception as e:
        print(f"Error reading data: {e}")
        raise

def update_data(conn, query, params):
    """Update data in the database."""
    try:
        with conn.cursor() as cursor:
            cursor.execute(query, params)
            conn.commit()
            print("Data updated successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error updating data: {e}")
        raise

def delete_database(conn, db_name):
    """Delete the specified database including all its schemas, tables, and indexes."""
    try:
        conn.autocommit = True
        with conn.cursor() as cursor:
            cursor.execute(sql.SQL("DROP DATABASE IF EXISTS {};").format(sql.Identifier(db_name)))
            print(f"Database '{db_name}' deleted successfully.")
    except Exception as e:
        print(f"Error deleting database: {e}")
        raise
    finally:
        conn.autocommit = False

def create_edges_table(conn):
    """Create an edges table with foreign key constraints to the main table."""
    edges_schema = """
    CREATE TABLE edges (
        source_id INTEGER NOT NULL,
        target_id INTEGER NOT NULL,
        FOREIGN KEY (source_id) REFERENCES nodes(id) ON DELETE CASCADE,
        FOREIGN KEY (target_id) REFERENCES nodes(id) ON DELETE CASCADE
    );
    """
    try:
        with conn.cursor() as cursor:
            cursor.execute(edges_schema)
            conn.commit()
            print("Edges table schema created successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error creating edges table: {e}")
        raise

def upload_edges_csv_to_table(conn, csv_file_path):
    """Upload a CSV file to the edges table using the COPY command."""
    try:
        with conn.cursor() as cursor:
            with open(csv_file_path, 'r') as f:
                cursor.copy_expert(
                    sql.SQL("""
                        COPY edges (source_id, target_id) FROM STDIN WITH (FORMAT CSV, HEADER TRUE, DELIMITER ',');
                    """), f
                )
            conn.commit()
            print(f"Edges data from '{csv_file_path}' uploaded successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error uploading edges CSV data: {e}")
        raise

def recurse_edge_index_iterative(source_nodes, edge_index, max_depth):
    """
    Optimized function to compute the subgraph around the source nodes up to a given depth.
    Uses an iterative approach instead of recursion.
    """
    visited_nodes = set(source_nodes)
    current_frontier = np.array(source_nodes)
    
    subgraph_edges = []

    for _ in range(max_depth):
        # Find edges where the target node is in the current frontier
        target_mask = np.isin(edge_index[1], current_frontier)
        subgraph_edge_index = edge_index[:, target_mask]
        subgraph_edges.append(subgraph_edge_index)

        # Update the current frontier with the source nodes of these edges
        current_frontier = np.setdiff1d(subgraph_edge_index[0], list(visited_nodes))
        visited_nodes.update(current_frontier)
        
        if len(current_frontier) == 0:
            break

    # Combine edges from all hops
    return np.concatenate(subgraph_edges, axis=1) if subgraph_edges else np.empty((2, 0), dtype=edge_index.dtype)


def get_subgraph_from_in_mem_graph_optimized(X, y, i, edge_index, hops):
    """
    Optimized version of subgraph extraction.
    """
    subgraph_edge_index = recurse_edge_index_iterative([i], edge_index, hops)
    unique_node_ids, remapping = np.unique(subgraph_edge_index, return_inverse=True)
    
    # Extract features and labels
    features = X.iloc[unique_node_ids, :].values
    labels = y.iloc[unique_node_ids, :].values.squeeze()

    # Remap edge indices
    remapped_edge_index = remapping.reshape(2, -1)
    return remapped_edge_index, features, labels, unique_node_ids

def create_db(node_file_name):
    # Create a new database
    conn = connect_to_postgres(dbname = "postgres")
    new_db_name = node_file_name.split(".")[0]
    create_database(conn, new_db_name)
    conn.close()
    conn = connect_to_postgres(new_db_name)
    return conn, new_db_name

def create(conn, node_file_name, edge_file_name, columns = columns):
    # column_types = ["id SERIAL PRIMARY KEY"]
    # for col in X_and_y.columns:
    #     if col == "label":
    #         column_types.append(f"{col} INTEGER")
    #         continue
    #     column_types.append(f"{col} REAL")
        
    # node_schema = f"""
    # CREATE TABLE nodes (
    #     {",".join(column_types)}
    # );
    # """
    start = time.time()
    # create_schema(conn, node_schema)
    # create_edges_table(conn)
    
    # csv_file_path = f"data/{node_file_name}"  # Replace with your CSV file path
    # upload_csv_to_table(conn, csv_file_path, "nodes")
    # create_index(conn, "nodes", "id")
    upload_edges_csv_to_table(conn, f"data/{edge_file_name}")
    create_index(conn, "edges", "target_id")
    creation_time = time.time() - start
    return creation_time

def read(conn, hops, columns = columns):
    with conn.cursor() as cursor:
        cursor.execute("SELECT id FROM nodes")
        results = cursor.fetchall()
    seed_node_ids = list(map(lambda res_data: res_data[0], results))
    
    with conn.cursor() as cursor:
        complete_time = 0
        complete_test_time = 0
        for seed_node_id in tqdm(seed_node_ids):
            try:
                start = time.time()
                cursor.execute(f"""
        WITH RECURSIVE NestedTargets AS (
            SELECT 0 AS depth, source_id, target_id
            FROM edges
            WHERE target_id = {seed_node_id}
            
            UNION ALL
            
            SELECT nt.depth + 1, e.source_id, e.target_id
            FROM edges e
            JOIN NestedTargets nt ON e.target_id = nt.source_id
            WHERE nt.depth < {hops - 1}
        ),
        
        node_ids AS (
            SELECT DISTINCT id FROM (
                SELECT source_id AS id FROM NestedTargets
                UNION
                SELECT target_id AS id FROM NestedTargets
            ) AS combined_ids
        ),
        
        node_data AS (
            SELECT 
                id, 
                {", ".join(columns)}
            FROM nodes
            WHERE id IN (SELECT id FROM node_ids)
            ORDER BY id
        )
        
        SELECT
            (SELECT array_agg(array[id, {", ".join(columns[:-1])}]) FROM node_data) AS node_table,
            (SELECT array_agg({columns[-1]}) FROM node_data) AS label_table,
            (SELECT array_agg(array[source_id, target_id])
             FROM (SELECT DISTINCT source_id, target_id FROM NestedTargets) AS edges) AS edge_table,
             (SELECT array_agg(id) FROM node_data) AS node_ids; 
    """) ##Cant use id from above since array agg is float convertng my ints which reduces the max value for the serial integer (e.g., 10,000,323 -> 10,000,320)
                results = cursor.fetchall()[0]
                labels = np.array(results[1])
                subgraph_node_features = np.array(results[0])
                if results[0] is None:
                    continue
                
                subgraph_edges = np.array(results[2]).transpose()
                node_ids = np.array(results[-1])
                _, cols_source = np.nonzero((subgraph_edges[0] == node_ids[:, None]).transpose())
                _, cols_target = np.nonzero((subgraph_edges[1] == node_ids[:, None]).transpose())

                remapped_edge_index = np.concatenate([np.expand_dims(cols_source, axis = 0), np.expand_dims(cols_target, axis = 0)], axis = 0)
                features = subgraph_node_features[:, 1:]
                overall_run_time = time.time() - start 
                
                complete_time += overall_run_time
                
                # Testing
                test_time = time.time()
                # remapped_edge_index_test, features_test, labels_test, unique_node_ids = get_subgraph_from_in_mem_graph_optimized(X, y, seed_node_id, edge_index, hops)                    
                # complete_test_time += time.time() - test_time
                # assert (sort_edge_index(torch.from_numpy(remapped_edge_index_test)) == sort_edge_index(torch.from_numpy(remapped_edge_index))).sum() / (remapped_edge_index_test.shape[-1] * remapped_edge_index_test.shape[0]), "Edges doesnt match"
                # assert np.allclose(features, features_test), "features doe not match"
                # assert np.allclose(labels_test, labels), "Labels does not match"
                # print(f"Fetched {remapped_edge_index.shape} edges, {labels.shape} labels, {features.shape} features in ({overall_run_time} s)")
            except Exception as e:
                conn.rollback()
                print(f"Error reading subgraphs: {e}")
                raise   
    return (complete_time, complete_test_time)

def delete(conn, new_db_name):
    start = time.time()
    conn.close()
    conn = connect_to_postgres()
    delete_database(conn, new_db_name)
    conn.close()
    return time.time() - start

In [48]:
# Example usage
hops = 2
overall_run_time = 0
# edge_index = dataset[0][0]["edge_index"] # edges.values.astype(np.int64).transpose()

# for i in tqdm(range(X.shape[0])):
#     start = time.time()
#     remapped_edge_index, features, labels, node_ids = get_subgraph_from_in_mem_graph_optimized(X, y, i, edge_index, hops)
#     overall_run_time += time.time() - start
    
#     print(f"Fetched {remapped_edge_index.shape} edges, {labels.shape} labels, {features.shape} features in {overall_run_time:.2f} s")

In [49]:
# conn, new_db_name = create_db(node_file_name)
# create_time = create(conn, node_file_name, edge_file_name, X_and_y)

In [51]:
read_time, read_time_mem = read(conn, 3, columns)

  0%|          | 0/111059956 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
read_time, read_time_mem

In [None]:
read_times = dict()
read_times_mem = dict()
for hops in tqdm(range(1, 4)):
    read_time, read_time_mem = read(conn, hops)
    read_times[hops] = read_time
    read_times_mem[hops] = read_time_mem


In [17]:
delete_time = delete(conn, new_db_name)

Connection successful.
Connection successful.
Error deleting database: database "ogbn_paper100M_features" is being accessed by other users
DETAIL:  There is 1 other session using the database.



ObjectInUse: database "ogbn_paper100M_features" is being accessed by other users
DETAIL:  There is 1 other session using the database.


In [53]:
read_times

{1: 270.018767118454, 2: 3174.6598134040833, 3: 49607.77852463722}

In [54]:
read_times_mem

{1: 331.7824068069458, 2: 604.4689948558807, 3: 1079.6200346946716}

In [55]:
create_time

13.816965341567993

In [56]:
delete_time

0.10642385482788086

In [17]:
new_db_name = node_file_name.split(".")[0]
conn = connect_to_postgres(new_db_name)

Connection successful.


In [9]:
with conn.cursor() as cursor:
    cursor.execute("SELECT count(*) FROM edges ")
    print(cursor.fetchall())

[(1615685872,)]


In [None]:
## TODO Check if everythign runs without problems
## TODO Update function
## TODO Neo4j impl
## TODO MySQL impl
## TODO large graph

In [None]:
# # Update data
# update_query = "UPDATE example_table SET city = %s WHERE id = %s;"
# update_params = ("NewCity", 1)
# update_data(conn, update_query, update_params)