In [1]:
import mysql.connector as connector
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import time
import torch
from torch_geometric.utils import sort_edge_index
import os
import subprocess

  _torch_pytree._register_pytree_node(


In [2]:
connection = connector.connect(
        user='root',
        password='password',
        host='localhost'
    )
connection.autocommit = True

cursor = connection.cursor()

cursor.execute("""SHOW DATABASES;""")
print(cursor.fetchall())
connection.commit()
cursor.close()
connection.close()

[('X_and_y_10000_5',), ('information_schema',), ('mysql',), ('performance_schema',), ('sys',)]


In [24]:
connection = connector.connect(
        user='root',
        password='password',
        host='localhost',
    database="mysql"
    )
connection.autocommit = True

cursor = connection.cursor()

cursor.execute("""show global variables like 'local_infile';""")
print(cursor.fetchall())
connection.commit()
cursor.close()
connection.close()

[('local_infile', 'ON')]


In [None]:
OPT_LOCAL_INFILE

In [43]:
import os
import mysql.connector as connector

def connect_to_mysql(db_name=None):
    """Connect to the MySQL database server."""
    try:
        conn = connector.connect(
            user='root',
            password='password',
            host='localhost',
            database=db_name,
            allow_local_infile=True
        )
        print("Connection successful.")
        return conn
    except Exception as e:
        print(f"Error connecting to database: {e}")
        raise

def create_database(conn, new_db_name):
    """Create a new database."""
    try:
        cursor = conn.cursor()
        cursor.execute(f"CREATE DATABASE IF NOT EXISTS {new_db_name}")
        print(f"Database '{new_db_name}' created successfully.")
    except Exception as e:
        print(f"Error creating database: {e}")
        raise

def create_table(conn, schema_sql):
    """Create a table with the specified schema."""
    try:
        cursor = conn.cursor()
        cursor.execute(schema_sql)
        print(f"Table created successfully.")
    except Exception as e:
        print(f"Error creating table: {e}")
        raise

def load_data_infile(csv_file_path, table_name, conn):
    """
    Use the LOAD DATA INFILE command to efficiently load data into a MySQL table.
    Args:
        csv_file_path (str): Path to the CSV file to import.
        table_name (str): Name of the target table.
        conn: Active MySQL connection object.
    """
    try:
        cursor = conn.cursor()
        query = (
            f"""
            LOAD DATA LOCAL INFILE '{csv_file_path}'
            INTO TABLE {table_name}
            FIELDS TERMINATED BY ','
            ENCLOSED BY '"'
            LINES TERMINATED BY '\n'
            IGNORE 1 ROWS;
            """
        )
        cursor.execute(query)
        conn.commit()
        print(f"Data from '{csv_file_path}' loaded into '{table_name}' successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error loading data: {e}")
        raise

def create_index(conn, table_name, column_name):
    """Create an index on a specific column."""
    try:
        cursor = conn.cursor()
        index_name = f"{table_name}_{column_name}_idx"
        cursor.execute(f"CREATE INDEX {index_name} ON {table_name} ({column_name})")
        conn.commit()
        print(f"Index '{index_name}' created successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error creating index: {e}")
        raise

def read_data(conn, query):
    """Read data from the database."""
    try:
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        print("Data read successfully.")
        return results
    except Exception as e:
        print(f"Error reading data: {e}")
        raise

def update_data(conn, query, params):
    """Update data in the database."""
    try:
        cursor = conn.cursor()
        cursor.execute(query, params)
        conn.commit()
        print("Data updated successfully.")
    except Exception as e:
        conn.rollback()
        print(f"Error updating data: {e}")
        raise

def delete_database(conn, db_name):
    """Delete the specified database."""
    try:
        cursor = conn.cursor()
        cursor.execute(f"DROP DATABASE IF EXISTS {db_name}")
        print(f"Database '{db_name}' deleted successfully.")
    except Exception as e:
        print(f"Error deleting database: {e}")
        raise


In [44]:
data_root_dir = "/home/dwalke/git/shweta/Master-Thesis-Code/data/syn_data"
num_nodes = 10000
num_edges = 5
X = pd.read_csv(f"{data_root_dir}/X_{num_nodes}_nodes_{num_edges}_edges.csv")
y = pd.read_csv(f"{data_root_dir}/y_{num_nodes}_nodes_{num_edges}_edges.csv")
edge_file_name = f"data/edges_{num_nodes}_{num_edges}.csv"
edges = pd.read_csv(f"{data_root_dir}/edge_index_{num_nodes}_nodes_{num_edges}_edges.csv")
edges.columns = ["source_id", "target_id"]
# edges = edges+1
edges.to_csv(edge_file_name, sep = ",", index = False)
X_and_y = X.copy()
X_and_y.columns = list(map(lambda col: f"f_{col}", X_and_y.columns))
X_and_y["label"] = y["0"]
node_file_name = f"X_and_y_{num_nodes}_{num_edges}.csv"
X_and_y.to_csv(f"data/{node_file_name}", sep = ",", index = True)

In [45]:
# edge_file_name = "ppi_edge_index.csv"
# node_file_name = "ppi.csv"
# X = pd.read_csv(f"data/ppi_x.csv")
# y = pd.read_csv(f"data/ppi_y.csv")
# edges = pd.read_csv("data/" + edge_file_name)
# edges.columns = ["source_id", "target_id"]
# y_comb = y.copy()
# y_comb['Combined'] = y_comb.apply(lambda row: [int(row[column]) for column in y.columns], axis=1)
# X_and_y = X.copy()
# X_and_y.columns = list(map(lambda col: f"f_{col}", X_and_y.columns))
# X_and_y["label"] = y_comb["Combined"].apply(lambda x: f"{{{','.join(map(str, x))}}}")
# node_file_name = "X_y_ppi.csv"
# X_and_y.to_csv(f"data/{node_file_name}", sep = ",", index = True)

In [46]:
X_and_y["label"]

0       0
1       0
2       0
3       0
4       0
       ..
9995    1
9996    1
9997    1
9998    1
9999    1
Name: label, Length: 10000, dtype: int64

In [47]:
def recurse_edge_index(source_nodes, edge_index, max_depth, depth = 0):
    assert max_depth >= 1, "Max depth should be above or equal to one"
    target_mask = edge_index[1][:, None] == source_nodes
    target_mask = target_mask.sum(-1).astype(np.bool_)
    subgraph_edge_index = edge_index[:, target_mask]
    depth = depth + 1
    
    if depth == max_depth:
        return subgraph_edge_index

    source_nodes = np.concatenate([subgraph_edge_index[0, :], source_nodes])    
    return recurse_edge_index(source_nodes,edge_index, max_depth, depth)

def get_subgraph_from_in_mem_graph(X,y, i, edge_index, hops):
    
    subgraph_edge_index = recurse_edge_index([i], edge_index, hops)
    unique_node_ids = np.unique(subgraph_edge_index)
    
    features = X.iloc[unique_node_ids, :].values
    labels = y.iloc[unique_node_ids, :].values.squeeze()

    node_ids = unique_node_ids
    _, cols_source = np.nonzero((subgraph_edge_index[0] == node_ids[:, None]).transpose())
    _, cols_target = np.nonzero((subgraph_edge_index[1] == node_ids[:, None]).transpose())
    remapped_edge_index = np.concatenate([np.expand_dims(cols_source, axis = 0), np.expand_dims(cols_target, axis = 0)], axis = 0)
    return remapped_edge_index, features, labels, unique_node_ids
    

# hops = 2
# overall_run_time = 0
# edge_index = edges.values.astype(np.int64).transpose()
# for i in tqdm(range(X.shape[0])):
#     start = time.time()
#     remapped_edge_index, features, labels, _ = get_subgraph_from_in_mem_graph(X, y, i, edge_index, hops)
#     overall_run_time += time.time() - start
#     print(f"Fetched {remapped_edge_index.shape} edges, {labels.shape} labels, {features.shape} features in {overall_run_time} s")

In [48]:
column_types = ["id SERIAL PRIMARY KEY"]
for col in X_and_y.columns:
    if col == "label":
        column_types.append(f"{col} INTEGER")
        continue
    column_types.append(f"{col} DOUBLE")

In [49]:
def reconnect():
    global conn
    conn.close()
    conn = connect_to_postgres("X_and_y_1000_5")

In [50]:
# Create a new database
conn = connect_to_mysql(db_name = "mysql")
new_db_name = node_file_name.split(".")[0]
create_database(conn, new_db_name)

# # Connect to the new database
conn.close()
conn = connect_to_mysql(new_db_name)

# Create schema
node_schema = f"""
CREATE TABLE IF NOT EXISTS nodes (
    {",".join(column_types)}
);
"""
create_table(conn, node_schema)
edge_schema = """
CREATE TABLE IF NOT EXISTS edges (
    source_id BIGINT UNSIGNED NOT NULL,
    target_id BIGINT UNSIGNED NOT NULL,
    PRIMARY KEY (source_id, target_id),
    FOREIGN KEY (source_id) REFERENCES nodes(id) ON DELETE CASCADE,
    FOREIGN KEY (target_id) REFERENCES nodes(id) ON DELETE CASCADE
);

"""
create_table(conn, edge_schema)

Connection successful.
Database 'X_and_y_10000_5' created successfully.
Connection successful.
Table created successfully.
Table created successfully.


In [51]:
# Upload CSV data and create index

csv_file_path = f"data/{node_file_name}"  # Replace with your CSV file path
load_data_infile(csv_file_path, "nodes", conn)

Data from 'data/X_and_y_10000_5.csv' loaded into 'nodes' successfully.


In [None]:

upload_csv_to_table(conn, csv_file_path, "nodes")
create_index(conn, "nodes", "id")
upload_edges_csv_to_table(conn, edge_file_name)
create_index(conn, "edges", "target_id")

In [18]:
with conn.cursor() as cursor:
    cursor.execute("SELECT id FROM nodes")
    results = cursor.fetchall()
seed_node_ids = list(map(lambda res_data: res_data[0], results))

In [19]:
X_and_y.columns[:-1]

Index(['f_0', 'f_1', 'f_2', 'f_3', 'f_4', 'f_5', 'f_6', 'f_7', 'f_8', 'f_9',
       'f_10', 'f_11', 'f_12', 'f_13', 'f_14', 'f_15', 'f_16', 'f_17', 'f_18',
       'f_19', 'f_20', 'f_21', 'f_22', 'f_23', 'f_24', 'f_25', 'f_26', 'f_27',
       'f_28', 'f_29', 'f_30', 'f_31', 'f_32', 'f_33', 'f_34', 'f_35', 'f_36',
       'f_37', 'f_38', 'f_39', 'f_40', 'f_41', 'f_42', 'f_43', 'f_44', 'f_45',
       'f_46', 'f_47', 'f_48', 'f_49'],
      dtype='object')

In [None]:
with conn.cursor() as cursor:
    i = 0 
    for seed_node_id in tqdm(seed_node_ids):
        i+=1
        if i<280: continue
        hops = 2
        try:
            start = time.time()
            cursor.execute(f"""
            WITH RECURSIVE NestedTargets AS (
            SELECT 0 AS depth, source_id, target_id
            FROM edges
            WHERE target_id = {seed_node_id}
            
            UNION ALL
            
            SELECT nt.depth + 1, e.source_id, e.target_id
            FROM edges e
            JOIN NestedTargets nt ON e.target_id = nt.source_id
            WHERE nt.depth < {hops -1 }
            )
            SELECT
                (SELECT array_agg(array[id, {", ".join(X_and_y.columns[:-1])}]) FROM 
                (SELECT id, {", ".join(X_and_y.columns[:-1])} FROM nodes
                 WHERE id IN (
                     (SELECT DISTINCT source_id FROM NestedTargets)
                     UNION
                     (SELECT DISTINCT target_id FROM NestedTargets)
                 ) ORDER BY id ) nodes) AS node_table,

                 (SELECT array_agg({X_and_y.columns[-1]}) FROM 
                (SELECT id, {X_and_y.columns[-1]} FROM nodes
                 WHERE id IN (
                     (SELECT DISTINCT source_id FROM NestedTargets)
                     UNION
                     (SELECT DISTINCT target_id FROM NestedTargets)
                 ) ORDER BY id) nodes) AS label_table,
                 
                (SELECT array_agg(array[source_id, target_id]) FROM
                (SELECT DISTINCT source_id, target_id
                 FROM NestedTargets) edges) AS edge_table;
                """)    
            results = cursor.fetchall()[0]
            labels = np.array(results[1])
            subgraph_node_features = np.array(results[0])
            if results[0] is None:
                continue
            
            subgraph_edges = np.array(results[2]).transpose()
            
            node_ids = subgraph_node_features[:, 0]
            _, cols_source = np.nonzero((subgraph_edges[0] == node_ids[:, None]).transpose())
            _, cols_target = np.nonzero((subgraph_edges[1] == node_ids[:, None]).transpose())
            remapped_edge_index = np.concatenate([np.expand_dims(cols_source, axis = 0), np.expand_dims(cols_target, axis = 0)], axis = 0)
            # labels = subgraph_node_features[:, -1]
            features = subgraph_node_features[:, 1:]
            overall_run_time = time.time() - start 
            ## Testing
            # remapped_edge_index_test, features_test, labels_test, unique_node_ids = get_subgraph_from_in_mem_graph(X, y, seed_node_id, edge_index, hops)                    
            # assert (sort_edge_index(torch.from_numpy(remapped_edge_index_test)) == sort_edge_index(torch.from_numpy(remapped_edge_index))).sum() / (remapped_edge_index_test.shape[-1] * remapped_edge_index_test.shape[0]), "Edges doesnt match"
            # assert np.allclose(features, features_test), "features doe not match"
            # assert np.allclose(labels_test, labels), "Labels does not match"
            
            print(f"Fetched {remapped_edge_index.shape} edges, {labels.shape} labels, {features.shape} features in ({overall_run_time} s)")
        except Exception as e:
            conn.rollback()
            print(f"Error uploading edges CSV data: {e}")
            raise    

  0%|          | 0/44906 [00:00<?, ?it/s]

Fetched (2, 786) edges, (482, 121) labels, (482, 50) features in (0.05923295021057129 s)
Fetched (2, 379) edges, (296, 121) labels, (296, 50) features in (0.03044581413269043 s)
Fetched (2, 1634) edges, (701, 121) labels, (701, 50) features in (0.08704566955566406 s)
Fetched (2, 3647) edges, (1044, 121) labels, (1044, 50) features in (0.1405472755432129 s)
Fetched (2, 403) edges, (290, 121) labels, (290, 50) features in (0.02264118194580078 s)
Fetched (2, 2755) edges, (939, 121) labels, (939, 50) features in (0.09304380416870117 s)
Fetched (2, 196) edges, (185, 121) labels, (185, 50) features in (0.01415252685546875 s)
Fetched (2, 124) edges, (115, 121) labels, (115, 50) features in (0.00949239730834961 s)
Fetched (2, 1486) edges, (684, 121) labels, (684, 50) features in (0.06747627258300781 s)
Fetched (2, 517) edges, (330, 121) labels, (330, 50) features in (0.024547576904296875 s)
Fetched (2, 4476) edges, (1067, 121) labels, (1067, 50) features in (0.1332547664642334 s)
Fetched (2, 2

In [None]:
# # Update data
# update_query = "UPDATE example_table SET city = %s WHERE id = %s;"
# update_params = ("NewCity", 1)
# update_data(conn, update_query, update_params)

In [15]:
# Delete the created database
conn.close()
conn = connect_to_postgres()
delete_database(conn, new_db_name)
conn.close()

Connection successful.
Database 'X_y_ppi' deleted successfully.
