In [None]:
import networkx as nx

def load_query(query_load_path):
	file = open(query_load_path)
	nodes_list = []
	edges_list = []
	label_cnt = 0

	for line in file:
		if line.strip().startswith("v"):
			tokens = line.strip().split()
			# v nodeID labelID
			id = int(tokens[1])
			tmp_labels = [int(tokens[2])] # (only one label in the query node)
			#tmp_labels = [int(token) for token in tokens[2 : ]]
			labels = [] if -1 in tmp_labels else tmp_labels
			label_cnt += len(labels)
			nodes_list.append((id, {"labels": labels}))

		if line.strip().startswith("e"):
			# e srcID dstID labelID1 labelID2....
			tokens = line.strip().split()
			src, dst = int(tokens[1]), int(tokens[2])
			tmp_labels = [int(tokens[3])]
			#tmp_labels = [int(token) for token in tokens[3 : ]]
			labels = [] if -1 in tmp_labels else tmp_labels
			edges_list.append((src, dst, {"labels": labels}))

	query = nx.Graph()
	query.add_nodes_from(nodes_list)
	query.add_edges_from(edges_list)
	file.close()
	return query

In [None]:
# from graph_operation import match_bfs
import subprocess
from collections import defaultdict
from copy import deepcopy
import time
import os

encoding = 'utf-8'

def load_g_graph(g_file):
    nid = list()
    nlabel = list()
    nindeg = list()
    elabel = list()
    e_u = list()
    e_v = list()
    with open(g_file) as f2:
        num_nodes = int(f2.readline().rstrip())
        v_neigh = list()
        for i in range(num_nodes):
            temp_list = list()
            v_neigh.append(temp_list)
        for i in range(num_nodes):
            node_info = f2.readline()
            node_id, node_label = node_info.rstrip().split()
            nid.append(int(node_id))
            nlabel.append(int(node_label))
        while True:
            line = f2.readline()
            # read until the end of the file.
            if not line:
                break
            temp_indeg = int(line.strip())
            nindeg.append(temp_indeg)
            if temp_indeg == 0:
                continue
            for i in range(temp_indeg):
                edge_info = f2.readline().rstrip().split()
                if len(edge_info) == 2:
                    edge_label = 1
                else:
                    edge_label = int(edge_info[-1])
                e_u.append(int(edge_info[0]))
                e_v.append(int(edge_info[1]))
                v_neigh[int(edge_info[0])].append(int(edge_info[1]))
                # v_neigh[int(edge_info[1])].append(int(edge_info[0]))
                elabel.append(edge_label)
    g_nid = deepcopy(nid)
    g_nlabel = deepcopy(nlabel)
    g_indeg = deepcopy(nindeg)
    g_edges = [deepcopy(e_u), deepcopy(e_v)]
    g_elabel = deepcopy(elabel)
    g_v_neigh = deepcopy(v_neigh)
    g_label_dict = defaultdict(list)
    for i in range(len(g_nlabel)):
        g_label_dict[g_nlabel[i]].append(i)
    graph_info = [
        g_nid,
        g_nlabel,
        g_indeg,
        g_edges,
        g_elabel,
        g_v_neigh,
        g_label_dict
    ]
    return graph_info


def load_p_data(p_file):
    nid = list()
    nlabel = list()
    nindeg = list()
    elabel = list()
    e_u = list()
    e_v = list()

    with open(p_file) as f1:
        num_nodes = int(f1.readline().rstrip())
        v_neigh = list()
        for i in range(num_nodes):
            temp_list = list()
            v_neigh.append(temp_list)
        for i in range(num_nodes):
            node_info = f1.readline()
            node_id, node_label = node_info.rstrip().split()
            # print(node_id)
            # print(node_label)
            nid.append(int(node_id))
            nlabel.append(int(node_label))
        while True:
            line = f1.readline()
            # read until the end of the file.
            if not line:
                break
            temp_indeg = int(line.strip())
            nindeg.append(temp_indeg)
            if temp_indeg == 0:
                continue
            for i in range(temp_indeg):
                edge_info = f1.readline().rstrip().split()
                if len(edge_info) == 2:
                    edge_label = 1
                else:
                    edge_label = int(edge_info[-1])
                e_u.append(int(edge_info[0]))
                e_v.append(int(edge_info[1]))
                v_neigh[int(edge_info[0])].append(int(edge_info[1]))
                # v_neigh[int(edge_info[1])].append(int(edge_info[0]))
                elabel.append(edge_label)
    p_nid = deepcopy(nid)
    p_nlabel = deepcopy(nlabel)
    p_indeg = deepcopy(nindeg)
    p_edges = [deepcopy(e_u), deepcopy(e_v)]
    p_elabel = deepcopy(elabel)
    p_v_neigh = [deepcopy(v_list) for v_list in v_neigh]
    p_label_dict = defaultdict(list)
    for i in range(len(p_nlabel)):
        p_label_dict[p_nlabel[i]].append(i)
    pattern_info = [
        p_nid,
        p_nlabel,
        p_indeg,
        p_edges,
        p_elabel,
        p_v_neigh,
        p_label_dict
    ]
    return pattern_info

