In [None]:
import random
import itertools
from pathlib import Path

import numpy as np
import torch as th
import dgl
from dgl.data import AIFBDataset, MUTAGDataset, BGSDataset

In [None]:
# 1. randomly sample triplets
# 2. randomly sample edge types
nc_val_size = 0.2
settings = [3, 5, 10]  # must be larger than 2
transform = dgl.AddReverse(copy_edata=False, sym_new_etype=True)
save_path_prefix = Path("./data/")
random.seed(12345)
random_state = random.getstate()

In [None]:
# 1. randomly sample triplets
# randomly and evenly distribute all edges into (num_clients + 2) groups
# each client exclusively own one group of edges
# edges of one another group are shared by random k clients (1 < k < num_clients)
# one final group is shared by all clients
def split_by_random_edges(g, num_clients):
    assert num_clients > 2

    edges = []
    for cetype in g.canonical_etypes:
        edges.extend([(cetype, i) for i in range(g.num_edges(etype=cetype))])

    random.shuffle(edges)
    segment_size = len(edges) // (num_clients + 2)
    all_client_edges = []
    for i in range(num_clients):
        client_edges = edges[i * segment_size: (i + 1) * segment_size] + edges[(num_clients + 1) * segment_size:]
        all_client_edges.append(client_edges)
    for edge in edges[num_clients * segment_size: (num_clients + 1) * segment_size]:
        k = random.randint(2, num_clients - 1)
        sampled_clients = random.sample(range(num_clients), k)
        for client in sampled_clients:
            all_client_edges[client].append(edge)

    g_list = []
    for client_edges in all_client_edges:
        client_edges.sort()
        eid_dict = {cetype: th.tensor([eid for _, eid in edge_iter]) for cetype, edge_iter in
                    itertools.groupby(client_edges, lambda x: x[0])}
        temp_g = dgl.edge_subgraph(g, eid_dict)
        canonical_etypes = [cetype for cetype in temp_g.canonical_etypes if temp_g.num_edges(cetype) > 0]
        g_list.append(dgl.edge_type_subgraph(temp_g, canonical_etypes))

    return g_list

# 2. randomly sample edge types
# randomly and evenly distribute all etypes into (num_clients + 2) groups
# each client exclusively own one group of etypes
# etypes of one another group are shared by random k clients (1 < k < num_clients)
# one final group is shared by all clients
def split_by_random_etypes(g, num_clients):
    assert num_clients > 2
    assert len(g.canonical_etypes) > num_clients

    canonical_etypes = g.canonical_etypes.copy()
    random.shuffle(canonical_etypes)
    if len(g.canonical_etypes) > num_clients + 1:
        segment_size = len(canonical_etypes) // (num_clients + 2)
        all_client_cetypes = []
        for i in range(num_clients):
            client_cetypes = canonical_etypes[i * segment_size: (i + 1) * segment_size] + canonical_etypes[
                                                                                          (num_clients + 1) * segment_size:]
            all_client_cetypes.append(client_cetypes)
        for cetype in canonical_etypes[num_clients * segment_size: (num_clients + 1) * segment_size]:
            k = random.randint(2, num_clients - 1)
            sampled_clients = random.sample(range(num_clients), k)
            for client in sampled_clients:
                all_client_cetypes[client].append(cetype)
    else:
        # len(g.canonical_etypes) == num_clients + 1
        all_client_cetypes = [[canonical_etypes[i], canonical_etypes[-1]] for i in range(num_clients)]

    g_list = []
    for client_cetypes in all_client_cetypes:
        client_cetypes.sort()
        g_list.append(dgl.compact_graphs(dgl.edge_type_subgraph(g, client_cetypes)))

    return g_list

In [None]:
# node classification
# AIFBDataset, MUTAGDataset, BGSDataset
random.setstate(random_state)
nc_datasets = [AIFBDataset, MUTAGDataset, BGSDataset]

