In [None]:
import random
import itertools
import time
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.nn import HypergraphConv
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import defaultdict
import copy

################################################
# Hypergraph Data Generation and FD Manipulation
################################################

def generate_acyclic_hypergraph(num_attributes,
                                            num_fds,
                                            max_edge_size,
                                            max_rhs_size):
    """
    Generate a random acyclic set of FDs by:
      1) Creating a random topological order of attributes.
      2) Iteratively sampling (X -> Y) and adding edges if it does not create a cycle.
    """
    attributes = [f"A{i}" for i in range(num_attributes)]
    random.shuffle(attributes)  # random topological ordering

    # Keep a directed graph to ensure no cycles are introduced
    dependency_graph = nx.DiGraph()
    for attr in attributes:
        dependency_graph.add_node(attr)

    hypergraph = {}  # dict: { tuple_of_LHS : set_of_RHS_attributes }
    attempts = 0

    while len(hypergraph) < num_fds and attempts < 10_000:
        attempts += 1

        # Randomly pick how many attributes in a potential FD
        edge_size = random.randint(2, max_edge_size)
        subset = random.sample(attributes, edge_size)

        # Decide how many go on LHS vs. RHS (at least 1 each).
        lhs_size = random.randint(1, min(edge_size - 1, max_edge_size - 1))
        rhs_size = edge_size - lhs_size
        if rhs_size > max_rhs_size:
            continue

        X = tuple(sorted(subset[:lhs_size]))
        Y = set(subset[lhs_size:])
        if not X or not Y:
            continue

        # Tentatively add edges (x -> y) for x in X, y in Y
        edges_to_add = []
        for x in X:
            for y in Y:
                edges_to_add.append((x, y))
                dependency_graph.add_edge(x, y)

        # Check acyclicity
        if nx.is_directed_acyclic_graph(dependency_graph):
            # Valid FD
            if X not in hypergraph:
                hypergraph[X] = set()
            hypergraph[X].update(Y)
        else:
            # Revert edges
            for (x, y) in edges_to_add:
                if dependency_graph.has_edge(x, y):
                    dependency_graph.remove_edge(x, y)

    return hypergraph

def compute_closure(fd_set, attributes):
    """
    Given a dictionary of FDs: {X -> set_of_Y}, compute the closure of 'attributes'.
    """
    closure = set(attributes)
    changed = True
    while changed:
        changed = False
        for lhs, rhs_set in fd_set.items():
            if set(lhs).issubset(closure) and not rhs_set.issubset(closure):
                closure.update(rhs_set)
                changed = True
    return closure

def decompose_fds_to_singleton(fd_set):
    """
    Decompose multi-attribute RHS into singleton FDs: X -> A, X -> B, etc.
    """
    new_fd_set = {}
    for lhs, rhs_set in fd_set.items():
        for rhs in rhs_set:
            lhs_sorted = tuple(sorted(lhs))
            if lhs_sorted not in new_fd_set:
                new_fd_set[lhs_sorted] = set()
            new_fd_set[lhs_sorted].add(rhs)
    return new_fd_set

def reduce_fd_left_side(fd_set):
    """
    Attempt to remove redundant attributes from LHS of each FD
    (standard minimal cover step).
    """
    changed = True
    while changed:
        changed = False
        all_fds = list(fd_set.items())
        for lhs, rhs_set in all_fds:
            lhs_set = set(lhs)
            for attr in list(lhs_set):
                lhs_reduced = lhs_set - {attr}
                if not lhs_reduced:
                    continue

                original_rhs = fd_set.get(tuple(sorted(lhs_set)), set())
                if tuple(sorted(lhs_set)) in fd_set:
                    del fd_set[tuple(sorted(lhs_set))]

                lhs_red_tuple = tuple(sorted(lhs_reduced))
                if lhs_red_tuple not in fd_set:
                    fd_set[lhs_red_tuple] = set()
                fd_set[lhs_red_tuple].update(rhs_set)

                # Check closure
                test_closure = compute_closure(fd_set, lhs_reduced)
                if not rhs_set.issubset(test_closure):
                    # revert
                    del fd_set[lhs_red_tuple]
                    fd_set[tuple(sorted(lhs_set))] = original_rhs
                else:
                    changed = True
                    break
    return fd_set

