# Mount drive

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

!ls /content/gdrive/My\ Drive

Mounted at /content/gdrive
 aimas2020
'Automatic Generation of Topic Labels.gslides'
'Colab Notebooks'
 cvdl2020
 iir_book.pdf
 ir_final
'Medical AI'
'Paper Slides'
 Q56094077
 res18_diabete_noaug.pth
'Towards Better Text Understanding and Retrieval through Kernel Entity Saliency Modeling.gslides'
 tsai.ipynb
 獎助學金
 申請資料


In [None]:
# !unzip /content/gdrive/MyDrive/Q56094077/snrs/hw1_0319/hw1_data.zip -d /content/gdrive/MyDrive/Q56094077/snrs/hw1_0319

# Import Library

In [1]:
import os

import torch
import torch.nn as nn

import pandas as pd
import numpy as np
import json

from tqdm import tqdm
from datetime import datetime

In [2]:
import torch_geometric
from torch_geometric.data import Data, DataLoader
import torch_geometric.utils as utils

# Setting

In [3]:
class Setting:
    _root = os.getcwd()

    _data = os.path.join(_root, "hw1_data")

    data_synthetic = os.path.join(_data, "Synthetic", "5000")
    data_youtube = os.path.join(_data, "youtube")
    
    
     # Create dir for train/test
    date_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H-%M")
    root = os.path.join(_root,  date_time)
    if os.path.exists(root):
        pass
    else:
        os.makedirs(root)


    ## Save plt info
    train_info_p = os.path.join(root, "train.json")
    val_info_p = os.path.join(root, "valid.json")
    test_info_p = os.path.join(root, "test.json")

    ## Save plt img
    result_plt_p = os.path.join(root, "train_plt.png")
    test_plt_p = os.path.join(root, "test_plt.png")
    sum_box_p = os.path.join(root, "sum_box.png")
        

    # Setting of training
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    epochs = 10000
    batch = 4

    c = 3
    depth = 5
    p = 128 # embedding dimension of hidden state
    q = int(p/2)

    epochs = 1000
    save_model = os.path.join(root, "weight.pth")
    criterion = torch.nn.BCEWithLogitsLoss()


In [4]:
setting = Setting()

# Dataset

## Data

- data.x	节点特征，维度是[num_nodes, num_node_features]。
- data.edge_index	维度是[2, num_edges]，描述图中节点的关联关系，每一列对应的两个元素，分别是边的起点和重点。数据类型是torch.long。需要注意的是，data.edge_index是定义边的节点的张量（tensor），而不是节点的列表（list）。
- data.edge_attr	边的特征矩阵，维度是[num_edges, num_edge_features]
- data.y	训练目标（维度可以是任意的）。对于节点相关的任务，维度为[num_nodes, *]；对于图相关的任务，维度为[1,*]。
- data.position	节点位置矩阵（Node position matrix），维度为[num_nodes, num_dimensions]。