for DatasetClass in nc_datasets:
    # load dataset
    dataset = DatasetClass(insert_reverse=False, force_reload=True)
    g = dataset[0]
    target_ntype = dataset.predict_category
    num_classes = dataset.num_classes
    train_mask = g.nodes[target_ntype].data["train_mask"]
    # train/val split
    train_idx = train_mask.nonzero().flatten()
    index = list(range(len(train_idx)))
    random.shuffle(index)
    val_idx = train_idx[index[:round(len(index) * nc_val_size)]]
    val_mask = th.zeros_like(train_mask)
    val_mask[val_idx] = True
    train_mask[val_idx] = False
    g.nodes[target_ntype].data["val_mask"] = val_mask

    # centralized setting
    g_with_rev = transform(g)  # add reverse edges
    # save graph with num_classes
    save_path = str(save_path_prefix / dataset.name / f"{dataset.name}_centralized_1.bin")
    num_classes_info = {"num_classes": th.tensor([num_classes])}
    dgl.save_graphs(save_path, [g_with_rev], num_classes_info)

    for num_clients in settings:
        # split the graph (1. random edges, 2. random etypes)
        g_list_edges = split_by_random_edges(g, num_clients)
        g_list_etypes = split_by_random_etypes(g, num_clients)
        # make sure each subgraph has train/val/test nodes
        for sub_g in g_list_edges:
            assert "train_mask" in sub_g.ndata
            assert "val_mask" in sub_g.ndata
            assert "test_mask" in sub_g.ndata
            assert sub_g.nodes[target_ntype].data["train_mask"].any()
            assert sub_g.nodes[target_ntype].data["val_mask"].any()
            assert sub_g.nodes[target_ntype].data["test_mask"].any()
        # add reverse edges
        g_list_edges = [transform(sub_g) for sub_g in g_list_edges]
        g_list_etypes = [transform(sub_g) for sub_g in g_list_etypes]
        # save graphs with num_classes
        save_path = str(save_path_prefix / dataset.name / f"{dataset.name}_random-edges_{num_clients}.bin")
        num_classes_info = {"num_classes": th.tensor([num_classes] * num_clients)}
        dgl.save_graphs(save_path, g_list_edges, num_classes_info)
        save_path = str(save_path_prefix / dataset.name / f"{dataset.name}_random-etypes_{num_clients}.bin")
        dgl.save_graphs(save_path, g_list_etypes, num_classes_info)

In [None]:
def nc_dataset_statistics(dataset_name):
    path = Path("data") / dataset_name / f"{dataset_name}_{{}}_{{}}.bin"
    settings = [("centralized", 1)]
    settings.extend([(split_type, num_clients) for split_type in ["random-edges", "random-etypes"] for num_clients in [3, 5, 10]])
    for split_type, num_clients in settings:
        g_list, _ = dgl.load_graphs(str(path).format(split_type, num_clients))
        target_ntype = list(g_list[0].ndata["train_mask"].keys())[0]
        num_ntypes = []
        num_etypes = []
        num_nodes = []
        num_edges = []
        num_train = []
        num_val = []
        num_test = []
        for g in g_list:
            num_ntypes.append(len(g.ntypes))
            num_etypes.append(len(g.etypes))
            num_nodes.append(g.num_nodes())
            num_edges.append(sum([g.num_edges(etype=cetype) for cetype in g.canonical_etypes]))
            num_train.append(g.ndata['train_mask'][target_ntype].sum().item())
            num_val.append(g.ndata['val_mask'][target_ntype].sum().item())
            num_test.append(g.ndata['test_mask'][target_ntype].sum().item())
        print(f"{dataset_name}_{split_type}_{num_clients}")
        # print(f"num_ntypes = {np.mean(num_ntypes):.1f}$\pm${np.std(num_ntypes):.1f}")
        # print(f"num_etypes = {np.mean(num_etypes):.1f}$\pm${np.std(num_etypes):.1f}")
        # print(f"num_nodes = {np.mean(num_nodes):.1f}$\pm${np.std(num_nodes):.1f}")
        # print(f"num_edges = {np.mean(num_edges):.1f}$\pm${np.std(num_edges):.1f}")
        # print(f"num_train = {np.mean(num_train):.1f}$\pm${np.std(num_train):.1f}")
        # print(f"num_val = {np.mean(num_val):.1f}$\pm${np.std(num_val):.1f}")
        # print(f"num_test = {np.mean(num_test):.1f}$\pm${np.std(num_test):.1f}")

        print(f"{np.mean(num_ntypes):.1f}")
        print(f"{np.mean(num_etypes):.1f}")
        print(f"{np.mean(num_nodes):.1f}")
        print(f"{np.mean(num_edges):.1f}")
        print(f"{np.mean(num_train):.1f}")
        print(f"{np.mean(num_val):.1f}")
        print(f"{np.mean(num_test):.1f}")

In [None]:
nc_dataset_statistics("bgs-hetero")