def load_graph(file_path):
    # data information contains:
    # return [0: id 1: label 2: degree 3: edge_info 4: edge_label 5: vertex neighbor 6: label_dict]
    nid = list()
    nlabel = list()
    nindeg = list()
    elabel = list()
    e_u = list()
    e_v = list()
    v_neigh = list()
    
    with open(file_path) as f:
        for line in f.readlines():
            line = line.strip()
            if line.startswith('v'):
                tokens = line.split()
                id = int(tokens[1])
                label = int(tokens[2])
                degree = int(tokens[3])
                nid.append(id)
                nlabel.append(label)
                nindeg.append(degree)
                v_neigh.append(list())
            if line.startswith('e'):
                tokens = line.split()
                src = int(tokens[1])
                dst = int(tokens[2])
                label = 0
                if len(tokens) > 3:
                    label = int(tokens[3])
                elabel.append(label)
                e_u.append(src)
                e_v.append(dst)
                v_neigh[src].append(dst)
                v_neigh[dst].append(src)
    p_nid = deepcopy(nid)
    p_nlabel = deepcopy(nlabel)
    p_indeg = deepcopy(nindeg)
    p_edges = [deepcopy(e_u), deepcopy(e_v)]
    p_elabel = deepcopy(elabel)
    p_v_neigh = [deepcopy(v_list) for v_list in v_neigh]
    p_label_dict = defaultdict(list)
    for i in range(len(p_nlabel)):
        p_label_dict[p_nlabel[i]].append(i)
    graph_info = [
        p_nid,
        p_nlabel,
        p_indeg,
        p_edges,
        p_elabel,
        p_v_neigh,
        p_label_dict
    ]
    return graph_info


def graph_depth(graph, starting_vertex):
    visited = list()
    queue = list()
    visit_count = 0
    depth = 0
    nodes = graph[0]
    num_nodes = len(nodes)
    neigh_info = graph[5]
    queue.append(starting_vertex)
    visited.append(starting_vertex)
    visit_count += 1
    while visit_count<num_nodes:
        current_node = queue.pop(0)
        current_neighbors = neigh_info[current_node]
        for u in current_neighbors:
            if u not in visited:
                queue.append(u)
                visited.append(u)
                visit_count += 1
        depth+=1
        print('depth, visited: ', depth, visited)
    return depth