def remove_redundant_fds(fd_set):
    """
    Remove entire FDs if they are implied by the others (still standard minimal cover).
    """
    fds_list = list(fd_set.items())
    for lhs, rhs_set in fds_list:
        for rhs in list(rhs_set):
            old_rhs = set(fd_set.get(lhs, []))
            old_rhs.remove(rhs)
            if not old_rhs:
                del fd_set[lhs]
            else:
                fd_set[lhs] = old_rhs

            c = compute_closure(fd_set, lhs)
            if rhs not in c:
                if lhs not in fd_set:
                    fd_set[lhs] = set()
                fd_set[lhs].add(rhs)
    return fd_set

def minimal_cover(fd_set):
    """
    1) Decompose multi-attribute RHS -> singletons
    2) Reduce LHS
    3) Remove redundant FDs
    """
    fd_set = decompose_fds_to_singleton(fd_set)
    fd_set = reduce_fd_left_side(fd_set)
    fd_set = remove_redundant_fds(fd_set)
    return fd_set

def is_cover_minimal(fd_set):
    """
    Check if the FD set is a valid minimal cover.
    """
    return len(fd_set) == len(minimal_cover(fd_set))

def is_cyclic_fd_graph(fd_set):
    """
    Check if a functional dependency set forms a cyclic graph.
    """
    G = nx.DiGraph()

    for X, Y_set in fd_set.items():
        for y in Y_set:
            for x in X:
                G.add_edge(x, y)

    return not nx.is_directed_acyclic_graph(G) 

def has_same_closure(fd_set1, fd_set2, attributes):
    """
    Check if two FD sets have the same closure.
    """
    closure1 = compute_closure(fd_set1, attributes)
    closure2 = compute_closure(fd_set2, attributes)
    return closure1 == closure2

def compute_3NF_decomposition(fd_set, attributes):
    """
    Compute a 3NF decomposition from a set of functional dependencies.
    """
    relations = []

    # Create relations for each FD
    for X, Y_set in fd_set.items():
        new_relation = set(X) | Y_set
        relations.append(new_relation)

    # Ensure lossless join by adding a candidate key relation
    candidate_key = find_candidate_key(fd_set, attributes)

    if not isinstance(candidate_key, set):  
        candidate_key = set(candidate_key)

    if not any(candidate_key.issubset(rel) for rel in relations):
        relations.append(candidate_key)

    return relations

def find_candidate_key(fd_set, attributes):
    """
    Find a candidate key for a given relation based on functional dependencies.
    
    :param fd_set: The minimal cover of functional dependencies.
    :param attributes: The full set of attributes in the relation.
    :return: A candidate key as a set of attributes.
    """
    closure_sets = {}

    for X in fd_set.keys():
        closure = compute_closure(fd_set, X)
        closure_sets[frozenset(X)] = closure

    for X, closure in closure_sets.items():
        if closure == set(attributes):
            return set(X)  # Ensure this always returns a set

    return {str(a) for a in attributes}  # Default

################################################
# FASTR Implementation
################################################

def compute_fastr_closure(fd_set, attributes):
    """
    Compute closure using partition-based refinement.
    """
    closure = set(attributes)
    changed = True

    while changed:
        changed = False
        for X, Y_set in fd_set.items():
            if set(X).issubset(closure) and not Y_set.issubset(closure):
                closure.update(Y_set)
                changed = True
    
    return closure

def refine_partitions(partitions, X, Y):
    """
    Merge partitions if X → Y holds.
    """
    merged_partition = set()
    # Union of all partitions containing X
    for attr in X:
        merged_partition |= partitions[attr]  

    # Merge Y into the refined partition
    for y in Y:
        partitions[y] |= merged_partition  

def fastr_minimal_cover(fd_set):
    """
    Compute minimal cover by removing redundant FDs.
    """
    min_cover = defaultdict(set)
    
    for X, Y_set in fd_set.items():
        for y in Y_set:
            temp_FD_set = copy.deepcopy(fd_set)  # Deep copy to avoid modifying original
            if X in temp_FD_set:
                temp_FD_set[X] -= {y}  # Remove y from X's dependency

            # If y is still derivable from X, it's redundant and should be removed
            if y not in compute_fastr_closure(temp_FD_set, X):
                min_cover[X].add(y)  # Retain only necessary FDs
                
    return min_cover

