In [None]:
# @ Author: Wenbo Duan
# @ Email: bobbyduanwenbo@live.com
# @ Date: 23, Sep 2023
# @ Function: 
#   1. Execute node elimination on the complete graph according to the ground truth order (generated in MATLAB)
#   2. Record each sub-graph with the label, store them as the raw files before converting to the train/test set


# generate nodel elimination results
import networkx as nx
import pandas as pd
from itertools import combinations
from collections import deque
import matplotlib.pyplot as plt
import os

def extract_graph(graph_id:int, edges_info:pd.DataFrame)->nx.Graph:
    # extract all pairwise nodes info from one graph
    egdes = edges_info[edges_info.graph_id == graph_id].iloc[:,1:].values.tolist()
    # adjust input format for networkx parsing
    egdes = [f'{i} {j}' for i,j in egdes]
    graph = nx.parse_adjlist(egdes)
    return graph

def extract_order(graph_id:int, nodes_info: pd.DataFrame) -> list:
    nodes = nodes_info[nodes_info.graph_id == graph_id].iloc[:,1:]
    nodes = nodes.sort_values(by=['node_label'])
    return nodes.node_id.values.tolist()


class EliminationEngine:
    def __init__(self, graph:nx.Graph, recordON=True) -> None:
        self.graph = graph
        self.graph_clone = graph.copy()
        self.fill_in_count = 0
        self.recordON = recordON

    def _add_fill_in(self, target_node):
        neighbors = self.graph.neighbors(target_node)
        potential_edges = combinations(neighbors,2)
        for edge in potential_edges:
            if edge not in self.graph.edges:
                self.graph.add_edge(edge[0], edge[1])
                if self.recordON:
                    self.fill_in_count += 1
                    self.graph.add_edge(edge[0], edge[1])
                              
    def eliminate(self, target_node):
        self._add_fill_in(target_node)
        self.graph.remove_node(target_node)


class Eliminator(EliminationEngine):
    def __init__(self, graph_id, graph: nx.Graph, eliminate_seq:list, visualize=True, recordON=True) -> None:
        super().__init__(graph, recordON)
        self.nodes = deque(eliminate_seq)
        self.visualize = visualize
        # father id: previous graph id 
        self.father_id = graph_id-1
        # child id: the index of the eliminated graph from the father graph
        self.child_id = 0
    
    def step(self):
        if self.nodes:
            target_node = str(self.nodes.popleft())
            

            if self.visualize:
                plt.figure()
                plt.title(f'Graph ID: {self.father_id}{self.child_id}\neliminating {target_node}')
                nx.draw(self.graph, with_labels=True)
            self.eliminate(target_node)
            self.child_id += 1
        
        elif self.recordON:
            self.summary()
        else:
            print("finished")
      
    def _record(self,edge_save_path, graph_save_path, target_node):
        # save the edge info of the current elimination results
        cur_id = f'{self.father_id}_{self.child_id}'
        cur_graph = self.graph.copy()

        # reset the graph label, reset index starting from 0, detach from the father graph
        rename_label = {nodes: index for index, nodes in enumerate(cur_graph.nodes)}
        nx.relabel_nodes(cur_graph, rename_label,copy=False)

        # save edge_info
        edges = list(cur_graph.edges)
        raw = []
        for source_id, destination_id in edges:
            raw.append({"graph_id":cur_id, 'source_node_id':source_id, 'destination_node_id': destination_id})
        edges_pd = pd.DataFrame(raw)
        if not os.path.exists(edge_save_path):
            edges_pd.to_csv(edge_save_path,mode='w+',header=True, index=False)
        else:
            edges_pd.to_csv(edge_save_path, mode='a+',header=False, index=False)

        # save node info
        print(cur_id)
        raw = {'graph_id':cur_id, 'graph_label':rename_label[str(target_node)]}
        graph_df = pd.DataFrame([raw])
        if not os.path.exists(graph_save_path):
            graph_df.to_csv(graph_save_path,mode='w+',header=True, index=False)
        else:
            graph_df.to_csv(graph_save_path, mode='a+',header=False, index=False)


    def auto_step(self, edge_save_path, graph_save_path):
        while self.nodes:
            target_node = str(self.nodes.popleft())
            self._record(edge_save_path, graph_save_path, target_node)
            self.eliminate(target_node)
            self.child_id += 1
        # if self.recordON:
        #     self.summary()
    
    def summary(self):
        print(f'Global Graph ID: {self.father_id}\nInput Node Size: {len(self.graph_clone.nodes)}\nTotal fill-in: {self.fill_in_count}')


### Loading the original files

In [None]:
edges_path = './data/GC_10/original/edges.csv'
nodes_path = './data/GC_10/original/nodes.csv'

edge_save_path ='./data/GC_10/raw/edges.csv'
graph_save_path ='./data/GC_10/raw/graphes.csv'

edges_info = pd.read_csv(edges_path)
nodes_info = pd.read_csv(nodes_path)

In [None]:
from tqdm import tqdm

graph_ids = nodes_info.graph_id.unique()
for graph_id in tqdm(graph_ids):
    graph = extract_graph(graph_id, edges_info)
    nodes_seq = extract_order(graph_id, nodes_info)
    eliminator = Eliminator(graph_id, graph,nodes_seq)
    eliminator.auto_step(edge_save_path, graph_save_path)

### Single step example

In [None]:
edges_path = './data/GC_10/original/edges.csv'
nodes_path = './data/GC_10/original/nodes.csv'

edges_info = pd.read_csv(edges_path)
nodes_info = pd.read_csv(nodes_path)

graph_id = 1
graph = extract_graph(graph_id, edges_info)
nodes_seq = extract_order(graph_id, nodes_info)
eliminator = Eliminator(graph_id, graph,nodes_seq)

In [None]:
eliminator.step()

In [None]:
eliminator.step()

In [None]:
eliminator.step()

In [None]:
eliminator.step()

In [None]:
eliminator.step()