class Filtering:
    def __init__(self, pattern, data_graph):
        # data information contains:
        # 0: id 1: label 2: degree 3: edge_info 4: edge_label 5: vertex neighbor 6: label_dict
        self.pattern = pattern
        self.data_graph = data_graph

    def GQL_filter(self):
        # generate the candidate set using GraphQL.

        # get candidates by NLF, as initialization.
        local_candidates, candidate_count = self.generate_general_candidates()
        invalid_vertex_id = -1

        # basic information
        query_vertex_num = len(self.pattern[0])
        data_vertex_num = len(self.data_graph[0])
        query_max_degree = max(self.pattern[2])
        data_max_degree = max(self.data_graph[2])

        # generate valid candidate list
        valid_candidate = list()
        for i in range(query_vertex_num):
            temp_list = [False] * data_vertex_num
            for v in local_candidates[i]:
                temp_list[v] = True
            valid_candidate.append(deepcopy(temp_list))

        # global refinement
        for l in range(2):
            for i in range(query_vertex_num):
                query_vertex = i
                for j in range(candidate_count[i]):
                    data_vertex = local_candidates[i][j]
                    if data_vertex == invalid_vertex_id:
                        continue
                    if not self.verify_exact_twig_iso(query_vertex, data_vertex, query_vertex_num, data_vertex_num,
                                                      query_max_degree, data_max_degree, valid_candidate):
                        local_candidates[query_vertex][j] = invalid_vertex_id
                        valid_candidate[query_vertex][data_vertex] = False

        candidates, candidate_count = self.compact_candidate(local_candidates, query_vertex_num)

        return candidates, candidate_count

    def generate_general_candidates(self):
        # generate the candidate using NLF.
        p_label = self.pattern[1]
        p_degree = self.pattern[2]
        g_degree = self.data_graph[2]
        candidates = list()
        candidate_count = [0] * len(p_label)
        for i in range(len(p_label)):    # num of candidate vertices.
            selected_label_vertices = self.get_vertices_by_label(p_label[i])   # select nodes with the same label
            temp_list = []
            for v in selected_label_vertices:
                if g_degree[v] >= p_degree[i]:                                 # check the degree.
                    # check NLF
                    if self.check_NLF(i, v):
                        temp_list.append(v)
                        candidate_count[i] += 1
            candidates.append(deepcopy(temp_list))

        return candidates, candidate_count

    def get_vertices_by_label(self, label):
        selected_vertices = self.data_graph[6][label]
        return selected_vertices

    def verify_exact_twig_iso(self, query_vertex, data_vertex, query_vertex_num, data_vertex_num,
                                                      query_max_degree, data_max_degree, valid_candidates):
        # construct the bipartite graph and determine whether it is valid.
        q_neighbors = self.pattern[5][query_vertex]
        # print(q_neighbors)
        d_neighbors = self.data_graph[5][data_vertex]
        # print(d_neighbors)
        left_partition_size = len(q_neighbors)
        right_partition_size = len(d_neighbors)

        # note that the initial might be wrong.
        left_to_right_offset = [0] * (query_max_degree+1)  # query_max_degree + 1
        left_to_right_edges = [None] * (query_max_degree * data_max_degree)   # query_max_degree * data_max_degree
        left_to_right_match = [None] * query_max_degree   # query_max_degree

        # right_to_left_match = list()   # data_max_degree
        # match_visited = list()         # data_max_degree + 1
        # match_queue = list()           # query_vertex_num
        # match_previous = list()        # data_max_degree + 1

        # print(query_max_degree)
        # print(data_max_degree)
        # print(left_partition_size)
        # print(right_partition_size)
        # build the bipartite graph
        edge_count = 0
        for i in range(left_partition_size):
            query_vertex_neighbor = q_neighbors[i]
            left_to_right_offset[i] = edge_count
            for j in range(right_partition_size):
                data_vertex_neighbor = d_neighbors[j]
                if valid_candidates[query_vertex_neighbor][data_vertex_neighbor]:
                    edge_count+=1
                    # print(edge_count)
                    left_to_right_edges[edge_count] = j
        left_to_right_offset[left_partition_size] = edge_count

        # check if it is a semi-perfect match, process the left_to_right_match, find the ones that are not matched.
        for i in range(left_partition_size):
            if left_to_right_match[i] is None and left_to_right_offset[i] != left_to_right_offset[i+1]:
                for j in range(left_to_right_offset[i], left_to_right_offset[i+1]):
                    if left_to_right_edges[j] is not None:
                        left_to_right_match[i] = left_to_right_edges[j]
                        break
        for i in range(left_partition_size):
            if left_to_right_match[i] is None:
                return False

        return True

    # def is_valid_candidate(self, ):
    def compact_candidate(self, local_candidates, query_vertex_num):
        new_candidates = list()
        new_candidate_count = [0] * query_vertex_num
        for i in range(query_vertex_num):
            query_vertx = i
            temp_list = []
            for j in range(len(local_candidates[query_vertx])):
                if local_candidates[query_vertx][j] != -1:
                    temp_list.append(local_candidates[query_vertx][j])
                    new_candidate_count[query_vertx] += 1
            new_candidates.append(deepcopy(temp_list))

        return new_candidates, new_candidate_count

    def check_NLF(self, query_vertex, data_vertex):
        query_neighbors = self.pattern[5][query_vertex]
        data_neighbors = self.data_graph[5][data_vertex]
        q_neigh_labels = [self.pattern[1][u] for u in query_neighbors]
        d_neigh_labels = [self.data_graph[1][v] for v in data_neighbors]
        q_label_frequency = defaultdict(lambda:0)
        d_label_frequency = defaultdict(lambda:0)
        for l in q_neigh_labels:
            q_label_frequency[l] += 1
        for l in d_neigh_labels:
            d_label_frequency[l] += 1
        for l in q_neigh_labels:
            # if the query label grequency is greater than data label frequency, return false
            # if there is no label in data (l not in the d_dict) the number is 0 (default dict)
            if q_label_frequency[l] > d_label_frequency[l]:
                return False
        return True

    def update_query(self, pattern_info):
        self.pattern = pattern_info

    def cpp_GQL(self, query_graph_file, data_graph_file):
        num_query_vertices = len(self.pattern[0])
        # base_command = ['/data/hancwang/Scalable Neural Subgraph Counting/Related_work/SubgraphMatching-master/build/filter/SubgraphMatching.out', '-d', data_graph_file, '-q', query_graph_file, '-filter', 'GQL']
        os.environ['LD_LIBRARY_PATH'] = '/home/lxhq/Documents/workspace/NeurSC/build_with_subgraph/graph:/home/lxhq/Documents/workspace/NeurSC/build_with_subgraph/utility'
        base_command = ['build_with_subgraph/filter/SubgraphMatching.out', '-d', data_graph_file, '-q', query_graph_file, '-filter', 'GQL']
        output = subprocess.run(base_command, capture_output=True)
        baseline_visit = output.stdout.decode(encoding).split('\n')
        # print(baseline_visit)
        candidates = list()
        candidate_count = list()

        for i in range(len(baseline_visit)):
            if 'Candidate set is:' in baseline_visit[i]:
                candidate_info = baseline_visit[i+1: i+2*num_query_vertices+1]
            elif 'Candidate set version:' in baseline_visit[i]:
                candidates = baseline_visit[i+1].split()
                for j in range(len(candidates)):
                    candidates[j] = int(candidates[j])
            elif 'Subgraph List is :' in baseline_visit[i]:
                induced_subgraph_list = baseline_visit[i+1].split()
                for j in range(len(induced_subgraph_list)):
                    induced_subgraph_list[j] = int(induced_subgraph_list[j])
            elif 'Offset is :' in baseline_visit[i]:
                neighbor_offset = baseline_visit[i+1].split()
                for j in range(len(neighbor_offset)):
                    neighbor_offset[j] = int(neighbor_offset[j])
            elif 'Filter vertices' in baseline_visit[i]:
                print(baseline_visit[i])
        # print(what_we_need)
        for i in range(len(candidate_info)):
            if i%2 == 0:
                candidate_count.append(int(candidate_info[i]))
        
        return candidates, candidate_count, induced_subgraph_list, neighbor_offset, candidate_info