def run_FASTR(hypergraph):
    """
    Full FASTR algorithm: partition refinement + minimal cover computation.
    """
    fastr_fds = defaultdict(set)
    attributes = set(attr for X, Y_set in hypergraph.items() for attr in X + tuple(Y_set))

    # Initialize partitions: Each attribute starts in its own set
    partitions = {attr: {attr} for attr in attributes}

    for X, Y_set in hypergraph.items():
        X_tuple = tuple(sorted(X))
        closure = compute_fastr_closure(hypergraph, X)
        
        # Prune redundant attributes
        minimal_closure = {y for y in closure if y not in X_tuple}
        
        # Partition refinement step
        refine_partitions(partitions, X_tuple, minimal_closure)
        
        # Store FDs
        fastr_fds[X_tuple] |= minimal_closure
    
    # Compute minimal cover
    return fastr_minimal_cover(fastr_fds)

################################################
# Create Dataset with Labels
################################################

def encode_hypergraph_torch(hypergraph, num_attributes):
    """
    Convert a full hypergraph into PyTorch tensor format.
    """
    edge_index = []
    attribute_to_index = {attr: idx for idx, attr in enumerate(set(itertools.chain(*[X + tuple(Y_set) for X, Y_set in hypergraph.items()])))}

    for hyperedge_id, (X, Y_set) in enumerate(hypergraph.items()):
        for attr in X:
            edge_index.append([attribute_to_index[attr], hyperedge_id])
        for attr in Y_set:
            edge_index.append([attribute_to_index[attr], hyperedge_id])

    return torch.tensor(edge_index, dtype=torch.long).t().contiguous()

################################################
# HGNN Model and Loss Functions
################################################

class HNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(HNN, self).__init__()
        self.conv1 = HypergraphConv(in_channels, hidden_channels)
        self.conv2 = HypergraphConv(hidden_channels, out_channels)
        self.loss_history = []
        self.validation_loss_history = []

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

def compute_join_cost_loss(decomposed_relations, original_table_size):
    """
    Compute the join cost loss for a given 3NF decomposition.
    """
    join_loss = 0
    for i in range(len(decomposed_relations)):
        for j in range(i + 1, len(decomposed_relations)):
            common_attrs = decomposed_relations[i].intersection(decomposed_relations[j])
            if common_attrs:
                selectivity = min(0.1, 1.0 / len(common_attrs))  # Approximate selectivity
                surviving_tuples = original_table_size * selectivity
                join_loss += surviving_tuples * (len(decomposed_relations[i]) * len(decomposed_relations[j]))
    
    return join_loss / original_table_size  # Normalize by total tuples

def left_reduction_loss(fd_set):
    """
    Computes the left-reduction loss by penalizing FDs where an attribute in the LHS can be removed while maintaining closure.
    """
    loss = 0
    for X, Y_set in list(fd_set.items()):
        X = list(X)
        for attr in X[:]:
            X_reduced = tuple(sorted(set(X) - {attr}))
            if X_reduced and X_reduced in fd_set and compute_closure(fd_set, X_reduced) >= Y_set:
                loss += 1  # Penalize redundant LHS attributes
    return loss

def right_reduction_loss(fd_set):
    """
    Computes the right-reduction loss by ensuring that every FD has a single attribute on the RHS.
    """
    return sum(len(Y_set) - 1 for Y_set in fd_set.values() if len(Y_set) > 1)

def minimal_cover_consistency_loss(predicted_FDs, original_FDs, attributes):
    """
    Ensures that the predicted FD cover maintains the same closure as the original FDs.
    """
    predicted_closure = compute_closure(predicted_FDs, attributes)
    original_closure = compute_closure(original_FDs, attributes)
    return len(original_closure - predicted_closure) + len(predicted_closure - original_closure)  # Penalize closure mismatch

def recall_boost_loss(predicted_FDs, ground_truth_FDs, attributes):
    predicted_closure = compute_closure(predicted_FDs, attributes)
    ground_truth_closure = compute_closure(ground_truth_FDs, attributes)
    return max(len(ground_truth_closure - predicted_closure), 0)  # Penalize missing FDs

