In [None]:
from BA3_loc import *
import pickle
from tqdm import tqdm
import os
import os.path as osp
import warnings
import numpy as np
import networkx as nx
import torch
from torch_geometric.data import Data
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# Bias value
global_b = 0.9                # Heavily biased base
global_b_slightly = 0.33333       # Equally distributed base

# Generating directories
biased_dir = '../data/I_Heavily_Biased_Base'
os.makedirs(biased_dir, exist_ok=True)

labeleddata_dir = '../data/II_Labeled'
os.makedirs(labeleddata_dir, exist_ok=True)

slightly_biased_data_dir = '../data/III_Equally_Distributed_Base'
os.makedirs(slightly_biased_data_dir, exist_ok=True)

biasly_connected_dir = '../data/IV_Biasly_Connected'
os.makedirs(biasly_connected_dir, exist_ok=True)


def get_house(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach HOUSE-shaped subgraphs."""
    list_shapes = [["house"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


def get_cycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach cycle-shaped (directed edges) subgraphs."""
    list_shapes = [["dircycle"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


def get_crane(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):
    """ Synthetic Graph: Start with a tree and attach crane-shaped subgraphs."""
    list_shapes = [["varcycle"]] * nb_shapes

    if draw:
        plt.figure(figsize=figsize)

    G, role_id, _ = synthetic_structsim.build_graph(
        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
    )
    G = perturb([G], 0.00, id=role_id)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    return G, role_id, name


# Generating rules for normal graphs (used in Train and Validation)
def graph_stats(base_num):
    if base_num == 1:
        base = 'tree'
        width_basis = np.random.choice(range(2, 3))
    elif base_num == 2:
        base = 'ladder'
        width_basis = np.random.choice(range(3, 4))
    elif base_num == 3:
        base = 'wheel'
        width_basis = np.random.choice(range(6, 7))
    return base, width_basis

# Generating rules for large graphs (used in Test)
def graph_stats_large(base_num):
    if base_num == 1:
        base = 'tree'
        width_basis = np.random.choice(range(3, 4))
    elif base_num == 2:
        base = 'ladder'
        width_basis = np.random.choice(range(5, 6))
    elif base_num == 3:
        base = 'wheel'
        width_basis = np.random.choice(range(8, 9))
    return base, width_basis


# For labeled data
def compute_node_labels(role_id):
    role_arr = np.array(role_id).astype(int)
    node_labels = (role_arr > 0).astype(np.int64)
    return node_labels


def create_data_list(nb_graphs, get_graph_func, bias, large=False):
    edge_index_list, label_list, ground_truth_list, role_id_list, pos_list = [], [], [], [], []
    node_label_list = []
    e_mean, n_mean = [], []

    for _ in tqdm(range(nb_graphs)):
        if not large:
            base_num = np.random.choice([1, 2, 3], p=[bias, (1 - bias) / 2, (1 - bias) / 2])
        else:
            base_num = np.random.choice([1, 2, 3])
        base, width_basis = graph_stats_large(base_num) if large else graph_stats(base_num)

        G, role_id, name = get_graph_func(basis_type=base, nb_shapes=1,width_basis=width_basis, feature_generator=None,m=3, draw=False)

        if get_graph_func == get_house:
            label = 1
        elif get_graph_func == get_cycle:
            label = 0
        elif get_graph_func == get_crane:
            label = 0
        else:
            label = -1

        label_list.append(label)
        e_mean.append(len(G.edges))
        n_mean.append(len(G.nodes))

        role_id = np.array(role_id)
        role_id_list.append(role_id)

        edge_index = np.array(list(G.edges), dtype=np.int64).T
        if edge_index.size == 0:
            edge_index = np.empty((2, 0), dtype=np.int64)

        edge_index_list.append(edge_index)

        if len(G.nodes) > 0:
            pos_dict = nx.spring_layout(G)
            pos_array = np.array([pos_dict[i] for i in sorted(G.nodes())])
        else:
            pos_array = np.empty((0, 2))
        pos_list.append(pos_array)

        ground_truth_list.append(find_gd(edge_index, role_id))
        node_labels = compute_node_labels(role_id)
        node_label_list.append(node_labels)

    print("#Graphs: %d  #Nodes(avg): %.2f  #Edges(avg): %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
    return edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list


def save_as_pt(edge_index_list, label_list, node_label_list, filename, unlabeled=False):
    data_list = []
    for i in range(len(edge_index_list)):
        ei = edge_index_list[i]
        edge_index = torch.tensor(ei, dtype=torch.long) if ei.size != 0 else torch.empty((2, 0), dtype=torch.long)

        num_nodes = len(node_label_list[i]) if node_label_list[i].size > 0 else 0

        if unlabeled:
            x = torch.zeros((num_nodes, 1), dtype=torch.float32)
        else:
            x = torch.tensor(node_label_list[i], dtype=torch.float32).view(num_nodes, 1)

        y = torch.tensor([label_list[i]], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

    torch.save(data_list, filename)
    print(f"Saved {len(data_list)} graphs to {filename}")


# TRAINING: Heavily Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'train.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'train.pt'), unlabeled=True)

# VALIDATION: Heavily Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'val.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'val.pt'), unlabeled=True)

# TEST: Heavily Biased Base, Labeled
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, 0.333, large=True)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, 0.333, large=True)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, 0.333, large=True)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(labeleddata_dir, 'test.pt'), unlabeled=False)
save_as_pt(edge_index_list, label_list, node_label_list, osp.join(biased_dir, 'test.pt'), unlabeled=True)


# TRAINING: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b_slightly)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b_slightly)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b_slightly)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'train.pt'), unlabeled=True)

# VALIDATION: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, global_b_slightly)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, global_b_slightly)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, global_b_slightly)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'val.pt'), unlabeled=True)

# TEST: Equally Distributed Base
edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list = create_data_list(1000, get_cycle, 0.333)
edge_index_list2, label_list2, ground_truth_list2, role_id_list2, pos_list2, node_label_list2 = create_data_list(1000, get_house, 0.333)
edge_index_list3, label_list3, ground_truth_list3, role_id_list3, pos_list3, node_label_list3 = create_data_list(1000, get_crane, 0.333)

edge_index_list.extend(edge_index_list2)
edge_index_list.extend(edge_index_list3)
label_list.extend(label_list2)
label_list.extend(label_list3)
node_label_list.extend(node_label_list2)
node_label_list.extend(node_label_list3)

save_as_pt(edge_index_list, label_list, node_label_list, osp.join(slightly_biased_data_dir, 'test.pt'), unlabeled=True)


# Biasly Connected
def create_data_list_motif_specific(nb_graphs, get_graph_func, motif_name, large=False):
    edge_index_list, label_list, ground_truth_list, role_id_list, pos_list = [], [], [], [], []
    node_label_list = []
    e_mean, n_mean = [], []
    
    base_probs = {
        'house': [0.50, 0.25, 0.25],  # tree, ladder, wheel
        'cycle': [0.25, 0.50, 0.25],
        'crane': [0.25, 0.25, 0.50]
    }
    
    probs = base_probs[motif_name]
    
    for _ in tqdm(range(nb_graphs)):
        base_num = np.random.choice([1, 2, 3], p=probs)
        base, width_basis = graph_stats_large(base_num) if large else graph_stats(base_num)
        
        G, role_id, name = get_graph_func(basis_type=base, nb_shapes=1,
                                        width_basis=width_basis, feature_generator=None,
                                        m=3, draw=False)
        
        label = 1 if get_graph_func == get_house else 0
        
        label_list.append(label)
        e_mean.append(len(G.edges))
        n_mean.append(len(G.nodes))
        
        role_id = np.array(role_id)
        role_id_list.append(role_id)
        
        edge_index = np.array(list(G.edges), dtype=np.int64).T
        if edge_index.size == 0:
            edge_index = np.empty((2, 0), dtype=np.int64)
        edge_index_list.append(edge_index)
        
        if len(G.nodes) > 0:
            pos_dict = nx.spring_layout(G)
            pos_array = np.array([pos_dict[i] for i in sorted(G.nodes())])
        else:
            pos_array = np.empty((0, 2))
        pos_list.append(pos_array)
        
        ground_truth_list.append(find_gd(edge_index, role_id))
        node_labels = compute_node_labels(role_id)
        node_label_list.append(node_labels)
    
    print(f"#Graphs: %d  #Nodes(avg): %.2f  #Edges(avg): %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
    return edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list

def create_data_list_test_equal(nb_graphs, get_graph_func, motif_name, large=True):
    edge_index_list, label_list, ground_truth_list, role_id_list, pos_list = [], [], [], [], []
    node_label_list = []
    e_mean, n_mean = [], []
    
    # TEST: equal connection probabilities (33.33%)
    probs = [1/3, 1/3, 1/3]
    
    for _ in tqdm(range(nb_graphs)):
        base_num = np.random.choice([1, 2, 3], p=probs)
        base, width_basis = graph_stats_large(base_num)
        
        G, role_id, name = get_graph_func(basis_type=base, nb_shapes=1,
                                        width_basis=width_basis, feature_generator=None,
                                        m=3, draw=False)
        
        label = 1 if get_graph_func == get_house else 0
        
        label_list.append(label)
        e_mean.append(len(G.edges))
        n_mean.append(len(G.nodes))
        
        role_id = np.array(role_id)
        role_id_list.append(role_id)
        
        edge_index = np.array(list(G.edges), dtype=np.int64).T
        if edge_index.size == 0:
            edge_index = np.empty((2, 0), dtype=np.int64)
        edge_index_list.append(edge_index)
        
        if len(G.nodes) > 0:
            pos_dict = nx.spring_layout(G)
            pos_array = np.array([pos_dict[i] for i in sorted(G.nodes())])
        else:
            pos_array = np.empty((0, 2))
        pos_list.append(pos_array)
        
        ground_truth_list.append(find_gd(edge_index, role_id))
        node_labels = compute_node_labels(role_id)
        node_label_list.append(node_labels)
    
    print(f"#Graphs: %d  #Nodes(avg): %.2f  #Edges(avg): %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
    return edge_index_list, label_list, ground_truth_list, role_id_list, pos_list, node_label_list


# TRAIN: Biasly Connected
house_train = create_data_list_motif_specific(1000, get_house, 'house', large=False)
cycle_train = create_data_list_motif_specific(1000, get_cycle, 'cycle', large=False)
crane_train = create_data_list_motif_specific(1000, get_crane, 'crane', large=False)

edge_index_list_train = house_train[0] + cycle_train[0] + crane_train[0]
label_list_train = house_train[1] + cycle_train[1] + crane_train[1]
node_label_list_train = house_train[5] + cycle_train[5] + crane_train[5]
save_as_pt(edge_index_list_train, label_list_train, node_label_list_train, osp.join(biasly_connected_dir, 'train.pt'), unlabeled=True)

# VALIDATION: Biasly Connected
house_val = create_data_list_motif_specific(1000, get_house, 'house', large=False)
cycle_val = create_data_list_motif_specific(1000, get_cycle, 'cycle', large=False)
crane_val = create_data_list_motif_specific(1000, get_crane, 'crane', large=False)

edge_index_list_val = house_val[0] + cycle_val[0] + crane_val[0]
label_list_val = house_val[1] + cycle_val[1] + crane_val[1]
node_label_list_val = house_val[5] + cycle_val[5] + crane_val[5]
save_as_pt(edge_index_list_val, label_list_val, node_label_list_val, osp.join(biasly_connected_dir, 'val.pt'), unlabeled=True)

house_test = create_data_list_test_equal(1000, get_house, 'house', large=True)
cycle_test = create_data_list_test_equal(1000, get_cycle, 'cycle', large=True)
crane_test = create_data_list_test_equal(1000, get_crane, 'crane', large=True)

#TEST: Biasly Connected (Equally Connected)
edge_index_list_test = house_test[0] + cycle_test[0] + crane_test[0]
label_list_test = house_test[1] + cycle_test[1] + crane_test[1]
node_label_list_test = house_test[5] + cycle_test[5] + crane_test[5]
save_as_pt(edge_index_list_test, label_list_test, node_label_list_test, osp.join(biasly_connected_dir, 'test.pt'), unlabeled=True)

print("\nData generation completed!")


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:26<00:00, 374.30it/s]


#Graphs: 10000  #Nodes(avg): 14.56  #Edges(avg): 14.90 


100%|██████████| 10000/10000 [00:26<00:00, 371.98it/s]


#Graphs: 10000  #Nodes(avg): 14.60  #Edges(avg): 15.96 


100%|██████████| 10000/10000 [00:26<00:00, 370.39it/s]


#Graphs: 10000  #Nodes(avg): 14.60  #Edges(avg): 16.96 
Saved 30000 graphs to ../BIGdata_09/II_Labeled/train.pt
Saved 30000 graphs to ../BIGdata_09/I_Hardly_Biased_Base/train.pt


100%|██████████| 10000/10000 [00:27<00:00, 369.32it/s]


#Graphs: 10000  #Nodes(avg): 14.60  #Edges(avg): 14.96 


100%|██████████| 10000/10000 [00:26<00:00, 370.40it/s]


#Graphs: 10000  #Nodes(avg): 14.61  #Edges(avg): 15.98 


100%|██████████| 10000/10000 [00:26<00:00, 370.54it/s]


#Graphs: 10000  #Nodes(avg): 14.57  #Edges(avg): 16.94 
Saved 30000 graphs to ../BIGdata_09/II_Labeled/val.pt
Saved 30000 graphs to ../BIGdata_09/I_Hardly_Biased_Base/val.pt


100%|██████████| 10000/10000 [00:34<00:00, 290.71it/s]


#Graphs: 10000  #Nodes(avg): 20.21  #Edges(avg): 23.87 


100%|██████████| 10000/10000 [00:34<00:00, 289.87it/s]


#Graphs: 10000  #Nodes(avg): 20.24  #Edges(avg): 24.86 


100%|██████████| 10000/10000 [00:34<00:00, 290.53it/s]


#Graphs: 10000  #Nodes(avg): 20.14  #Edges(avg): 25.79 
Saved 30000 graphs to ../BIGdata_09/II_Labeled/test.pt
Saved 30000 graphs to ../BIGdata_09/I_Hardly_Biased_Base/test.pt


100%|██████████| 10000/10000 [00:25<00:00, 397.46it/s]


#Graphs: 10000  #Nodes(avg): 12.33  #Edges(avg): 14.65 


100%|██████████| 10000/10000 [00:25<00:00, 399.84it/s]


#Graphs: 10000  #Nodes(avg): 12.34  #Edges(avg): 15.69 


100%|██████████| 10000/10000 [00:25<00:00, 399.82it/s]


#Graphs: 10000  #Nodes(avg): 12.35  #Edges(avg): 16.65 
Saved 30000 graphs to ../BIGdata_09/III_Equally_Distributed_Base/train.pt


100%|██████████| 10000/10000 [00:25<00:00, 399.74it/s]


#Graphs: 10000  #Nodes(avg): 12.34  #Edges(avg): 14.68 


100%|██████████| 10000/10000 [00:25<00:00, 399.28it/s]


#Graphs: 10000  #Nodes(avg): 12.36  #Edges(avg): 15.67 


100%|██████████| 10000/10000 [00:25<00:00, 398.52it/s]


#Graphs: 10000  #Nodes(avg): 12.32  #Edges(avg): 16.63 
Saved 30000 graphs to ../BIGdata_09/III_Equally_Distributed_Base/val.pt


100%|██████████| 10000/10000 [00:25<00:00, 395.29it/s]


#Graphs: 10000  #Nodes(avg): 12.34  #Edges(avg): 14.67 


100%|██████████| 10000/10000 [00:25<00:00, 393.32it/s]


#Graphs: 10000  #Nodes(avg): 12.34  #Edges(avg): 15.66 


100%|██████████| 10000/10000 [00:25<00:00, 397.85it/s]


#Graphs: 10000  #Nodes(avg): 12.36  #Edges(avg): 16.66 
Saved 30000 graphs to ../BIGdata_09/III_Equally_Distributed_Base/test.pt


100%|██████████| 10000/10000 [00:27<00:00, 370.18it/s]


#Graphs: 10000  #Nodes(avg): 14.61  #Edges(avg): 15.96 


100%|██████████| 10000/10000 [00:24<00:00, 414.21it/s]


#Graphs: 10000  #Nodes(avg): 11.21  #Edges(avg): 13.25 


100%|██████████| 10000/10000 [00:24<00:00, 410.47it/s]


#Graphs: 10000  #Nodes(avg): 11.20  #Edges(avg): 17.80 
Saved 30000 graphs to ../BIGdata_09/IV_Biasly_Connected/train.pt


100%|██████████| 10000/10000 [00:27<00:00, 369.04it/s]


#Graphs: 10000  #Nodes(avg): 14.56  #Edges(avg): 15.91 


100%|██████████| 10000/10000 [00:24<00:00, 408.83it/s]


#Graphs: 10000  #Nodes(avg): 11.19  #Edges(avg): 13.25 


100%|██████████| 10000/10000 [00:24<00:00, 413.11it/s]


#Graphs: 10000  #Nodes(avg): 11.20  #Edges(avg): 17.80 
Saved 30000 graphs to ../BIGdata_09/IV_Biasly_Connected/val.pt


100%|██████████| 10000/10000 [00:35<00:00, 278.38it/s]


#Graphs: 10000  #Nodes(avg): 20.23  #Edges(avg): 24.90 


100%|██████████| 10000/10000 [00:34<00:00, 286.05it/s]


#Graphs: 10000  #Nodes(avg): 20.07  #Edges(avg): 23.74 


100%|██████████| 10000/10000 [00:35<00:00, 284.01it/s]


#Graphs: 10000  #Nodes(avg): 20.25  #Edges(avg): 25.89 
Saved 30000 graphs to ../BIGdata_09/IV_Biasly_Connected/test.pt

Data generation completed!
