In [None]:
import pandas as pd
import numpy as np
from neo4j import GraphDatabase

In [None]:
# Reading back the saved file
csv_file_path = "data/20newsgroups_full.csv"
data = pd.read_csv(csv_file_path)

# Displaying the first few rows of the combined data
print(data)

In [None]:
# Separate the data by class (target labels)
classes = data['target'].unique()
class_data = {label: data[data['target'] == label] for label in classes}

# Check the number of samples in each class
for label in classes:
    print(f"Class {label} has {class_data[label].shape[0]} samples.")

# Determine the number of classes
num_classes = len(classes)

# Create the subdataset with 8000 samples
samples_per_class_8000 = 8000 // num_classes  # Number of samples per class
subdataset_8000 = pd.concat([class_data[label].sample(samples_per_class_8000, random_state=42) for label in classes])

# Create the subdataset with 14000 samples
samples_per_class_14000 = 14000 // num_classes  # Number of samples per class
subdataset_14000 = pd.concat([class_data[label].sample(samples_per_class_14000, random_state=42) for label in classes])

# The full dataset is the entire dataset read from the CSV
subdataset_full = data

# Display the class distributions for each subdataset
print("Subdataset 8000 samples class distribution:\n", subdataset_8000['target'].value_counts())
print("Subdataset 14000 samples class distribution:\n", subdataset_14000['target'].value_counts())
print("Full dataset class distribution:\n", subdataset_full['target'].value_counts())

In [None]:
# "First Create an empty database, within a project"

uri = "bolt://localhost:7687"
username = "neo4j" # username for Neo4j, most probably is Neo4j, if you did not change while installing Neo4j Desktop
password = "eigen1234" # Password for database.
database_name = "d7.newsgroups"  # Database name

driver = GraphDatabase.driver(uri, auth=(username, password))

In [None]:
def check_connection():
    try:
        # Establish a session with the specified database
        with driver.session(database=database_name) as session:
            # Run a simple query to check the connection
            result = session.run("RETURN 'Connection to database successful' AS message")
            for record in result:
                print(record["message"])
    except Exception as e:
        print("Error connecting to the database:", e)

# Call the check_connection function
check_connection()

In [None]:
# Function to create nodes in the specified database with a dynamic label
def create_nodes(data, driver, label):
    try:
        # Establish a session with the specified database
        with driver.session(database=database_name) as session:
            for _, row in data.iterrows():
                # Create a dynamic dictionary for properties
                properties = row.to_dict()
                # Prepare the Cypher query with a dynamic label
                query = f"CREATE (n:{label} {{"
                query += ', '.join([f"{key}: ${key}" for key in properties.keys()])
                query += "})"
                
                # Run the query with properties
                session.run(query, **properties)

    except Exception as e:
        print("Error during node creation:", e)
    finally:
        # Close the driver
        driver.close()

In [None]:
# For subdataset_800
create_nodes(subdataset_8000, driver, 'Dataset7_8000_madelon')

# For subdataset_1600
create_nodes(subdataset_14000, driver, 'Dataset7_14000_madelon')

# For the full dataset (2600 samples)
create_nodes(subdataset_full, driver, 'Dataset3_full_madelon')
