In [17]:
import math
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

In [18]:
from utils import read_test1_data
from utils import gen_graph
from utils import prepare_synthetic
from utils import shuffle_graph
from utils import preprocessing_data
from utils import get_pairwise_ids

from utils import prepare_test1
from utils import top_n_acc

In [19]:
RANDOM_STATE = 11
SYNTHETIC_NUM = 100
# SYNTHETIC_NUM = 1000

# number of gen nodes
# NUM_MIN = 4000
# NUM_MAX = 4001
NUM_MIN = 200
NUM_MAX = 201


MAX_EPOCHS = 10000
LEARNING_RATE = 1e-4
EMBEDDING_SIZE = 128
DEPTH = 5
BATCH_SIZE = 16

TEST1_NUM = 30

## Read Graph

In [20]:
test1_g, test1_bc, test1_edgeindex = read_test1_data(0)

## Generate Synthetic Graph

In [21]:
train_g = gen_graph(500, 501)
print(len(train_g.edges()))

1984


In [22]:
# [train_g.degree(i) for i in range(train_g.number_of_nodes())]

In [23]:
# nx.betweenness_centrality(train_g)

In [24]:
(np.array(list(train_g.edges())) + 100)[:10]

array([[100, 104],
       [100, 105],
       [100, 106],
       [100, 107],
       [100, 108],
       [100, 109],
       [100, 111],
       [100, 114],
       [100, 129],
       [100, 131]])

In [25]:
# nx.betweenness_centrality(train_g)

## DrBC

In [26]:
from scipy import stats
# from model1 import DrBC
from model import DrBC
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
model = DrBC().to(device)
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')

In [28]:
model.parameters

<bound method Module.parameters of DrBC(
  (linear0): Linear(in_features=3, out_features=128, bias=True)
  (gcn): GCNConv()
  (gru): GRUCell(128, 128)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)>

In [29]:
# list(model.parameters())[9].grad.data

In [30]:
pm = list(model.parameters())

for i, p in enumerate(pm):
    print(f"pm{i} shape: {p.shape}")

pm0 shape: torch.Size([128, 3])
pm1 shape: torch.Size([128])
pm2 shape: torch.Size([384, 128])
pm3 shape: torch.Size([384, 128])
pm4 shape: torch.Size([384])
pm5 shape: torch.Size([384])
pm6 shape: torch.Size([64, 128])
pm7 shape: torch.Size([64])
pm8 shape: torch.Size([1, 64])
pm9 shape: torch.Size([1])


In [31]:
# list(dict(nx.degree(train_g)).values())
# list(dict(nx.degree(train_g)).values())
# list(dict(nx.betweenness_centrality(train_g)).values())

In [32]:
def validate(model, v_data):
    model.eval()
    total_acc = 0.
    total_kendall = 0.
    for val_X, val_y, val_edge_index in v_data:
        val_X, val_edge_index = val_X.to(device), val_edge_index.to(device)
        
        with torch.no_grad():
            val_y_pred = model(val_X, val_edge_index)

        val_y_pred = val_y_pred.cpu().detach().numpy()
        val_y = val_y.detach().numpy()

        pred_index = val_y_pred.argsort()[::-1]
        true_index = val_y.argsort()[::-1]
        
        acc = top_n_acc(pred_index, true_index)
        kendall_t, _ = stats.kendalltau(val_y_pred, val_y)

        total_acc += acc
        total_kendall += kendall_t

    total_acc /= len(v_data)
    total_kendall /= len(v_data)
    return total_acc, total_kendall
    