class SampleSubgraph:
    def __init__(self, query, data_graph):
        # data information contains:
        # 0: id 1: label 2: degree 3: edge_info 4: edge_label 5: vertex neighbor 6: label_dict
        self.query = query
        self.data_graph = data_graph

    def find_subgraph(self, start_query_vertex, candidates):
        output_vertices = list()
        output_v_label = list()
        output_degree = list()
        output_edges = list()
        output_edge_label = list()
        output_v_neigh = list()
        depth = graph_depth(self.query, start_query_vertex)
        candidate_u = candidates[start_query_vertex]
        all_candidates = list()
        for i in range(len(candidates)):
            for j in range(len(candidates[i])):
                all_candidates.append(candidates[i][j])
        all_candidates = list(set(all_candidates))
        # print(all_candidates)
        data_label = self.data_graph[1]
        data_neigh = self.data_graph[5]
        # two possible ways: 1. start from all candidates and perform BFS search
        # 2. when the candidate is visited, we don't do the search starting from that node.
        all_need_visited = deepcopy(candidate_u)
        while len(all_need_visited) > 0:
            search_depth = 0
            queue = list()
            depth_queue = list()
            new_graph_vertices = list()
            new_graph_v_label = dict()
            new_graph_v_degree = defaultdict(lambda : 0)
            new_e_u = list()
            new_e_v = list()
            new_edge_label = list()
            new_graph_v_neigh = defaultdict(list)
            start_data_vertex = all_need_visited.pop(0)
            queue.append(start_data_vertex)
            depth_queue.append(search_depth)
            new_graph_vertices.append(start_data_vertex)
            new_graph_v_label[start_data_vertex] = data_label[start_data_vertex]
            while len(queue)>0:
                current_data_vertex = queue.pop(0)
                search_depth = depth_queue.pop(0)
                if search_depth > depth:
                    break
                for v in data_neigh[current_data_vertex]:
                    if v in all_need_visited:
                        all_need_visited.remove(v)
                    if v not in new_graph_vertices and v in all_candidates:
                        new_graph_vertices.append(v)
                        new_graph_v_label[v] = data_label[v]
                        queue.append(v)
                        depth_queue.append(search_depth+1)
                        for neigh_v in data_neigh[v]:
                            if neigh_v in new_graph_vertices:
                                # two way (undirected) edges
                                new_e_u.append(v)
                                new_e_v.append(neigh_v)
                                new_e_u.append(neigh_v)
                                new_e_v.append(v)
                                new_graph_v_degree[v] += 1
                                new_graph_v_degree[neigh_v] += 1

                                # neighbor should be added only once.
                                # new_graph_v_neigh[v].append(neigh_v)
                                new_graph_v_neigh[neigh_v].append(v)

                                new_edge_label.append(1)
                                new_edge_label.append(1)
            new_graph_edges = [new_e_u, new_e_v]
            output_vertices.append(new_graph_vertices)
            output_v_label.append(deepcopy(new_graph_v_label))
            output_degree.append(deepcopy(new_graph_v_degree))
            output_edges.append(new_graph_edges)
            output_edge_label.append(new_edge_label)
            output_v_neigh.append(deepcopy(new_graph_v_neigh))
        return output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh

    def find_subgraph_induced(self, candidates):
        t_0 = time.time()
        all_candidates = list()
        for i in range(len(candidates)):
            for j in range(len(candidates[i])):
                all_candidates.append(candidates[i][j])
        all_candidates = list(set(all_candidates))
        all_need_visited = deepcopy(all_candidates)
        queue = list()
        depth_queue = list()
        new_graph_vertices = list()
        new_graph_v_label = dict()
        new_graph_v_degree = defaultdict(lambda : 0)
        new_e_u = list()
        new_e_v = list()
        new_edge_label = list()
        new_graph_v_neigh = defaultdict(list)

        # get data graph information
        data_label = self.data_graph[1]
        data_edge = self.data_graph[3]
        data_neigh = self.data_graph[5]
        
        new_graph_vertices = deepcopy(all_candidates)
        for v in new_graph_vertices:
            new_graph_v_label[v] = data_label[v]
        t_1 = time.time()
        print('sample satage 1: {}s'.format(t_1-t_0))
        # for i in range(len(data_edge[0])):
        #     # if two nodes are both in candidate set, the edge is included for new graph
        #     u = data_edge[0][i]
        #     v = data_edge[1][i]
        #     if u in all_candidates and v in all_candidates:
        #         new_e_u.append(u)
        #         new_e_v.append(v)
        #         # only add once, since the edge will appear twice.
        #         new_graph_v_degree[u] += 1
        #         new_graph_v_neigh[u].append(v)
        #         new_edge_label.append(1)

        for vertex in new_graph_vertices:
            # if two nodes are both in candidate set, the edge is included for new graph
            neigh_of_v = data_neigh[vertex]
            for u in neigh_of_v:
                if u in all_candidates:
                    new_e_u.append(u)
                    new_e_v.append(vertex)
                    # only add once, since the edge will appear twice.
                    new_graph_v_degree[vertex] += 1
                    new_graph_v_neigh[vertex].append(u)
                    new_edge_label.append(1)

        t_2 = time.time()
        print('sample stage 2: {}s'.format(t_2-t_1))
        new_edges = [deepcopy(new_e_u), deepcopy(new_e_v)]
        new_vertices = new_graph_vertices
        new_v_label = new_graph_v_label
        new_degree = deepcopy(new_graph_v_degree)
        new_edge_label = deepcopy(new_edge_label)
        new_v_neigh = new_graph_v_neigh

        check_info = [new_vertices, new_v_label, new_degree, new_edges, new_edge_label, new_v_neigh]
        output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh = self._split_graph(check_info)
        t_3 = time.time()
        print('sample stage 3: {}s'.format(t_3-t_2))
        # output_graph_info = [output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh]

        return output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh

    def load_induced_subgraph(self, candidates, induced_subgraph_list, neighbor_offset):
        queue = list()
        depth_queue = list()
        new_graph_vertices = list()
        new_graph_v_label = dict()
        new_graph_v_degree = defaultdict(lambda : 0)
        new_e_u = list()
        new_e_v = list()
        new_edge_label = list()
        new_graph_v_neigh = defaultdict(list)

        # get data graph information
        data_label = self.data_graph[1]
        data_edge = self.data_graph[3]
        data_neigh = self.data_graph[5]

        new_graph_vertices = deepcopy(candidates)
        for v in new_graph_vertices:
            new_graph_v_label[v] = data_label[v]

        for i in range(len(candidates)):
            vertex = candidates[i]
            strat_index = neighbor_offset[i]
            end_index = neighbor_offset[i+1]
            for j in range(strat_index, end_index):
                u = induced_subgraph_list[j]
                new_e_u.append(u)
                new_e_v.append(vertex)
                # only add once, since the edge will appear twice.
                new_graph_v_degree[vertex] += 1
                new_graph_v_neigh[vertex].append(u)
                new_edge_label.append(1)
        
        new_edges = [deepcopy(new_e_u), deepcopy(new_e_v)]
        new_vertices = new_graph_vertices
        new_v_label = new_graph_v_label
        new_degree = deepcopy(new_graph_v_degree)
        new_edge_label = deepcopy(new_edge_label)
        new_v_neigh = new_graph_v_neigh

        check_info = [new_vertices, new_v_label, new_degree, new_edges, new_edge_label, new_v_neigh]
        output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh = self._split_graph(check_info)

        return output_vertices, output_v_label, output_degree, output_edges, output_edge_label, output_v_neigh

    def _split_graph(self, graph_info):
        vertices_id = graph_info[0]
        vertices_label = graph_info[1]
        vertices_neighbor = graph_info[5]
        num_vertices = len(vertices_id)
        to_be_visited = deepcopy(vertices_id)
        

        # initialize the output lists
        output_vertices = list()
        output_v_label = list()
        output_v_degree = list()
        output_edges = list()
        output_e_label = list()
        output_v_neigh = list()

        while len(to_be_visited) > 0:
            # initialize the temp containers.
            out_temp_vertices = list()
            out_temp_v_label = dict()
            out_temp_v_degree = defaultdict(lambda: 0)
            out_temp_e_u = list()
            out_temp_e_v = list()
            out_temp_e_label = list()
            out_temp_v_neigh = defaultdict(list)

            start_node = to_be_visited[0]
            # to_be_visited.remove(start_node)
            queue = list()
            queue.append(start_node)

            while len(queue) > 0:
                current_node = queue.pop(0)
                current_neighbors = vertices_neighbor[current_node]
                try:
                    to_be_visited.remove(current_node)
                except ValueError:
                    # print('node {} has been removed'.format(current_node))
                    continue   # if there is no this node in the to be visited set, we donot need to compute it again. will lead to bugs.  
                out_temp_vertices.append(current_node)        
                out_temp_v_label[current_node] = vertices_label[current_node]         
                for v in current_neighbors:
                    # a BFS, do we need to check whether it is in the to be visited set? (in Queue it should!)
                    out_temp_e_u.append(current_node)
                    out_temp_e_v.append(v)       # add a one-way edge, it will be added again.
                    out_temp_e_label.append(1)   # edge label is always 1.
                    out_temp_v_degree[current_node] += 1
                    out_temp_v_neigh[current_node].append(v)
                    if v in to_be_visited:
                        queue.append(v)

            output_vertices.append(deepcopy(out_temp_vertices))
            output_v_label.append(deepcopy(out_temp_v_label))
            output_v_degree.append(deepcopy(out_temp_v_degree))
            output_edges.append([deepcopy(out_temp_e_u), deepcopy(out_temp_e_v)])
            output_e_label.append(deepcopy(out_temp_e_label))
            output_v_neigh.append(deepcopy(out_temp_v_neigh))
        
        return output_vertices, output_v_label, output_v_degree, output_edges, output_e_label, output_v_neigh

    def update_query(self, query):
        self.query = query