- [Learning to Identify High Betweenness Centrality Nodes from
Scratch: A Novel Graph Neural Network Approach](https://arxiv.org/pdf/1905.10418.pdf)
- node initial feature = [$(d_v), 1, 1]

In [5]:
synthetic = []
between = []
for f in os.listdir(setting.data_synthetic):
    if "score" in f:
        # ground truth of betweenness centrality
        p = os.path.join(setting.data_synthetic, f)
        between.append(p)
    else:
        p = os.path.join(setting.data_synthetic, f)
        synthetic.append(p)

between.sort()
synthetic.sort()

In [6]:
data_list = []

for index, f in enumerate(synthetic):
    edge_index = torch_geometric.io.read_txt_array(f, dtype=torch.long)
    edge_index = edge_index.t().contiguous()
    edge_index = utils.to_undirected(edge_index)

    row, col = edge_index  
    deg = utils.degree(col) # must use col to get degree, why?
    deg = deg.numpy()  

    vertice = []
    for d in deg:
        vertice.append([d, 1, 1])
    vertice = np.array(vertice, dtype=np.float)
    vertice = torch.from_numpy(vertice)
    
    ### between centrality
    bcs = []
    bc = torch_geometric.io.read_txt_array(between[index], dtype=torch.double)
    bc = bc.t().contiguous()
    row, col = bc
    bc = col
    bc = bc.numpy()
    for b in bc:
        bcs.append([b])

#     bcs = np.array(bcs)
    data = Data(x=vertice, edge_index=edge_index, y=bcs)

    data_list.append(data)

loader = DataLoader(data_list, batch_size=setting.batch)
# print(loader)

# Model

In [7]:
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.transforms import Distance

In [8]:
class Net(MessagePassing):
    def __init__(self, c, p, q, num_layers, device, aggr="add"):
        super(Net, self).__init__(aggr=aggr)
        
        self.num_layers = num_layers
        self.w_0 = torch.nn.Linear(in_features=c, out_features=p).double()
        
        self.rnn = torch.nn.GRUCell(p, p).double()
  
        self.w_4 = torch.nn.Linear(in_features=p, out_features=q).double()
        self.w_5 = torch.nn.Linear(in_features=q, out_features=1).double()
        self.device = device

    def forward(self, data):
        
        x, edge_index = data.x, data.edge_index
        
        # h_0 = x

        # h_1
        x = self.w_0(x)
        x = F.normalize(x, p=2, dim=1)
        
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        deg = torch.add(deg, 1)
        deg_inv_sqrt = torch.pow(deg, -0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        h_s = [x]
        
        
        for i in range(self.num_layers-1):
            # internally calls the message(), aggregate() and update() functions
            m = self.propagate(edge_index, x=x, norm=norm)
            x = self.rnn(m, x)
            x = F.normalize(x, p=2, dim=1) 
           
            h_s.append(x)
        
        h_s = torch.stack(h_s, dim=-1)

        # Use torch.max to replace max_pooling
        z, _ = torch.max(h_s, dim=-1)
        # z = global_max_pool(h_s, torch.tensor([0], dtype=torch.long).to(self.device))
        
        
        ### Decoder
        z = self.w_4(z)
        z = F.relu(z)
        z = self.w_5(z)
        
        return z

    def message(self, x_j, norm: OptTensor):
        return x_j if norm is None else norm.view(-1, 1) * x_j
    

# Train

In [9]:
def load_checkpoint(filepath, device, **params):

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    model = Net(c=params["c"], p=params["p"], q=params["q"], num_layers=params["depth"], device=device).to(device)

    if os.path.exists(filepath):
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model_stat'])
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
        optimizer.load_state_dict(checkpoint['optimizer_stat'])

    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

    return model, optimizer

In [10]:
model, optimizer = load_checkpoint(
                    setting.save_model,
                    setting.device,
                    c=setting.c, 
                    p=setting.p, 
                    q = setting.q, 
                    depth=setting.depth)

In [11]:
model

Net(
  (w_0): Linear(in_features=3, out_features=128, bias=True)
  (rnn): GRUCell(128, 128)
  (w_4): Linear(in_features=128, out_features=64, bias=True)
  (w_5): Linear(in_features=64, out_features=1, bias=True)
)

In [None]:
def top_n(model, data, n=5):

    hit = 0
    
    for index, row in enumerate(x_test_index):

        
        
        x_in = x_test_input[index]
        _user_test = x_in["user_input"]
        _movie_test = x_in["item_input"]
 
        movie_index = np.array([ X[row, 1] ], dtype=int)
        movie_index = map_to_index(movie_index, movie_unique)        

        result = model.predict(x=x_in, batch_size=BATCH_SIZE)
        result = np.reshape(result, ITEMS+LEAVE)
    
        top_n_index = result.argsort()[-TOP_N:]
    
        if movie_index in _movie_test[top_n_index]:
            hit += 1

    print("total hit: %d, \t accuracy: %f " % (hit, hit/len(x_test_index)) )

In [12]:
train_info = {
       "bce": []
}

In [None]:
min_bce = 10000

for epoch in range(setting.epochs):
    
    bce_loss = 0.0
    graph_cnt = 0
    for data in tqdm(loader):
        
        optimizer.zero_grad()
        
        data = data.to(setting.device)
        bc_pr = model(data)
        bc_gt = data.y
        bc_gt = np.array(bc_gt)
        bc_gt = torch.from_numpy(bc_gt).squeeze()
        bc_gt = torch.reshape(bc_gt, (-1, ))
        
        ### random sample 5|V| nodes
        src = (torch.rand(25000) * 4999).long()
        det = (torch.rand(25000) * 4999).long()
        for b in range(len(data.batch)//5000-1):
            src = torch.cat((src, (torch.rand(25000) * 4999).long()+(b+1)*5000))
            det = torch.cat((det, (torch.rand(25000) * 4999).long()+(b+1)*5000))
                        
        y_gt = (bc_gt[det] - bc_gt[src]).squeeze().to(setting.device)
        y_pr = (bc_pr[det] - bc_pr[src]).squeeze()
   
        loss = setting.criterion(y_pr, y_gt)
        loss.backward()
        optimizer.step()

        bce_loss += loss.item()
        graph_cnt += 1
        
#         bce_loss = torch.tensor(0, dtype=torch.float).to(setting.device)
        
#         for b in range(setting.batch):
#             index = picked[b]
#             for i in range(len(index)):
#                 s1, s2 = index[i]
                
#                 y_gt = bc_gt[b][s2] - bc_gt[b][s1]
#                 y_pr = bc_pr[b][s2] - bc_pr[b][s1]
                
#                 y_gt = torch.from_numpy(y_gt).to(setting.device)
#                 loss = setting.criterion(y_pr, y_gt)
#                 bce_loss += loss
                
#         bce_loss += data.num_graphs * loss.item()
#         graph_cnt += data.num_graphs
        
#         bce_loss.backward()
        
        
    l = bce_loss/graph_cnt
    print("Epoch = {}, loss = {}".format(epoch, l))
    
    train_info["bce"].append(l)
    with open(setting.train_info_p, 'w') as f:
        json.dump(train_info, f)

    
    if l < min_bce:
        checkpoint = {
            'model_stat': model.state_dict(),
            'optimizer_stat': optimizer.state_dict(),
        }
        torch.save(checkpoint, setting.save_model)

100%|██████████| 8/8 [00:02<00:00,  3.78it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 0, loss = 0.6931509327813485


100%|██████████| 8/8 [00:01<00:00,  4.38it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 1, loss = 0.6931488629992566


100%|██████████| 8/8 [00:01<00:00,  4.62it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 2, loss = 0.6931497713462723


100%|██████████| 8/8 [00:01<00:00,  4.42it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 3, loss = 0.6931471846216292


100%|██████████| 8/8 [00:01<00:00,  4.37it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 4, loss = 0.6931464541857284


100%|██████████| 8/8 [00:01<00:00,  4.41it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 5, loss = 0.6931461773972476


100%|██████████| 8/8 [00:01<00:00,  4.59it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 6, loss = 0.6931468715734843


100%|██████████| 8/8 [00:01<00:00,  4.33it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 7, loss = 0.6931446732864703


100%|██████████| 8/8 [00:01<00:00,  4.40it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 8, loss = 0.693144645510898


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 9, loss = 0.693143742951595


100%|██████████| 8/8 [00:01<00:00,  4.54it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 10, loss = 0.6931440721052742


100%|██████████| 8/8 [00:01<00:00,  4.30it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 11, loss = 0.6931429348508946


100%|██████████| 8/8 [00:01<00:00,  4.32it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 12, loss = 0.6931432860439365


100%|██████████| 8/8 [00:01<00:00,  4.37it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 13, loss = 0.6931460573822468


100%|██████████| 8/8 [00:01<00:00,  4.58it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 14, loss = 0.6931440404870453


100%|██████████| 8/8 [00:01<00:00,  4.36it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 15, loss = 0.6931431933572934


100%|██████████| 8/8 [00:01<00:00,  4.35it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 16, loss = 0.6931446744474531


100%|██████████| 8/8 [00:01<00:00,  4.41it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 17, loss = 0.6931410503026343


100%|██████████| 8/8 [00:01<00:00,  4.62it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 18, loss = 0.6931455676352883


100%|██████████| 8/8 [00:01<00:00,  4.38it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 19, loss = 0.6931438753043707


100%|██████████| 8/8 [00:01<00:00,  4.38it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 20, loss = 0.6931415333917978


100%|██████████| 8/8 [00:01<00:00,  4.46it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 21, loss = 0.6931413394924253


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 22, loss = 0.6931430973074895


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 23, loss = 0.69314447400693


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 24, loss = 0.6931405055524923


100%|██████████| 8/8 [00:01<00:00,  4.59it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 25, loss = 0.6931437878977469


100%|██████████| 8/8 [00:01<00:00,  4.34it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 26, loss = 0.6931432724150339


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 27, loss = 0.6931439221596263


100%|██████████| 8/8 [00:01<00:00,  4.60it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 28, loss = 0.6931438007469133


100%|██████████| 8/8 [00:01<00:00,  4.39it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 29, loss = 0.6931413188185316


100%|██████████| 8/8 [00:01<00:00,  4.35it/s]
  0%|          | 0/8 [00:00<?, ?it/s]

Epoch = 30, loss = 0.6931426531055153


 12%|█▎        | 1/8 [00:00<00:01,  4.82it/s]

In [None]:
# test fit model
predict = model(data.to(device))
predict.shape


### Sampling nodes

In our experiments, we randomly sample 5|V | source nodes and 5|V |
target nodes with replacement

In [274]:
picked = (torch.rand(25000, 2) * 4999).long()
for b in range(batch-1):
    picked = torch.stack((picked, (torch.rand(25000, 2) * 4999).long()))
    
picked.shape

torch.Size([25000, 2])

In [None]:
for data in loader:
    print(data)

In [192]:
(loader[0].y-loader.y[1]).shape

AttributeError: 'DataLoader' object has no attribute 'y'