In [15]:
import numpy as np
import os
import torch
from torch_geometric.utils import to_networkx, degree, to_dense_adj, to_scipy_sparse_matrix
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from scipy import sparse as sp
import dgl
import networkx as nx
def torch_save(base_dir, filename, data):
    os.makedirs(base_dir, exist_ok=True)
    fpath = os.path.join(base_dir, filename)    
    torch.save(data, fpath)
def torch_load(base_dir, filename):
    fpath = os.path.join(base_dir, filename)    
    return torch.load(fpath, map_location=torch.device('cpu'))
def get_data(client_id,clients):
    return [
        torch_load(
            "/home/1005wjy/datasets/", 
           f'{"Ogbn"}_{"disjoint"}/{clients}/partition_{client_id}.pt'
        )['client_data']
    ]



def init_graphrepair(g, num_labels):

    # random walk embedding
    A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
    D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()

    Dinv=sp.diags(D)
    RW=A*Dinv
    M=RW
    SE=[torch.from_numpy(M.diagonal()).float()]
    M_power=M
    for _ in range(15):
        M_power=M_power*M
        SE.append(torch.from_numpy(M_power.diagonal()).float())
    
    random_emb = torch.stack(SE,dim=-1)

    # homogeneity embedding 
    homo_emb = torch.zeros([g.num_nodes, num_labels])
    edge=[[int(g.edge_index[0][i]),int(g.edge_index[1][i])] for i in range(len(g.edge_index[0])) ]

    for e in edge:
        begin, end = e[0], e[1]
        homo_emb[begin, int(g.y[end])] += 1
        homo_emb[end, int(g.y[begin])] += 1
    g['stc_enc'] = torch.cat([random_emb, homo_emb], dim=1)

    return g

num_labels = 40
clients = 5


In [None]:
for c in range(clients):
    data_path = "/home/1005wjy/datasets/"
    client_graph = get_data(c,clients)[0]
    print("client ",c," get data over")
    client_graph = init_graphrepair(client_graph, num_labels)
    torch_save(data_path, f'{"Ogbn"}_disjoint/{clients}/init_{c}.pt', {
            'client_data': client_graph,
            'client_id': c
        })
    print("client ",c," over")

client  0  get data over


In [4]:
a = get_data(1,5)

In [None]:
pip install dgl

Collecting dgl
  Using cached dgl-2.1.0-cp310-cp310-manylinux1_x86_64.whl (8.5 MB)
Collecting torchdata>=0.5.0
  Downloading torchdata-0.8.0-cp310-cp310-manylinux1_x86_64.whl (2.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hCollecting networkx>=2.1
  Downloading networkx-3.3-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting torch>=2
  Downloading torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl (797.2 MB)
[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m748.0/797.2 MB[0m [31m3.4 MB/s[0m eta [36m0:00:15[0m