In [None]:
import networkx as nx

if __name__ == '__main__':
    data_graph_path = '../dataset/yeast/data_graph/yeast.graph'
    query_graph_path = '../dataset/yeast/query_graph/query_dense_4_1.graph'
    data_graph = load_graph(data_graph_path)
    query_graph = load_graph(query_graph_path)
    filter = Filtering(query_graph, data_graph)
    subgraph_sampler = SampleSubgraph(query_graph, data_graph)
    candidates, candidate_count, induced_subgraph_list, neighbor_offset, candidate_info = filter.cpp_GQL(query_graph_path, data_graph_path)
    starting_vertex = candidate_count.index(min(candidate_count))
    starting_vertex = 0
    print('candidates', candidates)
    print('candidate_count', candidate_count)
    print('induced_subgraph_list', induced_subgraph_list)
    print('neighbor_offset', neighbor_offset)
    print('candidate_info', candidate_info)
    print('starting_vertex', starting_vertex)
    
    vertices_candidates = list()
    for i in range(len(candidate_info)):
        if i % 2 == 1:
            vertices_candidates.append(list(map(int, candidate_info[i].split())))
    new_vertices, new_v_label, new_degree, new_edges, new_e_label, new_v_neigh = subgraph_sampler.find_subgraph(starting_vertex, candidates=vertices_candidates)
    print(new_vertices)
    print(new_v_label)
    print(new_degree)
    print(new_edges)
    print(new_e_label)
    print(new_v_neigh)