def total_loss(predicted_FDs, original_FDs, decomposed_relations, original_table_size, lambda_join=1):
    """
    Compute total loss for HNN optimization, incorporating redundancy penalties and minimality constraints.
    """
    left_loss = left_reduction_loss(predicted_FDs)
    right_loss = right_reduction_loss(predicted_FDs)
    minimality_loss = minimal_cover_consistency_loss(predicted_FDs, original_FDs, list(original_FDs.keys()))
    recall_loss = recall_boost_loss(predicted_FDs, original_FDs, list(original_FDs.keys()))

    join_loss = compute_join_cost_loss(decomposed_relations, original_table_size)

    
    loss = left_loss + right_loss + minimality_loss + lambda_join * join_loss + recall_loss
    print(loss)
    return torch.tensor(loss, dtype=torch.float32, requires_grad=True)

def train_hnn(model, train_hypergraphs, optimizer, criterion, num_attributes, epochs=100):

    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for graph in train_hypergraphs:
            edge_index = encode_hypergraph_torch(graph, num_attributes)
            train_data = Data(x=torch.eye(num_attributes), edge_index=edge_index, y=torch.ones((num_attributes, 1)))
            
            optimizer.zero_grad()
            output = model(train_data.x, train_data.edge_index)
            loss = criterion(output, train_data.y)  
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        model.loss_history.append(epoch_loss / len(train_hypergraphs))

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Avg Train Loss: {epoch_loss / len(train_hypergraphs):.10f}")

################################################
# Evaluation
################################################

def closure_equivalent(fd_set_pred, fd_set_true, attributes):
    """
    Check if for all subsets of attributes (or at least the key ones),
    the closure is the same. In practice, enumerating all subsets is expensive,
    so we often do approximate checks or key-based checks.
    Here, we do a simpler check: pick each LHS from fd_set_true,
    see if closure matches. This can be adapted.
    """
    # Minimally, let's check closure of each LHS in the true set.
    for lhs, _rhs in fd_set_true.items():
        if not has_same_closure(fd_set_pred, fd_set_true, lhs):
            return False
    return True

def count_redundant_dependencies(fd_set):
    """
    Count the total number of redundant dependencies in the FD set.
    """
    redundant_count = 0

    for X, Y_set in list(fd_set.items()):
        X = list(X)

        # Left-redundancy: Check if an attribute can be removed
        for attr in X[:]:
            X_reduced = tuple(sorted(set(X) - {attr}))
            if X_reduced and X_reduced in fd_set and compute_closure(fd_set, X_reduced) >= Y_set:
                redundant_count += 1

        # Right-redundancy: Check if multiple attributes in RHS
        if len(Y_set) > 1:
            redundant_count += len(Y_set) - 1  # Each extra attribute in RHS is redundant

    return redundant_count

def breakdown_redundancy_per_FD(fd_set):
    """
    Returns a breakdown of left vs. right redundancy per FD.
    """
    redundancy_report = {}

    for X, Y_set in list(fd_set.items()):
        X = list(X)
        left_redundancy = []
        right_redundancy = []

        # Left-redundancy: Detect removable attributes
        for attr in X[:]:
            X_reduced = tuple(sorted(set(X) - {attr}))
            if X_reduced and X_reduced in fd_set and compute_closure(fd_set, X_reduced) >= Y_set:
                left_redundancy.append(attr)

        # Right-redundancy: Detect extra RHS attributes
        if len(Y_set) > 1:
            right_redundancy.extend(list(Y_set)[1:])  # All except the first are redundant

        redundancy_report[X] = {
            "left_redundant": left_redundancy,
            "right_redundant": right_redundancy
        }

    return redundancy_report

def hamming_loss(predicted_FD_sets, ground_truth_FD_sets, attributes):
    """
    Compute Hamming Loss across multiple hypergraphs.
    The loss is the number of missing or extra FDs in the prediction.
    """
    total_loss = 0
    num_samples = len(predicted_FD_sets)

    for predicted_FDs, ground_truth_FDs in zip(predicted_FD_sets, ground_truth_FD_sets):
        if len(predicted_FDs) <= len(ground_truth_FDs) and closure_equivalent(predicted_FDs, ground_truth_FDs, attributes) and is_cover_minimal(predicted_FDs):
            pass 

        predicted_set = set((tuple(sorted(X)), tuple(sorted(Y))) for X, Y in predicted_FDs.items())
        ground_truth_set = set((tuple(sorted(X)), tuple(sorted(Y))) for X, Y in ground_truth_FDs.items())

        false_positives = len(predicted_set - ground_truth_set)  # Extra dependencies
        false_negatives = len(ground_truth_set - predicted_set)  # Missing dependencies
        loss = false_positives + false_negatives

        total_loss += loss

    return total_loss / num_samples if num_samples > 0 else 0.0

