In [1]:
import math
import numpy as np
import torch
import torch_geometric
import networkx as nx
import matplotlib.pyplot as plt


In [2]:
from utils import read_test1_data
from utils import gen_graph

In [3]:
RANDOM_STATE = 11

SYNTHETIC_NUM = 1000
# SYNTHETIC_NUM = 100

# number of gen nodes
NUM_MIN = 100
NUM_MAX = 200

LEARNING_RATE = 1e-4
EMBEDDING_SIZE = 128
DEPTH = 5
BATCH_SIZE = 4


## Read Graph

In [4]:
test1_X, test1_bc = read_test1_data(0)

## Generate Synthetic Graph

In [5]:
train_g = gen_graph(NUM_MIN, NUM_MAX)
print(len(train_g.edges()))

622


In [6]:
list(train_g.neighbors(0))

[4, 8, 9, 16, 19, 81, 128, 142]

In [7]:
ls = []
for node in list(train_g.nodes())[:5]:
    ls.append(list(train_g.neighbors(node)))


In [8]:
# nx.betweenness_centrality(train_g)

## DrBC

In [9]:
import torch
from torch.nn import Module, Linear, GRUCell, Sequential, ReLU, functional as t_F

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
def prepare_synthetic():
    g_list = []
    dg_list = []
    bc_list = []
    for i in range(SYNTHETIC_NUM):
        g = gen_graph(NUM_MIN, NUM_MAX)
        g_list.append(g)
        dg_list.append(nx.degree(g))
        bc_list.append(nx.betweenness_centrality(g))
        
    return g_list, dg_list, bc_list

def preprocessing_data(train_g:list, train_dg:list, train_bc:list):
    X = []
    y = []
    nb = []
    pre_index = 0
    for i in range(len(train_bc)):
        assert len(train_dg[i]) == len(train_bc[i]) == len(train_g[i].nodes())
        # make suer is has same nodes number.
        num_node = len(train_dg[i])
        for node_id in range(num_node):
            node_nb = [pre_index+n for n in train_g[i].neighbors(node_id)]
            X.append([train_dg[i][node_id], 1., 1.])
            y.append(train_bc[i][node_id])
            nb.append(node_nb)
        pre_index += num_node
    X = torch.Tensor(X)
    y = torch.Tensor(y)
    # print(X.shape, y.shape)

    return X, y, nb

def get_pairwise_ids(g_list):
    s_ids = np.zeros(shape=(0, ), dtype=int)
    t_ids = np.zeros(shape=(0, ), dtype=int)
    pre_index = 0
    for g in g_list:
        num_node = len(g.nodes())
        ids_1 = np.repeat(np.arange(pre_index, pre_index+num_node), 5)
        ids_2 = np.repeat(np.arange(pre_index, pre_index+num_node), 5)

        np.random.shuffle(ids_1)
        np.random.shuffle(ids_2)

        s_ids = np.append(s_ids, ids_1, axis=0)
        t_ids = np.append(t_ids, ids_2, axis=0)
        pre_index += num_node
    return s_ids, t_ids