In [None]:
# remove query files without a true count

import os
import shutil

def load_true_card(true_card_path):
    res = set()
    with open(true_card_path, 'r') as f:
        for line in f.readlines():
            tokens = line.split(',')
            name = tokens[0]
            res.add(name)
    return res

true_path = 'data/yeast/query_graph.csv'
true_cards = load_true_card(true_path)

for file in os.listdir('data/yeast/query_graph/'):
    name = file.split('.')[0]
    if name not in true_cards:
        os.remove('data/yeast/query_graph/' + file)

In [None]:
# clean the outputs env

import shutil
import os

shutil.rmtree('saved_models', ignore_errors=True)
shutil.rmtree('saved_params', ignore_errors=True)
shutil.rmtree('saved_results', ignore_errors=True)

os.mkdir('saved_models/')
os.mkdir('saved_params/')
os.mkdir('saved_results/')

shutil.rmtree('outputs', ignore_errors=True)
os.mkdir('outputs/')
os.mkdir('outputs/yeast/')

In [None]:
# draw figures for NeurSC results

import seaborn as sns
import pandas as pd
import math
import os
import matplotlib.pyplot as plt

def read_result_file(result_dir, params_dir, result_name):
    # model_data = [['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', 'Query Type', 'Pred', 'Card', 'q-error'], [...], ...]
    name_tokens = result_name.split('_')
    epochs = int(name_tokens[4])
    training_type = None
    training_prec = None
    train_query_suffix = ''
    if name_tokens[7] == 'aug':
        training_prec = name_tokens[9]
        if name_tokens[8] == '1':
            training_type = 'aug_1'
            train_query_suffix = 'All original\nqueries +\n{}% aug 1\nqueries'.format(training_prec)
        else:
            training_type = 'aug_2'
            train_query_suffix = 'All original\nqueries +\nall aug 1\nqueries +\n{}% aug 2\nqueries'.format(training_prec)
    else:
        training_type = 'original'
        training_prec = name_tokens[7]
        train_query_suffix = '{}% original\nqueries'.format(training_prec)
    res = []

    train_query_number = 0
    with open(params_dir + result_name, 'r') as f:
        for line in f.readlines():
            if line.startswith('training query number'):
                train_query_number = int(line.split(':')[1].strip())
                train_query_number = str(train_query_number) + '\n' + train_query_suffix
                break

    with open(result_dir + result_name, 'r') as f:
        for line in f.readlines()[1:]:
            line_tokens = line.split()
            query_name_tokens = line_tokens[0].split('.')[0].split('_')
            query_size = int(query_name_tokens[2])
            query_type = query_name_tokens[1]
            pred = float(line_tokens[2])
            card = float(line_tokens[3])
            if pred < card:
                q_error = -math.log10(float(line_tokens[1]))
            else:
                q_error = math.log10(float(line_tokens[1]))
            res.append([epochs, train_query_number, training_type + '_' + training_prec, query_size, query_type, pred, card, q_error])
    return res

def draw_box_plot(dataframe, title, fontsize=16):
    sns.set(rc={'figure.figsize':(20.7,8.27)})
    bp = sns.boxplot(data=df, x='Train Query Number', y='q-error', whis=[1, 99])
    vertical_lines = [8.5, 11.5]  # List of x-axis values where you want to add vertical line
    for line in vertical_lines:
        plt.axvline(x=line, color='red', linestyle='--')
    plt.grid(visible=True, linestyle='--')
    plt.title(title, fontsize=fontsize)
    plt.axhline(0, color='green',linestyle='dashed')
    plt.ylabel('Under estimate <--- q-error ---> Over estimate', fontsize=fontsize)
    plt.xlabel('Number of training queries', fontsize=fontsize)
    if title == 'yeast':
        plt.yticks(ticks=[-6, -4, -2, 0, 2, 4, 6], 
                   labels=['$10^6$', '$10^4$', '$10^2$', '$10^0$', '$10^2$', '$10^4$', '$10^6$'])
    elif title == 'youtube':    
        plt.yticks(ticks=[-8, -6, -4, -2, 0, 2, 4, 6, 8], 
                   labels=['$10^8$', '$10^6$', '$10^4$', '$10^2$', '$10^0$', '$10^2$', '$10^4$', '$10^6$', '$10^8$'])
    elif title == 'wordnet':
        plt.yticks(ticks=[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], 
                   labels=['$10^5$', '$10^4$', '$10^3$', '$10^2$', '$10^1$', '$10^0$', '$10^1$', '$10^2$', '$10^3$', '$10^4$', '$10^5$'])
    elif title == 'eu2005':
        plt.yticks(ticks=[-3, -2, -1, 0, 1, 2, 3], 
                   labels=['$10^3$', '$10^2$', '$10^1$', '$10^0$', '$10^1$', '$10^2$', '$10^3$'])
    else:
        raise Exception('Not recognized dataset')
    plt.ylabel('under estimate <--- q-error ---> over estimate')
    plt.xlabel('number of training queries')
    pcntls = df.groupby('Train Query Number')['q-error'].describe(percentiles=[0.1, 0.9])
    pcntls = pcntls.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    display(pcntls)
    columns = len(pcntls['10%'])
    bp.scatter(data=pcntls, x=range(columns), y='10%', marker='x')
    bp.scatter(data=pcntls, x=range(columns), y='90%', marker='x')
    plt.show()