def compute_join_cost(decomposed_relations, original_table_size):
    """
    Compute the join cost for a given 3NF decomposition.
    
    :param decomposed_relations: List of table sizes in the 3NF schema.
    :param original_table_size: Size of the original unnormalized relation.
    :return: Join cost loss.
    """
    join_cost = 0
    for i in range(len(decomposed_relations)):
        for j in range(i + 1, len(decomposed_relations)):
            join_cost += (len(decomposed_relations[i]) * len(decomposed_relations[j])) / original_table_size
    return join_cost

def evaluate_HNN(model, test_hypergraphs, ground_truth_FD_sets, num_attributes):
    """
    Evaluate HNN model, apply FASTR post-processing to its predictions, and compare results.
    """
    model.eval()

    if not test_hypergraphs or not isinstance(test_hypergraphs[0], dict):
        raise ValueError("Error: test_hypergraphs is empty or not properly formatted as dictionaries.")

    attributes = extract_attributes(test_hypergraphs)

    hnn_time = 0
    hnn_avg_len_change = 0
    predicted_FD_sets = []

    # HNN Evaluation
    for test_hypergraph in test_hypergraphs:
        start_len = len(test_hypergraph)
        start_hnn = time.time()
        with torch.no_grad():
            test_edge_index = encode_hypergraph_torch(test_hypergraph, num_attributes)
            test_data = Data(x=torch.eye(num_attributes), edge_index=test_edge_index, y=torch.ones((num_attributes, 1)))
            test_output = model(test_data.x, test_data.edge_index)
        end_hnn = time.time()

        hnn_time += (end_hnn - start_hnn)
        
        predicted_FDs = {}  # New predicted FD set
        threshold = 0.5  # Use a threshold for binarization

        for idx, score in enumerate(test_output):
            if score.item() > threshold:  # Only keep significant dependencies
                lhs_idx = test_edge_index[0, idx].item()
                rhs_idx = test_edge_index[1, idx].item()
                lhs_attr = attributes[lhs_idx]
                rhs_attr = attributes[rhs_idx]

                if (lhs_attr,) not in predicted_FDs:
                    predicted_FDs[(lhs_attr,)] = set()
                predicted_FDs[(lhs_attr,)].add(rhs_attr)

                
        predicted_FD_sets.append(predicted_FDs)
        hnn_avg_len_change += start_len - len(predicted_FDs)

    
    hnn_avg_len_change /= len(test_hypergraphs)
    hnn_time /= len(test_hypergraphs)

    # FASTR Evaluation on Original Test Set
    fastr_time = 0
    fastr_avg_len_change = 0
    fastr_FD_sets = []

    for test_hypergraph in test_hypergraphs:
        start_len = len(test_hypergraph)
        start_fastr = time.time()
        fastr_FDs = run_FASTR(test_hypergraph)
        end_fastr = time.time()

        fastr_time += (end_fastr - start_fastr)
        fastr_FD_sets.append(fastr_FDs)

        fastr_avg_len_change += start_len - len(fastr_FDs)

    fastr_avg_len_change /= len(test_hypergraphs)
    fastr_time /= len(test_hypergraphs)

    # Post-processing: Apply FASTR to HNN Predictions
    hnn_fastr_time = 0
    hnn_fastr_FD_sets = []
    hnn_fastr_avg_len_change = 0

    for hnn_pred in predicted_FD_sets:
        start_len = len(hnn_pred)
        start_hnn_fastr = time.time()
        hnn_fastr_FDs = run_FASTR(hnn_pred)
        end_hnn_fastr = time.time()

        hnn_fastr_time += (end_hnn_fastr - start_hnn_fastr)
        hnn_fastr_FD_sets.append(hnn_fastr_FDs)

        hnn_fastr_avg_len_change += len(hnn_pred) - len(hnn_fastr_FDs)

    hnn_fastr_avg_len_change /= len(predicted_FD_sets)
    hnn_fastr_time /= len(predicted_FD_sets)

    # Compute Metrics for HNN
    hamming_hnn = hamming_loss(predicted_FD_sets, ground_truth_FD_sets, attributes)

    # Compute Redundancy for HNN
    total_redundant_HNN = sum(count_redundant_dependencies(fd_set) for fd_set in predicted_FD_sets)
    left_redundancy_HNN = sum(left_reduction_loss(fd_set) for fd_set in predicted_FD_sets)
    right_redundancy_HNN = sum(right_reduction_loss(fd_set) for fd_set in predicted_FD_sets)
    not_min_covers_HNN = sum(not is_cover_minimal(fd_set) for fd_set in predicted_FD_sets)
    not_same_closure_HNN = sum(
        not closure_equivalent(pred_fd, gt_fd, attributes)
        for pred_fd, gt_fd in zip(predicted_FD_sets, ground_truth_FD_sets)
    )


    # Compute Join Cost for HNN
    join_cost_hnn = sum(
        compute_join_cost(compute_3NF_decomposition(fd_set, attributes), len(fd_set))
        for fd_set in predicted_FD_sets
    ) / len(predicted_FD_sets)

    # Compute Metrics for FASTR
    hamming_fastr = hamming_loss(fastr_FD_sets, ground_truth_FD_sets, attributes)

    # Compute Redundancy for FASTR
    total_redundant_FASTR = sum(count_redundant_dependencies(fd_set) for fd_set in fastr_FD_sets)
    left_redundancy_FASTR = sum(left_reduction_loss(fd_set) for fd_set in fastr_FD_sets)
    right_redundancy_FASTR = sum(right_reduction_loss(fd_set) for fd_set in fastr_FD_sets)
    not_min_covers_FASTR = sum(not is_cover_minimal(fd_set) for fd_set in fastr_FD_sets)
    not_same_closure_FASTR = sum(
        not closure_equivalent(fastr_fd, gt_fd, attributes)
        for fastr_fd, gt_fd in zip(fastr_FD_sets, ground_truth_FD_sets)
    )


    # Compute Join Cost for FASTR
    join_cost_fastr = sum(
        compute_join_cost(compute_3NF_decomposition(fd_set, attributes), len(fd_set))
        for fd_set in fastr_FD_sets
    ) / len(fastr_FD_sets)

    # Compute Metrics for HNN + FASTR Post-processing
    hamming_hnn_fastr = hamming_loss(hnn_fastr_FD_sets, ground_truth_FD_sets, attributes)

    # Compute Redundancy for HNN + FASTR
    total_redundant_HNN_FASTR = sum(count_redundant_dependencies(fd_set) for fd_set in hnn_fastr_FD_sets)
    left_redundancy_HNN_FASTR = sum(left_reduction_loss(fd_set) for fd_set in hnn_fastr_FD_sets)
    right_redundancy_HNN_FASTR = sum(right_reduction_loss(fd_set) for fd_set in hnn_fastr_FD_sets)
    not_min_covers_HNN_FASTR = sum(not is_cover_minimal(fd_set) for fd_set in hnn_fastr_FD_sets)
    not_same_closure_HNN_FASTR = sum(
        not closure_equivalent(hnn_fastr_fd, gt_fd, attributes)
        for hnn_fastr_fd, gt_fd in zip(hnn_fastr_FD_sets, ground_truth_FD_sets)
    )

    # Compute Join Cost for HNN + FASTR
    join_cost_hnn_fastr = sum(
        compute_join_cost(compute_3NF_decomposition(fd_set, attributes), len(fd_set))
        for fd_set in hnn_fastr_FD_sets
    ) / len(hnn_fastr_FD_sets)

    print("\n🔍 **Evaluation Summary** 🔍")
    print(f"📌 HNN - Hamming Loss: {hamming_hnn:.4f}, Join Cost: {join_cost_hnn:.2f}")
    print(f"📌 FASTR - Hamming Loss: {hamming_fastr:.4f}, Join Cost: {join_cost_fastr:.2f}")
    print(f"📌 HNN + FASTR - Hamming Loss: {hamming_hnn_fastr:.4f}, Join Cost: {join_cost_hnn_fastr:.2f}")    

    print("\n📊 **Redundancy Analysis** 📊")
    print(f"📌 HNN - Avg Length Change in covers: {hnn_avg_len_change}")
    print(f"📌 HNN - Not Minimal Covers: {not_min_covers_HNN}")
    print(f"📌 HNN - Not Same Closure: {not_same_closure_HNN}")
    print(f"📌 HNN - Total Redundant Dependencies: {total_redundant_HNN}")
    print(f"📌 HNN - Left Redundancy Loss: {left_redundancy_HNN}")
    print(f"📌 HNN - Right Redundancy Loss: {right_redundancy_HNN}")
    print(f"📌 FASTR - Avg Length Change in covers: {fastr_avg_len_change}")
    print(f"📌 FASTR - Not Minimal Covers: {not_min_covers_FASTR}")
    print(f"📌 FASTR - Not Same Closure: {not_same_closure_FASTR}")
    print(f"📌 FASTR - Total Redundant Dependencies: {total_redundant_FASTR}")
    print(f"📌 FASTR - Left Redundancy Loss: {left_redundancy_FASTR}")
    print(f"📌 FASTR - Right Redundancy Loss: {right_redundancy_FASTR}")
    print(f"📌 HNN + FASTR - Avg Additional Length Change in covers: {hnn_fastr_avg_len_change}")
    print(f"📌 HNN + FASTR - Not Minimal Covers: {not_min_covers_HNN_FASTR}")
    print(f"📌 HNN + FASTR - Not Same Closure: {not_same_closure_HNN_FASTR}")
    print(f"📌 HNN + FASTR - Total Redundant Dependencies: {total_redundant_HNN_FASTR}")
    print(f"📌 HNN + FASTR - Left Redundancy Loss: {left_redundancy_HNN_FASTR}")
    print(f"📌 HNN + FASTR - Right Redundancy Loss: {right_redundancy_HNN_FASTR}")

    print(f"\n⚡ **Inference Speed Comparison** ⚡")
    print(f"🚀 HNN Avg Time: {hnn_time:.4f} seconds")
    print(f"🛠️ FASTR Avg Time: {fastr_time:.4f} seconds")
    print(f"🔥 HNN + FASTR Post-processing Avg Additional Time: {hnn_fastr_time:.4f} seconds")  

    plt.plot(model.loss_history)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.title("HGNN Training Loss")
    plt.show()