def train(model, optim, loss_fn, epochs:int):
    g_list, dg_list, bc_list  = prepare_synthetic(SYNTHETIC_NUM, (NUM_MIN, NUM_MAX))
    v_data = prepare_test1(TEST1_NUM)
    
    ls_metric = []
    batch_cnt = len(g_list) // BATCH_SIZE
    for e in range(epochs + 1):
        model.train()
        g_list, dg_list, bc_list = shuffle_graph(g_list, dg_list, bc_list)
        batch_bar = tqdm(range(batch_cnt))
        batch_bar.set_description(f'Epochs {e:<5}')
        train_loss = 0
        pair_cnt = 0
        for i in batch_bar:
            # batch
            s_index, e_index = i*BATCH_SIZE, (i+1)*BATCH_SIZE
            train_g, train_dg, train_bc = g_list[s_index: e_index], dg_list[s_index: e_index], bc_list[s_index: e_index]
            X, y, edge_index = preprocessing_data(train_g, train_dg, train_bc)
            X, y, edge_index = X.to(device), y.to(device), edge_index.to(device)
            out = model(X, edge_index)

            # pairwise-loss
            s_ids, t_ids = get_pairwise_ids(train_g)
            out_diff = out[s_ids] - out[t_ids]
            y_diff = y[s_ids] - y[t_ids]
            loss = loss_fn(out_diff, torch.sigmoid(y_diff))

            # optim
            optim.zero_grad()
            loss.backward()
            optim.step()

            pair_cnt += s_ids.shape[0]
            train_loss += (loss.item() * s_ids.shape[0])
            if i == (batch_cnt - 1):
                # last batch
                train_loss /= pair_cnt
                batch_bar.set_postfix(loss=train_loss) 

        if e % 50 == 0:
            val_acc, val_kendall = validate(model, v_data)
            ls_metric.append([e, val_acc, val_kendall])
            print(f"Val Acc: {val_acc * 100:.2f} % | Val KendallTau: {val_kendall:.4f}")
        

_ = train(model, optim, loss_fn, 200)

[Generating new training graph]: 100%|██████████| 100/100 [00:11<00:00,  8.54it/s]
[Reading test1 graph]: 100%|██████████| 30/30 [00:07<00:00,  3.80it/s]
Epochs 0    : 100%|██████████| 6/6 [00:00<00:00, 43.54it/s, loss=1.11e+4]


Val Acc: 87.20 % | Val KendallTau: 0.7299