def draw_line_plot(df, title, fontsize=16):
    medians = df.groupby(by=['Train Query Number', 'Epochs'])['q-error'].median().reset_index(name='median q-error')
    medians = medians.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    medians['median q-error'] = medians['median q-error'].abs()
    sns.lineplot(data=medians, x='Train Query Number', y='median q-error')
    plt.ylabel('log10(median q-error)', fontsize=fontsize)
    plt.title(title, fontsize=fontsize)
    plt.xlabel('Number of training queries', fontsize=fontsize)
    plt.tick_params(axis='y', labelsize=fontsize)
    plt.show()

if __name__ == '__main__':
    data_graph = 'yeast'
    result_dir = 'saved_results/'
    params_dir = 'saved_params/'
    # model_data = [['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', 'Query Type', 'Pred', 'Card', 'q-error'], [...], ...]
    model_data = []
    for file in os.listdir(result_dir):
        if file.startswith(data_graph):
            model_data.extend(read_result_file(result_dir, params_dir, file))
    df = pd.DataFrame(data=model_data, 
                      columns=['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', ' QueryType', 'Pred', 'Card', 'q-error'])
    df = df.loc[df['Epochs'] == 80]
    df = df.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    draw_box_plot(df, data_graph)
    draw_line_plot(df, data_graph)


In [None]:
# draw figures for NeurSC results

import seaborn as sns
import pandas as pd
import math
import os
import matplotlib.pyplot as plt

def read_result_file(result_dir, params_dir, result_name):
    # model_data = [['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', 'Query Type', 'Pred', 'Card', 'q-error'], [...], ...]
    name_tokens = result_name.split('_')
    epochs = int(name_tokens[4])
    training_type = None
    training_prec = None
    train_query_suffix = ''
    if name_tokens[7] == 'aug':
        training_prec = name_tokens[9]
        if name_tokens[8] == '1':
            training_type = 'aug_1'
            train_query_suffix = 'All original\nqueries +\n{}% aug 1\nqueries'.format(training_prec)
        else:
            training_type = 'aug_2'
            train_query_suffix = 'All original\nqueries +\nall aug 1\nqueries +\n{}% aug 2\nqueries'.format(training_prec)
    else:
        training_type = 'original'
        training_prec = name_tokens[7]
        train_query_suffix = '{}% original\nqueries'.format(training_prec)
    res = []

    train_query_number = 0
    with open(params_dir + result_name, 'r') as f:
        for line in f.readlines():
            if line.startswith('training query number'):
                train_query_number = int(line.split(':')[1].strip())
                train_query_number = str(train_query_number) + '\n' + train_query_suffix
                break

    with open(result_dir + result_name, 'r') as f:
        for line in f.readlines()[1:]:
            line_tokens = line.split()
            query_name_tokens = line_tokens[0].split('.')[0].split('_')
            query_size = int(query_name_tokens[2])
            query_type = query_name_tokens[1]
            pred = float(line_tokens[2])
            card = float(line_tokens[3])
            if pred < card:
                q_error = -math.log10(float(line_tokens[1]))
            else:
                q_error = math.log10(float(line_tokens[1]))
            res.append([epochs, train_query_number, training_type + '_' + training_prec, query_size, query_type, pred, card, q_error])
    return res

def draw_box_plot(dataframe, title, fontsize=16):
    sns.set(rc={'figure.figsize':(20.7,8.27)})
    bp = sns.boxplot(data=df, x='Train Query Number', y='q-error', whis=[1, 99])
    vertical_lines = [8.5, 11.5]  # List of x-axis values where you want to add vertical line
    for line in vertical_lines:
        plt.axvline(x=line, color='red', linestyle='--')
    plt.grid(visible=True, linestyle='--')
    plt.title(title, fontsize=fontsize)
    plt.axhline(0, color='green',linestyle='dashed')
    plt.ylabel('Under estimate <--- q-error ---> Over estimate', fontsize=fontsize)
    plt.xlabel('Number of training queries', fontsize=fontsize)
    if title == 'yeast':
        plt.yticks(ticks=[-6, -4, -2, 0, 2, 4, 6], 
                   labels=['$10^6$', '$10^4$', '$10^2$', '$10^0$', '$10^2$', '$10^4$', '$10^6$'])
    elif title == 'youtube':    
        plt.yticks(ticks=[-8, -6, -4, -2, 0, 2, 4, 6, 8], 
                   labels=['$10^8$', '$10^6$', '$10^4$', '$10^2$', '$10^0$', '$10^2$', '$10^4$', '$10^6$', '$10^8$'])
    elif title == 'wordnet':
        plt.yticks(ticks=[-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], 
                   labels=['$10^5$', '$10^4$', '$10^3$', '$10^2$', '$10^1$', '$10^0$', '$10^1$', '$10^2$', '$10^3$', '$10^4$', '$10^5$'])
    elif title == 'eu2005':
        plt.yticks(ticks=[-3, -2, -1, 0, 1, 2, 3], 
                   labels=['$10^3$', '$10^2$', '$10^1$', '$10^0$', '$10^1$', '$10^2$', '$10^3$'])
    else:
        raise Exception('Not recognized dataset')
    plt.ylabel('under estimate <--- q-error ---> over estimate')
    plt.xlabel('number of training queries')
    pcntls = df.groupby('Train Query Number')['q-error'].describe(percentiles=[0.1, 0.9])
    pcntls = pcntls.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    display(pcntls)
    columns = len(pcntls['10%'])
    bp.scatter(data=pcntls, x=range(columns), y='10%', marker='x')
    bp.scatter(data=pcntls, x=range(columns), y='90%', marker='x')
    plt.show()


