In [None]:
from neo4j import GraphDatabase, basic_auth
from easydict import EasyDict as edict
from dotenv import load_dotenv
import os

In [None]:
env_path = "env.txt"
load_dotenv(dotenv_path=env_path, override=True)
# def _strip_quotes(v):
#     if v is None:
#         return None
#     return v.strip().strip('"').strip("'")
# URI = _strip_quotes(os.getenv("NEO4J_URI"))
# USERNAME = _strip_quotes(os.getenv("NEO4J_USERNAME"))
# PASSWORD = _strip_quotes(os.getenv("NEO4J_PASSWORD"))
# AUTH = (USERNAME, PASSWORD)
# ITERATION_NUMBER = int(os.getenv("ITERATION_NUMBER") or 30)
# CONSTRAINT_DB = _strip_quotes(os.getenv("NEO4J_CONSTRAINT_DB"))
# INSTANCE_DB   = _strip_quotes(os.getenv("NEO4J_INSTANCE_DB"))

URI = "neo4j://127.0.0.1:7687"
AUTH = ("neo4j", "12345678")
ITERATION_NUMBER = 10
CONSTRAINT_DB = "test"
INSTANCE_DB   = "test-instance-graph"

# Relabeling

### Relabeling Cost Function
$$ \Delta_l (G^\prime, G)= \sum_{v \in V} \delta_l (\lambda(v), \lambda^\prime(v)) $$

## Greedy (Version 1)
- in each step calculate vertex with the most violations ($T(v, l)$, Sect. 4.1)
- Find new label $ \lambda^\prime (v)$ to eliminate violations (Aim to eliminate more violations with each relabeling)


In [None]:
AUTHORS_PATH = f"datasets\\temp\\authors_20260106-170423.txt"

with open(AUTHORS_PATH) as f:
    authors = f.read()

all_possible_authors = [line for line in authors.split("\n") if line.strip()]

