## SIGN: Scalable Inception Graph Networks  
- experiments with flickr dataset  

In [2]:
import os
import sys
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch.utils.data import DataLoader

from torch_geometric.datasets import Flickr
import torch_geometric.transforms as T

from torch_sparse import SparseTensor

import pytorch_lightning as pl

sys.path.append('../')
from utils import *
from torch_custom_funcs import *
logger = make_logger(name='sign_logger')

#### 1. Data 변화 체크하기

In [5]:
# Preparing Data
K = 4

path = os.path.join(os.getcwd(), 'data', 'Flickr')
transform = T.Compose([T.NormalizeFeatures(), T.SIGN(K)])
dataset = Flickr(path, transform=T.NormalizeFeatures())
data = dataset[0]

In [6]:
# Check
inspector = GraphInspector(data)
inspector.get_basic_info()

{'num_nodes': 89250, 'num_edges': 899756}


In [7]:
inspector.inspect('edge_index')

most_freq_appeared_node: 50 with 5425


Adjacency Matrix의 변화에 대해 확인한다.

In [8]:
row, col = data.edge_index
adj_t = SparseTensor(
    row=col, col=row, sparse_sizes=(data.num_nodes, data.num_nodes))

deg = adj_t.sum(dim=1).to(torch.float)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)

a1 = adj_t
a2 = a1 @ a1
a3 = a2 @ a1

In [13]:
print("Density Change: {:.4f}, {:.4f}, {:.4f}".format(a1.density(), a2.density(), a3.density()))

Density Change: 0.0001, 0.0097, 0.1114


$AX$ 의 변화에 대해 확인한다.

In [14]:
# Ax
x1 = adj_t @ data.x
x2 = adj_t @ x1
x3 = adj_t @ x2
x4 = adj_t @ x3

# Cosine Sim
x_list = [x1, x2, x3, x4]

cos = nn.CosineSimilarity(dim=1, eps=1e-8)
N = 4
output = np.zeros((N, N))
for i in range(N):
    for j in range(N):
        output[i, j] = cos(x_list[i], x_list[j]).detach().mean().item()

print(pd.DataFrame(output))

          0         1         2         3
0  1.000000  0.806187  0.883690  0.824013
1  0.806187  1.000000  0.961076  0.984005
2  0.883690  0.961076  1.000000  0.985965
3  0.824013  0.984005  0.985965  1.000000


2단계를 건널 경우 유사도가 높아지는 것으로 보인다.

#### 2. Wrapping with DataLoader

In [15]:
K = 4
BATCH_SIZE = 1024

path = os.path.join(os.getcwd(), 'data', 'Flickr')
transform = T.Compose([T.NormalizeFeatures(), T.SIGN(K)])
dataset = Flickr(path, transform=transform)
data = dataset[0]

device = get_device()

train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)

train_loader = DataLoader(train_idx, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_idx, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_idx, batch_size=BATCH_SIZE)

#### 3. Model & Lightning Module

In [16]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_size, drop_rate):
        super(GNN, self).__init__()

        self.hidden_size = hidden_size
        self.drop_rate = drop_rate

        # MLP for downstream tasks
        self.lins = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        for _ in range(K+1):
            self.lins.append(Linear(dataset.num_node_features, hidden_size))
            self.bns.append(nn.BatchNorm1d(num_features=hidden_size))

        self.linear_final = Linear((K+1)*hidden_size, dataset.num_classes)

        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, xs):
        hs = []
        for i, x in enumerate(xs):
            h = self.lins[i](x)
            h = self.bns[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=self.drop_rate, training=self.training)
            hs.append(h)
        h = torch.cat(hs, dim=-1)
        h = self.linear_final(h)
        return h.log_softmax(dim=-1)