def draw_line_plot(df, title, fontsize=16):
    medians = df.groupby(by=['Train Query Number', 'Epochs'])['q-error'].median().reset_index(name='median q-error')
    medians = medians.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    medians['median q-error'] = medians['median q-error'].abs()
    sns.lineplot(data=medians, x='Train Query Number', y='median q-error')
    plt.ylabel('log10(median q-error)', fontsize=fontsize)
    plt.title(title, fontsize=fontsize)
    plt.xlabel('Number of training queries', fontsize=fontsize)
    plt.tick_params(axis='y', labelsize=fontsize)
    plt.show()

if __name__ == '__main__':
    data_graph = 'yeast'
    result_dir = 'saved_results/'
    params_dir = 'saved_params/'
    # model_data = [['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', 'Query Type', 'Pred', 'Card', 'q-error'], [...], ...]
    model_data = []
    for file in os.listdir(result_dir):
        if file.startswith(data_graph):
            model_data.extend(read_result_file(result_dir, params_dir, file))
    df = pd.DataFrame(data=model_data, 
                      columns=['Epochs', 'Train Query Number', 'Train Query Number Type', 'Query Size', ' QueryType', 'Pred', 'Card', 'q-error'])
    df = df.loc[df['Epochs'] == 80]
    df = df.sort_values(by='Train Query Number', key=lambda x: x.apply(lambda y: int(y.split('\n')[0])))
    draw_box_plot(df, data_graph)
    draw_line_plot(df, data_graph)


In [None]:
import networkx as nx
import os

if __name__ == '__main__':
    data_set = 'eu2005'
    output_file = open('outputs/research/{}.csv'.format(data_set), 'w')
    data_graph_path = '../dataset/{}/data_graph/{}.graph'.format(data_set, data_set)
    query_graph_dir = '../dataset/{}/query_graph/'.format(data_set)
    data_graph = load_graph(data_graph_path)
    for file in os.listdir(query_graph_dir):
        query_size = int(file.split('_')[2])
        query_graph_path = query_graph_dir + file
        query_graph = load_graph(query_graph_path)
        filter = Filtering(query_graph, data_graph)
        subgraph_sampler = SampleSubgraph(query_graph, data_graph)
        candidates, candidate_count, induced_subgraph_list, neighbor_offset, candidate_info = filter.cpp_GQL(query_graph_path, data_graph_path)
        starting_vertex = candidate_count.index(min(candidate_count))
        starting_vertex = 0
        output_file.write(str(query_size) + ',' + str(len(candidates)) + ',' + str(len(data_graph[0])) + ',' + file + '\n')
    output_file.close()

In [None]:
# interpreting the test results
import os
import numpy as np
import math

def get_prediction_statistics(errors: list):
	lower, upper = np.quantile(errors, 0.25), np.quantile(errors, 0.75)
	print("<" * 80, flush=True)
	print("Predict Result Profile of {} Queries:".format(len(errors)), flush=True)
	print("Min/Max: {:.4f} / {:.4f}".format(np.min(errors), np.max(errors)), flush=True)
	print("Mean: {:.4f}".format(np.mean(errors)), flush=True)
	print("Median: {:.4f}".format(np.median(errors)), flush=True)
	print("25%/75% Quantiles: {:.4f} / {:.4f}".format(lower, upper), flush=True)
	print(">" * 80, flush=True)
	error_median = abs(upper - lower)
	return error_median

results_dir = './saved_results/'
all_results = []
for file in os.listdir(results_dir):
    results = {} # results <- {name: [pred, true]}
    with open(results_dir + file) as f:
        for line in f.readlines():
            if line.startswith('f'):
                continue
            line = line.strip()
            tokens = line.split(' ')
            results[tokens[0]] = [float(tokens[2]), float(tokens[3])]
    all_results.append(results)

for results in all_results:
    log2_q_errors = []
    total_log2_mse_loss = 0
    total_log2_l1_loss = 0
    count = 0
    for pred, card in results.values():
        if pred == 0:
             pred = 1
        pred_log2 = math.log2(pred)
        card_log2 = math.log2(card)
        log2_q_errors.append(pred_log2 - card_log2)
        total_log2_mse_loss += (pred_log2 - card_log2)**2
        total_log2_l1_loss += abs(pred_log2 - card_log2)
        count += 1   
    print("Evaluation result of Eval dataset: Total Loss= {:.4f}, Total L1 Loss= {:.4f}".format(total_log2_mse_loss, total_log2_l1_loss))
    get_prediction_statistics(log2_q_errors)
    print()
