In [None]:
from BA3_loc import *
import pickle
from tqdm import tqdm
import os
import os.path as osp
import warnings
import numpy as np
import networkx as nx
import torch
from torch_geometric.data import Data
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# Bias value
global_b = 0.9                # Hardly biased base
global_b_slightly = 0.33333       # Equally distributed base

# Generating directories
biased_dir = '../data/I_Hardly_Biased_Base'
os.makedirs(biased_dir, exist_ok=True)

labeleddata_dir = '../data/II_Labeled'
os.makedirs(labeleddata_dir, exist_ok=True)

slightly_biased_data_dir = '../data/III_Equally_Distributed_Base'
os.makedirs(slightly_biased_data_dir, exist_ok=True)


def get_house(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach HOUSE-shaped subgraphs."""
    list_shapes = [["house"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


def get_cycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach cycle-shaped (directed edges) subgraphs."""
    list_shapes = [["dircycle"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


def get_crane(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach crane-shaped subgraphs."""
    list_shapes = [["varcycle"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


# Generating rules for normal graphs (used in Train and Validation)
def graph_stats(base_num):
    if base_num == 1:
        base = 'tree'
        width_basis = np.random.choice(range(2, 3))
    elif base_num == 2:
        base = 'ladder'
        width_basis = np.random.choice(range(3, 4))
    elif base_num == 3:
        base = 'wheel'
        width_basis = np.random.choice(range(6, 7))
    return base, width_basis

# Generating rules for large graphs (used in Test)
def graph_stats_large(base_num):
    if base_num == 1:
        base = 'tree'
        width_basis = np.random.choice(range(3, 4))
    elif base_num == 2:
        base = 'ladder'
        width_basis = np.random.choice(range(5, 6))
    elif base_num == 3:
        base = 'wheel'
        width_basis = np.random.choice(range(8, 9))
    return base, width_basis


# For labeled data
def compute_node_labels(role_id):
    role_arr = np.array(role_id).astype(int)
    node_labels = (role_arr > 0).astype(np.int64)
    return node_labels


def create_data_list(nb_graphs, get_graph_func, bias, large=False):
    edge_index_list, label_list, ground_truth_list, role_id_list, pos_list = [], [], [], [], []
    node_label_list = []
    e_mean, n_mean = [], []

    for _ in tqdm(range(nb_graphs)):
        if not large:
            base_num = np.random.choice([1, 2, 3], p=[bias, (1 - bias) / 2, (1 - bias) / 2])
        else:
            base_num = np.random.choice([1, 2, 3])
        base, width_basis = graph_stats_large(base_num) if large else graph_stats(base_num)

        G, role_id, name = get_graph_func(basis_type=base, nb_shapes=1,width_basis=width_basis, feature_generator=None,m=3, draw=False)

        if get_graph_func == get_house:
            label = 1
        elif get_graph_func == get_cycle:
            label = 0
        elif get_graph_func == get_crane:
            label = 0
        else:
            label = -1

        label_list.append(label)
        e_mean.append(len(G.edges))
        n_mean.append(len(G.nodes))

        role_id = np.array(role_id)
        role_id_list.append(role_id)

        edge_index = np.array(list(G.edges), dtype=np.int64).T
        if edge_index.size == 0:
            edge_index = np.empty((2, 0), dtype=np.int64)

        edge_index_list.append(edge_index)

        if len(G.nodes) > 0:
            pos_dict = nx.spring_layout(G)
            pos_array = np.array([pos_dict[i] for i in sorted(G.nodes())])
        else:
            pos_array = np.empty((0, 2))
        pos_list.append(pos_array)

        ground_truth_list.append(find_gd(edge_index, role_id))
        node_labels = compute_node_labels(role_id)
        node_label_list.append(node_labels)

    print("#Graphs: %d  #Nodes(avg): %.2f  #Edges(avg): %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
    return edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list


def save_as_pt(edge_index_list, label_list, node_label_list, filename, unlabeled=False):
    data_list = []
    for i in range(len(edge_index_list)):
        ei = edge_index_list[i]
        edge_index = torch.tensor(ei, dtype=torch.long) if ei.size != 0 else torch.empty((2, 0), dtype=torch.long)

        num_nodes = len(node_label_list[i]) if node_label_list[i].size > 0 else 0

        if unlabeled:
            x = torch.zeros((num_nodes, 1), dtype=torch.float32)
        else:
            x = torch.tensor(node_label_list[i], dtype=torch.float32).view(num_nodes, 1)

        y = torch.tensor([label_list[i]], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

    torch.save(data_list, filename)
    print(f"Saved {len(data_list)} graphs to {filename}")


# TRAINING: Hardly Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'train.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'train.pt'), unlabeled=True)

# VALIDATION: Hardly Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'val.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'val.pt'), unlabeled=True)

# TEST: Hardly Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, 0.333, large=True)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, 0.333, large=True)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, 0.333, large=True)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'test.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'test.pt'), unlabeled=True)


# TRAINING: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b_slightly)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b_slightly)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b_slightly)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'train.pt'), unlabeled=True)

# VALIDATION: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b_slightly)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b_slightly)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b_slightly)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'val.pt'), unlabeled=True)

# TEST: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, 0.333)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, 0.333)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, 0.333)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'test.pt'), unlabeled=True)

print("\nData generation completed!")


In [None]:
import torch

unlabeled_data_path = '../datatest/I_Hardly_Biased_Base/test.pt'
data_list = torch.load(unlabeled_data_path, weights_only=False)

for i, data in enumerate(data_list[:5]):
    print(f"Graph {i}:")
    print(data)  # Data(x=[N, 1], edge_index=[2, E], y=[1])
    print(f"x shape: {data.x}")
    print(f"edge_index shape: {data.edge_index.shape}")
    print(f"y shape: {data.y.shape}")
    print("---")
