# Compare Relationships across all labels between 2 versions

## Importing

In [None]:
import os
import logging
import configparser
import pandas as pd
from neo4j import GraphDatabase

## Get Current Path

In [None]:
# Setup
current_path = os.path.dirname(os.path.realpath("relationships.ipynb"))
# current_path = os.path.dirname(os.path.realpath('f0153.ipynb'))

## Read Config File 

In [None]:
config = configparser.ConfigParser()
config.read(f"{current_path}/config.ini")

neo4j_uri_old = config['DEFAULT']['Neo4j-Uri-old']
neo4j_uri_new = config['DEFAULT']['Neo4j-Uri-new']
username = config['DEFAULT']['Neo4j-Username']
password_old = config['DEFAULT']['Neo4j-Password-old']
password_new = config['DEFAULT']['Neo4j-Password-new']
output_folder = config['DEFAULT']['Output-Folder']

## Setup Log 

In [None]:
logging.basicConfig(
    filename=f"{current_path}/relationship_diff_all.log",
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s'
)

print("Comparing relationship differences by start label")

## Helper Functions to Query Relationships 

In [None]:
def run_query(driver, query):
    with driver.session() as session:
        result = session.run(query)
        return pd.DataFrame([dict(r) for r in result])

def get_graph_version(driver):
    query = "MATCH (n:version) RETURN n.version AS version LIMIT 1"
    try:
        df = run_query(driver, query)
        return df.iloc[0]['version']
    except Exception as e:
        logging.warning(f"Couldn't fetch version info: {e}")
        return "unknown"

def get_relationships_by_label(driver, label):
    query = f"""
    MATCH (a:`{label}`)-[r]->(b)
    RETURN 
        '{label}' AS start_node_type,
        a.source AS start_node_source,
        type(r) AS relationship_type,
        labels(b) AS target_node_labels,
        b.source AS target_node_source,
        count(*) AS count
    """
    return run_query(driver, query)

core_labels = ["researcher", "publication", "dataset", "grant", "organisation"]

def resolve_target_node_type(df):
    def match_label(label_list):
        for lbl in core_labels:
            if lbl in label_list:
                return lbl
        return "unknown"
    
    df["target_node_type"] = df["target_node_labels"].apply(match_label)
    return df.drop(columns=["target_node_labels"])

## Compare Relationships for All Labels

In [None]:
# Labels to compare
labels = core_labels
all_diffs = []

# Connect to both graphs
driver_old = GraphDatabase.driver(neo4j_uri_old, auth=(username, password_old))
driver_new = GraphDatabase.driver(neo4j_uri_new, auth=(username, password_new))

# Get version names dynamically
version_old = get_graph_version(driver_old)
version_new = get_graph_version(driver_new)

for label in labels:
    try:
        df_old = get_relationships_by_label(driver_old, label)
        df_old = resolve_target_node_type(df_old)
        df_new = get_relationships_by_label(driver_new, label)
        df_new = resolve_target_node_type(df_new)

        merge_cols = [
            "start_node_type", "start_node_source",
            "relationship_type", "target_node_type", "target_node_source"
        ]

        merged_df = pd.merge(
            df_old, df_new,
            on=merge_cols,
            how="outer",
            suffixes=(f"_{version_old}", f"_{version_new}")
        )

        merged_df.fillna(0, inplace=True)

        merged_df["diff"] = (
            merged_df[f"count_{version_new}"] - merged_df[f"count_{version_old}"]
        )
        ordered_cols = [
        "start_node_type", "start_node_source", "relationship_type",
        "target_node_type", "target_node_source",
        f"count_{version_old}", f"count_{version_new}", "diff"
        ]
        merged_df = merged_df[ordered_cols]
        all_diffs.append(merged_df)

        
        logging.info(f"✅ Compared relationships for: {label}")

    except Exception as e:
        logging.error(f"❌ Failed comparing relationships for {label}: {e}")

driver_old.close()
driver_new.close()

if all_diffs:
    final_relationship_diff = pd.concat(all_diffs)
    final_relationship_diff.sort_values(
        by=["start_node_type", "relationship_type", "diff"], ascending=[True, True, False], inplace=True
    )
    final_relationship_diff.reset_index(drop=True, inplace=True)
    final_relationship_diff
else:
    print("⚠️ No relationship differences to report.")
    final_relationship_diff = pd.DataFrame()
final_relationship_diff

## Save to CSV

In [None]:
output_file = f"{output_folder}/relationship_diff_all_labels.csv"

try:
    final_relationship_diff.to_csv(output_file, index=False)
    logging.info(f"Saved relationship diff to {output_file}")
    print(f"Saved relationship diff to {output_file}")
except Exception as e:
    logging.error(f"Failed to save CSV: {e}")
    print(f"Error writing file: {e}")