Epochs 1    : 100%|██████████| 6/6 [00:00<00:00, 44.78it/s, loss=1.11e+4]
Epochs 2    : 100%|██████████| 6/6 [00:00<00:00, 47.34it/s, loss=1.11e+4]
Epochs 3    : 100%|██████████| 6/6 [00:00<00:00, 42.59it/s, loss=1.11e+4]
Epochs 4    : 100%|██████████| 6/6 [00:00<00:00, 42.26it/s, loss=1.11e+4]
Epochs 5    : 100%|██████████| 6/6 [00:00<00:00, 43.68it/s, loss=1.11e+4]
Epochs 6    : 100%|██████████| 6/6 [00:00<00:00, 47.21it/s, loss=1.11e+4]
Epochs 7    : 100%|██████████| 6/6 [00:00<00:00, 42.92it/s, loss=1.11e+4]
Epochs 8    : 100%|██████████| 6/6 [00:00<00:00, 44.06it/s, loss=1.11e+4]
Epochs 9    : 100%|██████████| 6/6 [00:00<00:00, 43.78it/s, loss=1.11e+4]
Epochs 10   : 100%|██████████| 6/6 [00:00<00:00, 43.98it/s, loss=1.11e+4]
Epochs 11   : 100%|██████████| 6/6 [00:00<00:00, 42.73it/s, loss=1.11e+4]
Epochs 12   : 100%|██████████| 6/6 [00:00<00:00, 42.62it/s, loss=1.11e+4]
Epochs 13   : 100%|██████████| 6/6 [00:00<00:00, 45.89it/s, loss=1.11e+4]
Epochs 14   : 100%|██████████| 6/6 [00

Val Acc: 94.27 % | Val KendallTau: 0.4120


Epochs 51   : 100%|██████████| 6/6 [00:00<00:00, 45.31it/s, loss=1.11e+4]
Epochs 52   : 100%|██████████| 6/6 [00:00<00:00, 46.77it/s, loss=1.11e+4]
Epochs 53   : 100%|██████████| 6/6 [00:00<00:00, 47.58it/s, loss=1.11e+4]
Epochs 54   : 100%|██████████| 6/6 [00:00<00:00, 45.30it/s, loss=1.11e+4]
Epochs 55   : 100%|██████████| 6/6 [00:00<00:00, 46.18it/s, loss=1.11e+4]
Epochs 56   : 100%|██████████| 6/6 [00:00<00:00, 43.19it/s, loss=1.11e+4]
Epochs 57   : 100%|██████████| 6/6 [00:00<00:00, 47.44it/s, loss=1.11e+4]
Epochs 58   : 100%|██████████| 6/6 [00:00<00:00, 41.29it/s, loss=1.11e+4]
Epochs 59   : 100%|██████████| 6/6 [00:00<00:00, 39.49it/s, loss=1.11e+4]
Epochs 60   : 100%|██████████| 6/6 [00:00<00:00, 44.24it/s, loss=1.11e+4]
Epochs 61   : 100%|██████████| 6/6 [00:00<00:00, 47.53it/s, loss=1.11e+4]
Epochs 62   : 100%|██████████| 6/6 [00:00<00:00, 41.76it/s, loss=1.11e+4]
Epochs 63   : 100%|██████████| 6/6 [00:00<00:00, 46.94it/s, loss=1.11e+4]
Epochs 64   : 100%|██████████| 6/6 [00

Val Acc: 94.33 % | Val KendallTau: 0.6398


Epochs 101  : 100%|██████████| 6/6 [00:00<00:00, 41.10it/s, loss=1.11e+4]
Epochs 102  : 100%|██████████| 6/6 [00:00<00:00, 42.92it/s, loss=1.11e+4]
Epochs 103  : 100%|██████████| 6/6 [00:00<00:00, 43.68it/s, loss=1.11e+4]
Epochs 104  : 100%|██████████| 6/6 [00:00<00:00, 39.24it/s, loss=1.11e+4]
Epochs 105  : 100%|██████████| 6/6 [00:00<00:00, 43.75it/s, loss=1.11e+4]
Epochs 106  : 100%|██████████| 6/6 [00:00<00:00, 44.35it/s, loss=1.11e+4]
Epochs 107  : 100%|██████████| 6/6 [00:00<00:00, 44.25it/s, loss=1.11e+4]
Epochs 108  : 100%|██████████| 6/6 [00:00<00:00, 40.55it/s, loss=1.11e+4]
Epochs 109  : 100%|██████████| 6/6 [00:00<00:00, 44.07it/s, loss=1.11e+4]
Epochs 110  : 100%|██████████| 6/6 [00:00<00:00, 43.60it/s, loss=1.11e+4]
Epochs 111  : 100%|██████████| 6/6 [00:00<00:00, 44.05it/s, loss=1.11e+4]
Epochs 112  : 100%|██████████| 6/6 [00:00<00:00, 43.29it/s, loss=1.11e+4]
Epochs 113  : 100%|██████████| 6/6 [00:00<00:00, 41.85it/s, loss=1.11e+4]
Epochs 114  : 100%|██████████| 6/6 [00

Val Acc: 94.27 % | Val KendallTau: 0.6652


Epochs 151  : 100%|██████████| 6/6 [00:00<00:00, 39.47it/s, loss=1.11e+4]
Epochs 152  : 100%|██████████| 6/6 [00:00<00:00, 44.13it/s, loss=1.11e+4]
Epochs 153  : 100%|██████████| 6/6 [00:00<00:00, 42.53it/s, loss=1.11e+4]
Epochs 154  : 100%|██████████| 6/6 [00:00<00:00, 43.35it/s, loss=1.11e+4]
Epochs 155  : 100%|██████████| 6/6 [00:00<00:00, 39.01it/s, loss=1.11e+4]
Epochs 156  : 100%|██████████| 6/6 [00:00<00:00, 40.54it/s, loss=1.11e+4]
Epochs 157  : 100%|██████████| 6/6 [00:00<00:00, 44.10it/s, loss=1.11e+4]
Epochs 158  : 100%|██████████| 6/6 [00:00<00:00, 43.80it/s, loss=1.11e+4]
Epochs 159  : 100%|██████████| 6/6 [00:00<00:00, 44.68it/s, loss=1.11e+4]
Epochs 160  : 100%|██████████| 6/6 [00:00<00:00, 42.46it/s, loss=1.11e+4]
Epochs 161  : 100%|██████████| 6/6 [00:00<00:00, 44.03it/s, loss=1.11e+4]
Epochs 162  : 100%|██████████| 6/6 [00:00<00:00, 36.59it/s, loss=1.11e+4]
Epochs 163  : 100%|██████████| 6/6 [00:00<00:00, 43.56it/s, loss=1.11e+4]
Epochs 164  : 100%|██████████| 6/6 [00

Val Acc: 94.27 % | Val KendallTau: 0.6705


In [33]:
# g = _[2]
# g.degree(list(range(99, 105)))

## To-Do List
* (done) loss_fn 再加上 sigmoid
* (done) pairwise 目前跨圖了
* (done) h 要 normalized
* (done) aggregate 改成 MessagePassing
* (done) synthetic graph 後，shuffle graph 的順序
* (done) 加入 Epochs
* Metric: top1, 5, 10
* Metric: kendall tau distance
* wall-clock running time
* test step
* (done) change to leaky relu -> back to relu