In [None]:
from neo4j import GraphDatabase, basic_auth
from easydict import EasyDict as edict

In [None]:
hypp = edict()
hypp.fraud_number = 20

# Relabeling

### Relabeling Cost Function
$$ \Delta_l (G^\prime, G)= \sum_{v \in V} \delta_l (\lambda(v), \lambda^\prime(v)) $$

## Greedy (Version 1)
- in each step calculate vertex with the most violations ($T(v, l)$, Sect. 4.1)
- Find new label $ \lambda^\prime (v)$ to eliminate violations (Aim to eliminate more violations with each relabeling)


In [None]:
AUTHORS_PATH = f"datasets\\temp\\authors_20260106-170423.txt"

with open(AUTHORS_PATH) as f:
    authors = f.read()

all_possible_authors = [line for line in authors.split("\n") if line.strip()]

In [None]:
def get_violation_set_of_vertex(driver, instance_db, constraint_db, vertex): # for version 1
    query_1 = f"""
    MATCH (a {{name: "{vertex}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    constraint_co_authors = [record["name"] for record in driver.execute_query(query_1, database_=constraint_db).records]
    instance_co_authors = [record["name"] for record in driver.execute_query(query_1, database_=instance_db).records]
    # print(f"real co-authors of {vertex}: {constraint_co_authors}")
    # print(f"instance co-authors of {vertex}: {instance_co_authors}")
    violation_set = list(set(instance_co_authors) - set(constraint_co_authors))
    # print(f"violation set of {vertex}: {violation_set}")
    return violation_set

def get_potential_violations(driver, instance_db, constraint_db, vertex_in_instance, test_label):
    """
    Calculates how many violations a vertex WOULD have if its name was changed to test_label.
    """
    # 1. Get the current neighbors of this specific vertex in the instance graph
    query_neighbors = f"""
    MATCH (a {{name: "{vertex_in_instance}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    # These are the actual neighbors in the dirty graph
    neighbors = [record["name"] for record in driver.execute_query(query_neighbors, database_=instance_db).records]
    
    # 2. Check which of these neighbors are NOT allowed for the 'test_label' in the constraint graph
    query_constraint = f"""
    MATCH (a {{name: "{test_label}"}})-[:CO_AUTHOR]-(b)
    RETURN b.name AS name
    """
    allowed_co_authors = [record["name"] for record in driver.execute_query(query_constraint, database_=constraint_db).records]
    
    # Violations = neighbors in instance that are not allowed by the constraint for this test_label
    # We also include a check: test_label cannot be the same as the neighbor (unless self-loops are allowed)
    potential_violations = [n for n in neighbors if n not in allowed_co_authors and n != test_label]
    
    return len(potential_violations)


def greedy_v1(driver, instance_db, constraint_db):
    for i in range(hypp.fraud_number):
        print(f"\n--- Greedy iteration {i+1} ---")
        # 0. initialize
        vertex_with_max_violations = None
        max_violations = -1
        # 1. iterate over all vertices
        for author in all_possible_authors:
            # 2. get violation set of vertex & choose vertex with most violations
            violation_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
            num_violations = len(violation_set)
            if num_violations > max_violations:
                max_violations = num_violations
                vertex_with_max_violations = author
        
        # 3. iterate over all possible relabelings of v
        min_new_violations = max_violations
        best_relabel = vertex_with_max_violations # Default to current
        
        for potential_name in all_possible_authors:
            # Calculate T(v, l')
            new_v_count = get_potential_violations(driver, instance_db, constraint_db, 
                                                vertex_with_max_violations, potential_name)
            
            if new_v_count < min_new_violations:
                min_new_violations = new_v_count
                best_relabel = potential_name
                
            # Optimization: if we find a label with 0 violations, we can stop searching
            if min_new_violations == 0:
                break
        # print(f"Vertex with max violations: {vertex_with_max_violations} ({max_violations} violations)")
        #print(f"Best relabel: {best_relabel} (would reduce to {min_new_violations} violations)")
        # 4. choose & apply relabeling with least violations
        if best_relabel != vertex_with_max_violations:
            update_query = f"""
            MATCH (a {{name: "{vertex_with_max_violations}"}})
            SET a.name = "{best_relabel}"
            """
            driver.execute_query(update_query, database_=instance_db)
            print(f"Applied repair: {vertex_with_max_violations} -> {best_relabel}")
            print(f"Violations reduced to: {min_new_violations}")
        else:
            print("No better label found that reduces violations.")

In [None]:
# URI = "neo4j://127.0.0.1:7687"
# AUTH = ("neo4j", "12345678")

# with GraphDatabase.driver(URI, auth=AUTH) as driver:
#     driver.verify_connectivity()
#     greedy_v1(driver, "test-instance-graph", "test")

## Greedy (Version 2)
- Choose vertex relabeling that eliminates the most violations (over all vertices and labels number of original violations - remaining after relabeling)
    - +normalization


In [None]:
def choose_best_relabeling(driver, instance_db, constraint_db, all_possible_labels):
    best_score = -float('inf')
    best_action = None # Will store (vertex_name, new_label)

    # In Greedy V2, iterate over ALL vertices and ALL labels
    for author in all_possible_labels:
        # Get current violation count: |T(v, lambda(v))|
        v_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
        current_v_count = len(v_set)
        
        if current_v_count == 0:
            continue

        for potential_label in all_possible_labels:
            if potential_label == author:
                continue
                
            # Calculate cost delta_l
            cost = relabeling_cost(author, potential_label)
            if cost == 0: cost = 1 # Avoid division by zero
            
            # Calculate potential violations: |T(v, lambda'(v))|
            new_v_count = get_potential_violations(driver, instance_db, constraint_db, author, potential_label)
            
            # Normalizing violation elimination gain by the relabeling cost
            # Note: We only favor repairs where violations actually decrease.
            if current_v_count > new_v_count:
                score = (current_v_count - new_v_count) / cost
                
                if score > best_score:
                    best_score = score
                    best_action = (author, potential_label)
                    
    return best_action, best_score


def relabeling_cost(vertex, new_label):
    if new_label == vertex:
        return 0
    else:
        return 1


def greedy_v2(driver, instance_db, constraint_db):
    """
    Greedily selects the vertex relabeling that eliminates the most 
    violations normalized by cost.
    """
    iteration = 0
    # To ensure termination, we must see a reduction in total violations
    while iteration < hypp.fraud_number:
        print(f"--- Greedy V2 Iteration {iteration + 1} ---")
        
        # Step 1: Find the globally optimal repair action
        best_action, score = choose_best_relabeling(driver, instance_db, constraint_db, all_possible_authors)
        
        if not best_action or score <= 0:
            print("No more beneficial repairs found.")
            break
            
        target_vertex, new_label = best_action
        
        # Step 2: Apply the repair
        update_query = f"""
        MATCH (a {{name: "{target_vertex}"}})
        SET a.name = "{new_label}"
        """
        driver.execute_query(update_query, database_=instance_db)
        print(f"Repaired {target_vertex} -> {new_label} with efficiency score {score:.2f}")
        
        iteration += 1

    return "Repair complete."

In [None]:
def get_all_violations(driver, instance_db, constraint_db):
    """
    Returns a dictionary mapping each vertex to its violation set.
    """
    violations_dict = {}
    for author in all_possible_authors:
        violation_set = get_violation_set_of_vertex(driver, instance_db, constraint_db, author)
        if violation_set:
            violations_dict[author] = violation_set
    print(violations_dict, len(violations_dict))
    return violations_dict

In [None]:
URI = "neo4j://127.0.0.1:7687"
AUTH = ("neo4j", "12345678")

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()
    get_all_violations(driver, "test-instance-graph", "test")
    greedy_v2(driver, "test-instance-graph", "test")
    get_all_violations(driver, "test-instance-graph", "test")