class TorchGraph(pl.LightningModule):
    def __init__(self, device):
        super(TorchGraph, self).__init__()
        self.model = GNN(hidden_size=HIDDEN_SIZE, drop_rate=DROP_RATE).to(device)

    def forward(self, idx):
        self.model.eval()

    def training_step(self, batch, batch_idx):
        # batch = idx
        self.model.train()

        xs = [data.x[batch].to(device)]
        xs += [data[f'x{i}'][batch].to(device) for i in range(1, K+1)]
        y_true = data.y[batch].to(device)

        y_pred = self.model(xs)
        loss = F.nll_loss(y_pred, y_true)

        self.log(name="train_loss", value=loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        self.model.eval()

        xs = [data.x[batch].to(device)]
        xs += [data[f'x{i}'][batch].to(device) for i in range(1, K+1)]
        y_true = data.y[batch].to(device)

        y_pred = self.model(xs)
        loss = F.nll_loss(y_pred, y_true)

        self.log(name="val_loss", value=loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=LR)
        return optimizer

#### 4. Train

In [17]:
HIDDEN_SIZE = 1024
DROP_RATE = 0.5
LR = 0.01
EPOCHS = 200

sign_graph = TorchGraph(device=device)
logger.info(f"num model parameters: {get_num_params(sign_graph.model)}")

2021-09-08 20:29:11,375 - sign_logger - num model parameters: 2611207


In [18]:
print(sign_graph.model)

GNN(
  (lins): ModuleList(
    (0): Linear(in_features=500, out_features=1024, bias=True)
    (1): Linear(in_features=500, out_features=1024, bias=True)
    (2): Linear(in_features=500, out_features=1024, bias=True)
    (3): Linear(in_features=500, out_features=1024, bias=True)
    (4): Linear(in_features=500, out_features=1024, bias=True)
  )
  (bns): ModuleList(
    (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (linear_final): Linear(in_features=5120, out_features=7, bias=True)
)


In [19]:
param_report = get_parameter_report(sign_graph.model)
print(param_report)

                   name        shape  num_param
0         lins.0.weight  [1024, 500]     512000
1           lins.0.bias       [1024]       1024
2         lins.1.weight  [1024, 500]     512000
3           lins.1.bias       [1024]       1024
4         lins.2.weight  [1024, 500]     512000
5           lins.2.bias       [1024]       1024
6         lins.3.weight  [1024, 500]     512000
7           lins.3.bias       [1024]       1024
8         lins.4.weight  [1024, 500]     512000
9           lins.4.bias       [1024]       1024
10         bns.0.weight       [1024]       1024
11           bns.0.bias       [1024]       1024
12         bns.1.weight       [1024]       1024
13           bns.1.bias       [1024]       1024
14         bns.2.weight       [1024]       1024
15           bns.2.bias       [1024]       1024
16         bns.3.weight       [1024]       1024
17           bns.3.bias       [1024]       1024
18         bns.4.weight       [1024]       1024
19           bns.4.bias       [1024]    

In [20]:
trainer = pl.Trainer(
    gpus=1,
    auto_scale_batch_size=None,
    deterministic=True,
    max_epochs=EPOCHS
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [21]:
trainer.fit(model=sign_graph, train_dataloaders=train_loader, val_dataloaders=val_loader)

  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | GNN  | 2.6 M 
-------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total params
10.445    Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(
  rank_zero_warn(


Epoch 199: 100%|██████████| 66/66 [00:01<00:00, 66.60it/s, loss=0.082, v_num=0, train_loss_step=0.0843, val_loss_step=6.270, val_loss_epoch=6.320, train_loss_epoch=0.0845]


Training Loss가 빠르게 줄어드는 것에 비하여 Validation Loss는 오히려 발산하고 있다. 오직 Training Data 자체에 대해서만 특징을 학습하고 있는 것으로 추측된다.

#### 5. Evaluate

In [22]:
@torch.no_grad()
def test(model, loader):
    model.eval()

    total_correct = total_examples = 0
    for idx in loader:
        xs = [data.x[idx].to(device)]
        xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)]
        y = data.y[idx].to(device)

        out = model(xs)
        total_correct += int((out.argmax(dim=-1) == y).sum())
        total_examples += idx.numel()

    return total_correct / total_examples

train_acc = test(sign_graph.model.to(device), train_loader)
val_acc = test(sign_graph.model.to(device), val_loader)
test_acc = test(sign_graph.model.to(device), test_loader)

print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Train: 0.9758, Val: 0.4814, Test: 0.4837


Overfitting이 심하여 Generalization이 잘 되지 않는 결과를 보여준다. 논문의 실험 결과에서는 F1 Score가 Train/Val/Test 중 어떤 데이터셋에서 나온 결과물인지 명확히 밝히고 있지 않다.  