In [None]:
def get_violation_set_of_vertex(driver, instance_db, constraint_db, vertex): # for version 1
    query_1 = f"""
    MATCH (a {{name: "{vertex}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    constraint_co_authors = [record["name"] for record in driver.execute_query(query_1, database_=constraint_db).records]
    instance_co_authors = [record["name"] for record in driver.execute_query(query_1, database_=instance_db).records]
    # print(f"real co-authors of {vertex}: {constraint_co_authors}")
    # print(f"instance co-authors of {vertex}: {instance_co_authors}")
    violation_set = list(set(instance_co_authors) - set(constraint_co_authors))
    # print(f"violation set of {vertex}: {violation_set}")
    return violation_set

def get_potential_violations(driver, instance_db, constraint_db, vertex_in_instance, test_label):
    """
    Calculates how many violations a vertex WOULD have if its name was changed to test_label.
    """
    # 1. Get the current neighbors of this specific vertex in the instance graph
    query_neighbors = f"""
    MATCH (a {{name: "{vertex_in_instance}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    # These are the actual neighbors in the dirty graph
    neighbors = [record["name"] for record in driver.execute_query(query_neighbors, database_=instance_db).records]
    
    # 2. Check which of these neighbors are NOT allowed for the 'test_label' in the constraint graph
    query_constraint = f"""
    MATCH (a {{name: "{test_label}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    allowed_co_authors = [record["name"] for record in driver.execute_query(query_constraint, database_=constraint_db).records]
    
    # Violations = neighbors in instance that are not allowed by the constraint for this test_label
    # We also include a check: test_label cannot be the same as the neighbor (unless self-loops are allowed)
    potential_violations = [n for n in neighbors if n not in allowed_co_authors and n != test_label]
    
    return len(potential_violations)


def greedy_v1(driver, instance_db, constraint_db):
    for i in range(ITERATION_NUMBER):
        print(f"\n--- Greedy iteration {i+1} ---")
        # 0. initialize
        vertex_with_max_violations = None
        max_violations = -1
        # 1. iterate over all vertices
        for author in all_possible_authors:
            # 2. get violation set of vertex & choose vertex with most violations
            violation_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
            num_violations = len(violation_set)
            if num_violations > max_violations:
                max_violations = num_violations
                vertex_with_max_violations = author
        
        # 3. iterate over all possible relabelings of v
        min_new_violations = max_violations
        best_relabel = vertex_with_max_violations # Default to current
        
        for potential_name in all_possible_authors:
            # Calculate T(v, l')
            new_v_count = get_potential_violations(driver, instance_db, constraint_db, 
                                                vertex_with_max_violations, potential_name)
            
            if new_v_count < min_new_violations:
                min_new_violations = new_v_count
                best_relabel = potential_name
                
            # Optimization: if we find a label with 0 violations, we can stop searching
            if min_new_violations == 0:
                break
        # print(f"Vertex with max violations: {vertex_with_max_violations} ({max_violations} violations)")
        #print(f"Best relabel: {best_relabel} (would reduce to {min_new_violations} violations)")
        # 4. choose & apply relabeling with least violations
        if best_relabel != vertex_with_max_violations:
            update_query = f"""
            MATCH (a {{name: "{vertex_with_max_violations}"}})
            SET a.name = "{best_relabel}"
            """
            driver.execute_query(update_query, database_=instance_db)
            print(f"Applied repair: {vertex_with_max_violations} -> {best_relabel}")
            print(f"Violations reduced to: {min_new_violations}")
        else:
            print("No better label found that reduces violations.")

In [None]:
# with GraphDatabase.driver(URI, auth=AUTH) as driver:
#     driver.verify_connectivity()
#     greedy_v1(driver, INSTANCE_DB, CONSTRAINT_DB)

## Greedy (Version 2)
- Choose vertex relabeling that eliminates the most violations (over all vertices and labels number of original violations - remaining after relabeling)
    - +normalization


In [None]:
def choose_best_relabeling(driver, instance_db, constraint_db, all_possible_labels):
    best_score = -float('inf')
    best_action = None # Will store (vertex_name, new_label)

    # In Greedy V2, iterate over ALL vertices and ALL labels
    for author in all_possible_labels:
        # Get current violation count: |T(v, lambda(v))|
        v_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
        print(f"Current violations for {author}: {v_set}")
        current_v_count = len(v_set)
        
        if current_v_count == 0:
            continue

        for potential_label in all_possible_labels:
            if potential_label == author:
                continue
                
            # Calculate cost delta_l
            cost = relabeling_cost(author, potential_label)
            if cost == 0: cost = 1 # Avoid division by zero
            
            # Calculate potential violations: |T(v, lambda'(v))|
            new_v_count = get_potential_violations(driver, instance_db, constraint_db, author, potential_label)
            
            # Normalizing violation elimination gain by the relabeling cost
            # Note: We only favor repairs where violations actually decrease.
            if current_v_count > new_v_count:
                score = (current_v_count - new_v_count) / cost
                
                if score > best_score:
                    best_score = score
                    best_action = (author, potential_label)
                    
    return best_action, best_score


def relabeling_cost(vertex, new_label):
    if new_label == vertex:
        return 0
    else:
        return 1


def greedy_v2(driver, instance_db, constraint_db):
    """
    Greedily selects the vertex relabeling that eliminates the most 
    violations normalized by cost.
    """
    iteration = 0
    # To ensure termination, we must see a reduction in total violations
    while iteration < ITERATION_NUMBER:
        print(f"--- Greedy V2 Iteration {iteration + 1} ---")
        
        # Step 1: Find the globally optimal repair action
        best_action, score = choose_best_relabeling(driver, instance_db, constraint_db, all_possible_authors)
        
        if not best_action or score <= 0:
            print("No more beneficial repairs found.")
            break
            
        target_vertex, new_label = best_action
        
        # Step 2: Apply the repair
        update_query = f"""
        MATCH (a {{name: "{target_vertex}"}})
        SET a.name = "{new_label}"
        """
        driver.execute_query(update_query, database_=instance_db)
        print(f"Repaired {target_vertex} -> {new_label} with efficiency score {score:.2f}")
        
        iteration += 1

    return "Repair complete."

In [None]:
def get_all_violations(driver, instance_db, constraint_db):
    """
    Returns a dictionary mapping each vertex to its violation set.
    """
    violations_dict = {}
    for author in all_possible_authors:
        violation_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
        if violation_set:
            violations_dict[author] = violation_set
    print(violations_dict, len(violations_dict))
    return violations_dict

In [None]:
# with GraphDatabase.driver(URI, auth=AUTH) as driver:
#     driver.verify_connectivity()
#     get_all_violations(driver, INSTANCE_DB, CONSTRAINT_DB)
#     greedy_v2(driver, INSTANCE_DB, CONSTRAINT_DB)
#     get_all_violations(driver, INSTANCE_DB, CONSTRAINT_DB)

## Contraction

In [None]:
# create super node and vertex class
class SuperNode:
    def __init__(self, label, host):
        self.label = label
        self.host = host
        self.guests = []
        self.stored_cost = 0
    
    
    def get_host(self):
        return self.host

    @property
    def id(self):
        return self.host.id
    
    def get_stored_cost(self):
        return self.stored_cost
    
    def set_stored_cost(self, new_cost):
        self.stored_cost = new_cost
    
    def get_guests(self):
        return self.guests
    
    def set_host(self, new_host):
        self.host = new_host
    
    def set_guests(self, new_guests):
        self.guests = new_guests
    
    def set_label(self, new_label):
        self.label = new_label
    
    def get_label(self):
        return self.label
    
    def get_all_vertices(self):
        """Formula: V(R) = {h(R)} U (Union of V(Ri) for all Ri in guests)"""
        verts = [self.host]
        for guest in self.guests:
            verts.extend(guest.get_all_vertices())
        # Return unique vertices only
        return list(set(verts))
    
    def get_cost(self, candidate_label):
        """
        Implements Formula 7 from the paper.
        """
        # 1. Calculate the total cost of relabeling all vertices in this node
        # v.get_label() is the current label of the vertex in the instance graph
        current_relabel_cost = 0
        for v in self.get_all_vertices():
            if v.get_label() != candidate_label:
                # Using the default count cost (delta_l = 1 for non-identical)
                current_relabel_cost += 1

        # 2. Sum up the costs of all previous internal contractions (guest nodes)
        # This assumes your SuperNode object stores the cost of its creation
        previous_guests_costs = sum(guest.stored_cost for guest in self.get_guests())

        # 3. Formula 7 result
        final_cost = current_relabel_cost - previous_guests_costs
        
        return final_cost

class Vertex:
    def __init__(self, id, label):
        self.id = id
        self.neighbors = []
        self.label = label

    def get_neighbors(self):
        return self.neighbors
    
    def add_neighbor(self, neighbor):
        self.neighbors.append(neighbor)

    def get_label(self):
        return self.label
    
    def set_label(self, new_label):
        self.label = new_label


In [None]:
def get_violation_set_of_vertex_contract(driver, instance_db, constraint_db, vertex):
    # Use vertex.label to get the actual author name (e.g., "s. singh")
    query = """
    MATCH (a {name: $name})-[:CO_AUTHOR]-(b)
    RETURN b.name AS neighbor_name
    """
    
    # Use parameters ($name) instead of f-strings for safety and reliability
    constraint_results = driver.execute_query(query, name=vertex.label, database_=constraint_db).records
    instance_results = driver.execute_query(query, name=vertex.label, database_=instance_db).records
    
    constraint_co_authors = [r["neighbor_name"] for r in constraint_results]
    instance_co_authors = [r["neighbor_name"] for r in instance_results]
    
    # Violation: Co-authors in the instance that are NOT allowed by the constraint graph
    violation_names = list(set(instance_co_authors) - set(constraint_co_authors))
    
    return violation_names

def create_all_vertices(driver, instance_db):
    vertices = {}
    query = f"""
    MATCH (a)
    OPTIONAL MATCH (a)-[:CO_AUTHOR]-(b)
    RETURN elementId(a) AS v_id, a.name AS name, elementId(b) AS neighbor_id
    """

    results = driver.execute_query(query, database_=instance_db).records
    
    for record in results:
        v_id = record["v_id"]
        label = record["name"]
        neighbor_id = record["neighbor_id"]
        
        if v_id not in vertices:
            vertices[v_id] = Vertex(v_id, label)
        
        # Add neighbor only if it exists (neighbor_id won't be None)
        if neighbor_id:
            vertices[v_id].add_neighbor(neighbor_id)
            
    return vertices

# for each vertex, create its super node
def create_super_nodes(vertices):
    super_nodes = {}
    for vertex_name, vertex in vertices.items():
        host = vertex
        super_node = SuperNode(vertex.label, host)
        super_nodes[vertex_name] = super_node
    return super_nodes

def get_node_pair_most_violations(driver, instance_db, constraint_db, super_nodes):
    best_pair = (None, None)
    max_violations = -1
    # 1. iterate over vertices in super_nodes
    for i in range(len(super_nodes)):
        for j in range(i + 1, len(super_nodes)):
            R1, R2 = super_nodes[i], super_nodes[j]
            current_pair_violations = 0

            # ---------
            # # Sum violations originating from vertices in R1 that land in R2
            # for v in R1.get_all_vertices():
            #     violation_set = 
            # ntract(driver, instance_db, constraint_db, v)
            #     print(f"Violation set for vertex {v.id}: {violation_set}")
            #     for neighbor in violation_set:
            #         # Check if this specific violating neighbor belongs to R2
            #         if (neighbor in R2.get_all_vertices()):
            #             current_pair_violations += 1
            # -------------

            for v in R1.get_all_vertices():
                violation_names = get_violation_set_of_vertex_contract(driver, instance_db, constraint_db, v)
                print(f"Violation set for vertex {v.id}: {violation_names}")
                # Get the labels of all vertices in R2
                r2_labels = [vert.label for vert in R2.get_all_vertices()]
                
                for v_name in violation_names:
                    if v_name in r2_labels:
                        current_pair_violations += 1
            
            print(f"Current pair ({R1.label}, {R2.label}) has {current_pair_violations} violations.")
            
            if current_pair_violations > max_violations:
                max_violations = current_pair_violations
                best_pair = (R1, R2)
                
    return best_pair

# def get_all_neighbors_in_instance(driver, instance_db, vertex):
#     query = f"""
#     MATCH (a {{name: "{vertex.label}"}})-[:CO_AUTHOR]-(b)
#     RETURN elementId(b) AS name
#     """
#     neighbors = [record["name"] for record in driver.execute_query(query, database_=instance_db).records]
#     return neighbors

def get_all_neighbors_in_instance(driver, instance_db, vertex, all_vertices_dict):
    # Search by elementId to be precise, as labels (names) can be non-unique
    query = """
    MATCH (a) WHERE elementId(a) = $v_id
    MATCH (a)-[:CO_AUTHOR]-(b)
    RETURN elementId(b) AS neighbor_id
    """
    results = driver.execute_query(query, v_id=vertex.id, database_=instance_db).records
    
    # Map the neighbor IDs back to the actual Vertex objects
    neighbor_objects = [all_vertices_dict[r["neighbor_id"]] for r in results]
    return neighbor_objects

def check_satisfaction(driver, constraint_db, l1, l2):
    # 1. Simple equality check first (saves a database hit)
    if l1 == l2:
        return True

    # 2. Check the constraint graph for an edge between l1 and l2
    # Neighborhood constraints are undirected in this study [cite: 203]
    query = """
    MATCH (a {name: $l1})
    MATCH (b {name: $l2})
    RETURN EXISTS((a)-[:CO_AUTHOR]-(b)) AS is_satisfied
    """
    
    result = driver.execute_query(
        query, 
        l1=l1, 
        l2=l2, 
        database_=constraint_db
    )
    
    if not result.records:
        return False
        
    return result.records[0]["is_satisfied"]

def get_candidate_label(driver, instance_db, constraint_db, R_to_repair, R_target, all_vertices_dict):
    # h(R1) is the host vertex of the target node
    host_label = R_target.get_host().get_label()
    
    # 1. Get all possible candidate labels L' from constraint_db 
    # that are neighbors of (or equal to) the host_label
    query_candidates = f"""
    MATCH (l1 {{name: "{host_label}"}})-[:CO_AUTHOR]-(l2)
    RETURN DISTINCT l2.name AS label
    """
    possible_labels = [r["label"] for r in driver.execute_query(query_candidates, database_=constraint_db).records]
    
    best_label = None
    max_gain = -float('inf')

    # 2. Evaluate each candidate l' based on Formula 6
    for l_prime in possible_labels:
        # |T(R2)|: Violations before repair
        # |T(R2, l')|: Violations if all vertices in R_to_repair are relabeled to l_prime
        
        current_violations_count = 0
        new_violations_count = 0
        
        for v in R_to_repair.get_all_vertices():
            # Get neighbors in the instance graph
            neighbors = get_all_neighbors_in_instance(driver, instance_db, v, all_vertices_dict)
            
            for neighbor in neighbors:
                print(type(neighbor))
                # Check if neighbor's label is incompatible with l_prime
                if not check_satisfaction(driver, constraint_db, l_prime, neighbor.get_label()):
                    new_violations_count += 1
            
            # Count original violations for this vertex to calculate gain
            current_violations_count += len(get_violation_set_of_vertex_contract(driver, instance_db, constraint_db, v))

        # Gain = |T(R2)| - |T(R2, l')|
        gain = current_violations_count - new_violations_count
        
        if gain > max_gain:
            max_gain = gain
            best_label = l_prime
            
    return best_label

In [None]:
def count_total_inter_node_violations(driver, instance_db, constraint_db, super_nodes_list):
    total_violations = 0
    # Use a set to track edges we've already counted (since co-authorship is undirected)
    processed_edges = set()

    # 1. Iterate through every active SuperNode
    for i, R1 in enumerate(super_nodes_list):
        # Create a set of IDs for vertices in R1 for internal check
        r1_ids = {v.id for v in R1.get_all_vertices()}
        
        # 2. Compare with every other active SuperNode
        for j in range(i + 1, len(super_nodes_list)):
            R2 = super_nodes_list[j]
            r2_ids = {v.id for v in R2.get_all_vertices()}
            
            # 3. Check every vertex in R1
            for v in R1.get_all_vertices():
                # Get the violation names for this vertex
                # (Neighbors in instance graph not allowed by constraint graph)
                violation_names = get_violation_set_of_vertex_contract(driver, instance_db, constraint_db, v)
                
                # 4. Check if the violating neighbor's label belongs to anyone in R2
                for v_name in violation_names:
                    # We look for vertices in R2 that have that specific label
                    for v_target in R2.get_all_vertices():
                        if v_target.label == v_name:
                            # We found an edge between R1 and R2 that is a violation
                            total_violations += 1
                            # Since we only iterate j > i, we don't need to worry 
                            # about counting the same pair twice in reverse.
                            
    return total_violations

In [None]:
def contract(driver, instance_db, constraint_db, numm_iterations):
    # Step 1: Create all vertices from the instance graph
    vertices = create_all_vertices(driver, instance_db)
    # Step 2: Create super nodes for each vertex
    super_nodes = create_super_nodes(vertices)

    
    # loop while G not satisfies constraints
    for i in range(numm_iterations):
        print(f"\n--- Contraction iteration {i+1} ---")
        # while (len(get_all_violations(driver, instance_db, constraint_db)) > 0): # TODO: right now we are changing nothing about the two graphs
        
        r_1, r_2 = get_node_pair_most_violations(driver, instance_db, constraint_db, list(super_nodes.values()))
        l_1 = get_candidate_label(driver, instance_db, constraint_db, r_1, r_2, vertices)
        l_2 = get_candidate_label(driver, instance_db, constraint_db, r_2, r_1, vertices)
        print(f"Chosen pair: R1 label={r_1.get_label()}, R2 label={r_2.get_label()}")
        print(f"Candidate labels: l1={l_1}, l2={l_2}")

        cost_r1 = r_1.get_cost(l_1)
        cost_r2 = r_2.get_cost(l_2)
        
        if cost_r2 > cost_r1:
            r_1, r_2 = r_2, r_1  # Swap
            chosen_label = l_1
            applied_cost = cost_r1
        else:
            chosen_label = l_2
            applied_cost = cost_r2
        
        r_2.set_stored_cost(applied_cost)
        
        for v in r_2.get_all_vertices():
            v.set_label(chosen_label)
        
        r_1.get_guests().append(r_2)
        # Remove r_2 from super_nodes
        if r_2.host.id in super_nodes:
            del super_nodes[r_2.host.id]

    return vertices, super_nodes

In [None]:
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()
    vertices, supernodes = contract(driver, INSTANCE_DB, CONSTRAINT_DB, 0)
    origin_num_violations = count_total_inter_node_violations(driver, INSTANCE_DB, CONSTRAINT_DB, list(supernodes.values()))
    vertices, supernodes = contract(driver, INSTANCE_DB, CONSTRAINT_DB, ITERATION_NUMBER)
    num_violations = count_total_inter_node_violations(driver, INSTANCE_DB, CONSTRAINT_DB, list(supernodes.values()))

In [None]:
print(origin_num_violations)
print(num_violations)

In [None]:
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()
    old_vertices = create_all_vertices(driver, INSTANCE_DB)
    old_supernodes = create_super_nodes(old_vertices)

In [None]:
for supernode in supernodes.values():
    print(f"SuperNode Label: {supernode.get_label()}")
    print(f" Host: {supernode.get_host().label}")
    print(f" Guests: {[guest.get_host().label for guest in supernode.get_guests()]}")
    print(f" All Vertices: {[v.label for v in supernode.get_all_vertices()]}")
    print(f" All VerticesID: {[v.id for v in supernode.get_all_vertices()]}")
    print(f" Stored Cost: {supernode.get_stored_cost()}")
    print("-----")
print("---------------------------------")
for supernode in old_supernodes.values():
    print(f"SuperNode Label: {supernode.get_label()}")
    print(f" Host: {supernode.get_host().label}")
    print(f" Guests: {[guest.get_host().label for guest in supernode.get_guests()]}")
    print(f" All Vertices: {[v.label for v in supernode.get_all_vertices()]}")
    print(f" All VerticesID: {[v.id for v in supernode.get_all_vertices()]}")
    print(f" Stored Cost: {supernode.get_stored_cost()}")
    print("-----")

In [None]:
for vertex in vertices.values():
    # Look up each neighbor ID in the vertices dictionary to get the label
    neighbor_labels = [vertices[n_id].label for n_id in vertex.get_neighbors()]
    
    print(f"  Current Label: {vertex.label}; Neighbor Labels: {neighbor_labels}")