def train_and_evaluate(num_attributes=10,
                       num_fds=10,
                       max_edge_size=5,
                       max_rhs_size=3,
                       train_size=100,
                       val_size=20,
                       test_size=20,
                       epochs=50,
                       hidden_dim=16,
                       lr=1e-2):
    """
    Main pipeline: generate synthetic FD sets, create bipartite data,
    train an HNN to classify which FDs are in the minimal cover,
    then evaluate predictions by reconstructing FD sets and checking closure.
    """
    # Generate random FDs for dataset
    train_hypergraphs = [
        generate_acyclic_hypergraph(
            num_attributes, num_fds, max_edge_size, max_rhs_size
        )
        for _ in range(train_size)
    ]
    val_hypergraphs = [
        generate_acyclic_hypergraph(
            num_attributes, num_fds, max_edge_size, max_rhs_size
        )
        for _ in range(val_size)
    ]
    test_hyerpgraphs = [
        generate_acyclic_hypergraph(
            num_attributes, num_fds, max_edge_size, max_rhs_size
        )
        for _ in range(test_size)
    ]

    # Build (attributes, labeled_fds) for each hypergraph
    minimal_train_FDs = [minimal_cover(graph) for graph in train_hypergraphs]
    minimal_val_FDs = [minimal_cover(graph) for graph in val_hypergraphs]
    minimal_test_FDs = [minimal_cover(graph) for graph in test_hyerpgraphs]

    train_edge_indices = [encode_hypergraph_torch(graph, num_attributes) for graph in train_hypergraphs]

    # Initialize model + optimizer
    model = HNN(in_channels=num_attributes, hidden_channels=hidden_dim, out_channels=1)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    # Train model
    train_hnn(model, train_hypergraphs, optimizer, criterion, num_attributes=num_attributes, epochs=epochs)

    # Evaluate model
    evaluate_HNN(model, val_hypergraphs, minimal_val_FDs, num_attributes)
    
################################################
# Pipeline Execution
################################################

train_and_evaluate(
    num_attributes=50,
    num_fds=100,
    max_edge_size=10,
    max_rhs_size=10,
    train_size=200,
    val_size=10,
    test_size=50,
    epochs=100,
    hidden_dim=64,
    lr=1e-2
)

Epoch 0, Avg Train Loss: 0.0207507331