In [12]:
class DrBC(Module):
    def __init__(self, embedding_size=EMBEDDING_SIZE, depth=DEPTH):
        super(DrBC, self).__init__()
        self.embedding_size = embedding_size
        self.depth = depth
        self.linear0 = Linear(3, self.embedding_size)
        self.gru = GRUCell(self.embedding_size, self.embedding_size)
        # decoder
        self.mlp = Sequential(
            Linear(self.embedding_size, self.embedding_size // 2),
            ReLU(),
            Linear(self.embedding_size // 2, 1)
        )
        
    def neighbor_aggre(self, X, all_nb, h):
        # nb aggre
        h_aggre = []
        for node_id in range(X.shape[0]):
            d_v = X[node_id, 0]
            node_nb = all_nb[node_id]
            node_aggre = torch.Tensor([0.] * self.embedding_size).to(device)
            for nb_id in node_nb:
                # for node all nb
                node_aggre += (1 / (math.sqrt(d_v + 1) * math.sqrt(X[nb_id, 0] + 1))) * h[nb_id]
            h_aggre.append(torch.unsqueeze(node_aggre, dim=0))
        h_aggre = torch.cat(h_aggre, dim=0)
        # print('h_aggre shape: ', h_aggre.shape)
        return h_aggre # tensor format

    def forward(self, X, all_nb):
        all_h = []
        h = self.linear0(X)
        h = torch.relu(h)
        h = t_F.normalize(h, p=2, dim=-1) # l2-norm
        all_h.append(torch.unsqueeze(h, dim=0))

        # GRUCell
        for i in range(self.depth-1):
            # neighborhood aggregation
            h_aggre = self.neighbor_aggre(X, all_nb, h)
            h = self.gru(h_aggre, h)
            h = t_F.normalize(h, p=2, dim=-1) # l2-norm
            all_h.append(torch.unsqueeze(h, dim=0))
        
        # max pooling
        all_h = torch.cat(all_h, dim=0)
        h_max = torch.max(all_h, dim=0).values
        # print('h_max shape: ', h_max.shape)

        # Decoder
        out = self.mlp(h_max)
        out = torch.squeeze(out)
        # print('out shape: ', out.shape)
        return out
        


        
model = DrBC().to(device)
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.BCEWithLogitsLoss()

In [13]:
model.parameters

<bound method Module.parameters of DrBC(
  (linear0): Linear(in_features=3, out_features=128, bias=True)
  (gru): GRUCell(128, 128)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)>

In [14]:
def train():
    g_list, dg_list, bc_list  = prepare_synthetic()
    print('-'*20, 'prepare systhetic done')
    batch_cnt = len(g_list) // BATCH_SIZE

    for i in range(batch_cnt):
        s_index = i*BATCH_SIZE
        e_index = (i+1)*BATCH_SIZE
        train_g, train_dg, train_bc = g_list[s_index: e_index], dg_list[s_index: e_index], bc_list[s_index: e_index]
        X, y, all_nb = preprocessing_data(train_g, train_dg, train_bc)
        X, y = X.to(device), y.to(device)
        out = model(X, all_nb)

        # pairwise-loss
        s_ids, t_ids = get_pairwise_ids(train_g)
        out_diff = out[s_ids] - out[t_ids]
        y_diff = y[s_ids] - y[t_ids]
        loss = loss_fn(out_diff, torch.sigmoid(y_diff))

        # optim
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(f"Batch {i + 1}: Loss = {loss.item()}")
        
def validate():
    pass

_ = train()

-------------------- prepare systhetic done
Batch 1: Loss = 0.6931463479995728
Batch 2: Loss = 0.6931358575820923
Batch 3: Loss = 0.6931314468383789
Batch 4: Loss = 0.6931201219558716
Batch 5: Loss = 0.6931133270263672
Batch 6: Loss = 0.6931107044219971
Batch 7: Loss = 0.6931028366088867
Batch 8: Loss = 0.6930952668190002
Batch 9: Loss = 0.6930988430976868
Batch 10: Loss = 0.693091094493866
Batch 11: Loss = 0.6930738687515259
Batch 12: Loss = 0.6930953860282898
Batch 13: Loss = 0.693079948425293
Batch 14: Loss = 0.6930753588676453
Batch 15: Loss = 0.6930761933326721
Batch 16: Loss = 0.6930736899375916
Batch 17: Loss = 0.6930694580078125
Batch 18: Loss = 0.6930606365203857
Batch 19: Loss = 0.6930750012397766
Batch 20: Loss = 0.6930649876594543
Batch 21: Loss = 0.6930702328681946
Batch 22: Loss = 0.6930771470069885
Batch 23: Loss = 0.6930841207504272
Batch 24: Loss = 0.6930593848228455
Batch 25: Loss = 0.6930658221244812
Batch 26: Loss = 0.6930628418922424
Batch 27: Loss = 0.693049371242

KeyboardInterrupt: 

## To-Do List
* (done) loss_fn 再加上 sigmoid
* (done) pairwise 目前跨圖了
* (done) h 要 normalized
* aggregate 改成 MessagePassing