In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


### GraphSAGE(pytorch-version)

#### import Module 

In [3]:
import torch
print(torch.__version__)

1.13.0+cu116


In [4]:
!pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
!pip install torch-geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu116.html
Collecting pyg-lib
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/pyg_lib-0.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 12.9 MB/s 
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (9.4 MB)
[K     |████████████████████████████████| 9.4 MB 79.5 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.15%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (4.6 MB)
[K     |████████████████████████████████| 4.6 MB 95.0 MB/s 
Installing collected packages: torch-sparse, torch-scatter, pyg-lib
Successfully installed pyg-lib-0.1.0+pt113cu116 torch-scatter-2.1.0+pt113cu116 torch-sparse-0.6.15+pt113cu116
Looking in i

#### configuration

In [5]:
import logging
import pickle
import torch

def make_logger(name=None):
    # https://hwangheek.github.io/2019/python-logging/
    logger = logging.getLogger(name) # 로거 객체를 생성
    logger.setLevel(logging.DEBUG) # 로그의 레벨
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s") # LogRecord의 출력 형태 지정

    # Handler는 로그 메시지를 출력하는 역할
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(formatter)

    logger.addHandler(console)
    return logger

def dump_pickle(address, file):
    with open(address, 'wb') as f:
        pickle.dump(file, f)

def load_pickle(address):
    with open(address, 'rb') as f:
        data = pickle.load(f)
    return data

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, save_path='checkpoint.pt'):
        """
        :param patience: how many times you will wait before earlystopping
        :param save_path: where to save checkpoint
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.save_path = save_path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model, val_loss)
        elif score < self.best_score:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model, val_loss)
            self.counter = 0 # reset

    def save_checkpoint(self, model, val_loss):
        if self.verbose:
            print(f"val loss: ({self.val_loss_min:.6f} -> {val_loss:.6f})")
        torch.save(model.state_dict(), self.save_path)
        self.val_loss_min = val_loss

In [6]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv

logger = make_logger(name='graphsage_logger')

In [7]:
%cd gdrive/MyDrive/GraphSAGE

[Errno 2] No such file or directory: 'gdrive/MyDrive/GraphSAGE'
/content


In [8]:
# Load Reddit Dataset
import os
path = os.path.join(os.getcwd(),'data','Reddit')

dataset = Reddit(path)
data = dataset[0]

Downloading https://data.dgl.ai/dataset/reddit.zip
Extracting /content/data/Reddit/raw/reddit.zip
Processing...
Done!


In [9]:
print(data)

Data(x=[232965, 602], edge_index=[2, 114615892], y=[232965], train_mask=[232965], val_mask=[232965], test_mask=[232965])


In [10]:
# Data verification
logger.info(f"Node Feature MAtrix InFo : # Nodes: {data.x.shape[0]}, # Node Features : {data.x.shape[1]}")

# Edge index
# Graph Connectivity in Coo Format with shape (2, num_edges)
logger.info(f"Edge index shape : {data.edge_index.shape}")
logger.info(f"Edge weight: {data.edge_attr}")

# train_mask denotes against which nodes to train
print(len(data.train_mask))
print(data.train_mask.sum())

2022-12-13 06:38:00,371 - graphsage_logger - Node Feature MAtrix InFo : # Nodes: 232965, # Node Features : 602
INFO:graphsage_logger:Node Feature MAtrix InFo : # Nodes: 232965, # Node Features : 602
2022-12-13 06:38:00,373 - graphsage_logger - Edge index shape : torch.Size([2, 114615892])
INFO:graphsage_logger:Edge index shape : torch.Size([2, 114615892])
2022-12-13 06:38:00,376 - graphsage_logger - Edge weight: None
INFO:graphsage_logger:Edge weight: None


232965
tensor(153431)


In [11]:
# Define Sampler
train_loader = NeighborSampler(
    data.edge_index, node_idx=data.train_mask,
    sizes=[25, 10], batch_size=1024, shuffle=True, num_workers=12)

subgraph_loader = NeighborSampler(
    data.edge_index, node_idx=None,
    sizes=[-1], batch_size=1024, shuffle=False, num_workers=12)




In [12]:
print(train_loader)
print(subgraph_loader)

NeighborSampler(sizes=[25, 10])
NeighborSampler(sizes=[-1])


In [13]:
next(iter(train_loader))

(1024,
 tensor([ 56763, 156564, 104494,  ...,  80749, 195147, 146918]),
 [EdgeIndex(edge_index=tensor([[10866, 10878, 19709,  ..., 21668, 21671, 21672],
          [    0,     0,     0,  ..., 21673, 21673, 21673]]), e_id=tensor([86646886, 34334275, 76882858,  ..., 88386657,  9981250,  7102880]), size=(106277, 21674)),
  EdgeIndex(edge_index=tensor([[  457,  1024,  1025,  ..., 21671, 21672, 21673],
          [    0,     0,     0,  ...,  1023,  1023,  1023]]), e_id=tensor([20415259, 43538122, 51235531,  ...,  9978656,  7102819, 35764977]), size=(21674, 1024))])

In [14]:
# Look
batch_size, n_id, adjs = next(iter(train_loader))

# 1) batch_size
# 현재 batch size를 의미함 (integer)
logger.info(f"Current Batch Size: {batch_size}")

# 2) n_id
# 이번 Subgraph에서 사용된 모든 node id
# batch_size개의 Row를 예측하기 위해서 이에 대한 1차 이웃 node A개가 필요하고
# 1차 이웃 node A개를 위해서는 2차 이웃 node B개가 필요함
# n_id.shape = batch_size + A + B
logger.info(f"현재 Subgraph에서 사용된 모든 node id의 개수: {n_id.shape[0]}")

# 3) adjs
# 아래와 같이 Layer의 수가 2개이면 adjs는 길이 2의 List가 된다.
# head node가 있고 1-hop neighbors와 2-hop neighbors가 있다고 할 때
# adjs[1]이 head node와 1-hop neighbors의 관계를 설명하며  (1번째 Layer)
# adjs[0]이 1-hop neighbors와 2-hop neighbors의 관계를 설명한다. (2번째 Layer)
logger.info(f"Layer의 수: {len(adjs)}")

# 각 리스트에는 아래와 같은 튜플이 들어있다.
# (edge_index, e_id, size)
# edge_index: source -> target nodes를 기록한 bipartite edges
# e_id: 위 edge_index에 들어있는 index가 Full Graph에서 갖는 node id

# size: 위 edge_index에 들어있는 node의 수를 튜플로 나타낸 것으로
# head -> 1-hop 관계를 예시로 들면,
# head node의 수가 a개, 1-hop node의 수가 b개라고 했을 때
# size = (a+b, a)
# 또한 target node의 경우 source nodes의 리스트의 시작 부분에 포함되어 있어
# skip-connections나 self-loops를 쉽게 사용할 수 있게 되어 있음
A = adjs[1].size[0] - batch_size
B = adjs[0].size[0] - A - batch_size

logger.info(f"진행 방향: {B}개의 2-hop neighbors ->"
            f"{A}개의 1-hop neighbors -> {batch_size}개의 Head Nodes")

2022-12-13 06:38:39,666 - graphsage_logger - Current Batch Size: 1024
INFO:graphsage_logger:Current Batch Size: 1024
2022-12-13 06:38:39,669 - graphsage_logger - 현재 Subgraph에서 사용된 모든 node id의 개수: 107034
INFO:graphsage_logger:현재 Subgraph에서 사용된 모든 node id의 개수: 107034
2022-12-13 06:38:39,671 - graphsage_logger - Layer의 수: 2
INFO:graphsage_logger:Layer의 수: 2
2022-12-13 06:38:39,675 - graphsage_logger - 진행 방향: 85156개의 2-hop neighbors ->20854개의 1-hop neighbors -> 1024개의 Head Nodes
INFO:graphsage_logger:진행 방향: 85156개의 2-hop neighbors ->20854개의 1-hop neighbors -> 1024개의 Head Nodes


In [15]:
print(adjs)

[EdgeIndex(edge_index=tensor([[  1024,   1044,   1203,  ..., 107031, 107032, 107033],
        [     0,      1,      1,  ...,  21877,  21877,  21877]]), e_id=tensor([ 22854837, 103644544,  74461702,  ...,   9388556,  50869465,
         34961078]), size=(107034, 21878)), EdgeIndex(edge_index=tensor([[ 1024,  1025,  1026,  ..., 21875, 21876, 21877],
        [    0,     1,     1,  ...,  1023,  1023,  1023]]), e_id=tensor([ 22854837, 113115088,  37512591,  ...,  71936222,  12968253,
        101202791]), size=(21878, 1024))]


In [16]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [17]:
# Define Model
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()

        self.num_layers = 2

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)

            # 마지막 Layer는 Dropout을 적용하지 않는다.
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x.log_softmax(dim=-1)

    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                xs.append(x.cpu())

                pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)

        pbar.close()
        return x_all


In [18]:
model = GraphSAGE(dataset.num_features, 256, dataset.num_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

x = data.x.to(device)
y = data.y.squeeze().to(device)


In [19]:

def train(epoch):
    model.train()

    pbar = tqdm(total=int(data.train_mask.sum()))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = 0
    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs = [adj.to(device) for adj in adjs]

        optimizer.zero_grad()
        out = model(x[n_id], adjs)
        loss = F.nll_loss(out, y[n_id[:batch_size]])
        loss.backward()
        optimizer.step()

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
        pbar.update(batch_size)

    pbar.close()

    loss = total_loss / len(train_loader)
    approx_acc = total_correct / int(data.train_mask.sum())
    return loss, approx_acc


@torch.no_grad()
def test():
    model.eval()
    out = model.inference(x)

    y_true = y.cpu().unsqueeze(-1)
    y_pred = out.argmax(dim=-1, keepdim=True)

    results = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
    return results


In [21]:

for epoch in range(1, 31):
    loss, acc = train(epoch)
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
    train_acc, val_acc, test_acc = test()
    print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Epoch 01: 100%|██████████| 153431/153431 [00:16<00:00, 9131.56it/s] 


Epoch 01, Loss: 0.5766, Approx. Train: 0.9314


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15515.94it/s]


Train: 0.9634, Val: 0.9519, Test: 0.9508


Epoch 02: 100%|██████████| 153431/153431 [00:16<00:00, 9240.10it/s] 


Epoch 02, Loss: 0.5094, Approx. Train: 0.9341


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15482.93it/s]


Train: 0.9641, Val: 0.9528, Test: 0.9517


Epoch 03: 100%|██████████| 153431/153431 [00:16<00:00, 9238.40it/s] 


Epoch 03, Loss: 0.5552, Approx. Train: 0.9334


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15396.85it/s]


Train: 0.9667, Val: 0.9531, Test: 0.9520


Epoch 04: 100%|██████████| 153431/153431 [00:16<00:00, 9161.21it/s] 


Epoch 04, Loss: 0.5468, Approx. Train: 0.9354


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15477.33it/s]


Train: 0.9663, Val: 0.9509, Test: 0.9508


Epoch 05: 100%|██████████| 153431/153431 [00:16<00:00, 9166.52it/s] 


Epoch 05, Loss: 0.5338, Approx. Train: 0.9346


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15506.38it/s]


Train: 0.9681, Val: 0.9530, Test: 0.9533


Epoch 06: 100%|██████████| 153431/153431 [00:16<00:00, 9257.12it/s] 


Epoch 06, Loss: 0.5386, Approx. Train: 0.9346


Evaluating: 100%|██████████| 465930/465930 [00:29<00:00, 15554.93it/s]


Train: 0.9680, Val: 0.9528, Test: 0.9520


Epoch 07: 100%|██████████| 153431/153431 [00:17<00:00, 8757.41it/s] 


Epoch 07, Loss: 0.5258, Approx. Train: 0.9354


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15405.95it/s]


Train: 0.9671, Val: 0.9514, Test: 0.9513


Epoch 08: 100%|██████████| 153431/153431 [00:16<00:00, 9110.50it/s] 


Epoch 08, Loss: 0.5395, Approx. Train: 0.9348


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15368.08it/s]


Train: 0.9677, Val: 0.9527, Test: 0.9516


Epoch 09: 100%|██████████| 153431/153431 [00:16<00:00, 9058.69it/s] 


Epoch 09, Loss: 0.5800, Approx. Train: 0.9356


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15349.58it/s]


Train: 0.9680, Val: 0.9517, Test: 0.9507


Epoch 10: 100%|██████████| 153431/153431 [00:16<00:00, 9144.81it/s] 


Epoch 10, Loss: 0.5072, Approx. Train: 0.9360


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15381.62it/s]


Train: 0.9695, Val: 0.9525, Test: 0.9509


Epoch 11: 100%|██████████| 153431/153431 [00:16<00:00, 9029.21it/s] 


Epoch 11, Loss: 0.4930, Approx. Train: 0.9385


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15416.57it/s]


Train: 0.9706, Val: 0.9524, Test: 0.9531


Epoch 12: 100%|██████████| 153431/153431 [00:16<00:00, 9146.54it/s] 


Epoch 12, Loss: 0.5410, Approx. Train: 0.9374


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15347.38it/s]


Train: 0.9699, Val: 0.9519, Test: 0.9518


Epoch 13: 100%|██████████| 153431/153431 [00:16<00:00, 9052.86it/s] 


Epoch 13, Loss: 0.4940, Approx. Train: 0.9386


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15359.30it/s]


Train: 0.9718, Val: 0.9531, Test: 0.9525


Epoch 14: 100%|██████████| 153431/153431 [00:18<00:00, 8403.56it/s] 


Epoch 14, Loss: 0.5590, Approx. Train: 0.9389


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15308.01it/s]


Train: 0.9713, Val: 0.9528, Test: 0.9518


Epoch 15: 100%|██████████| 153431/153431 [00:16<00:00, 9051.45it/s] 


Epoch 15, Loss: 0.5261, Approx. Train: 0.9388


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15354.47it/s]


Train: 0.9715, Val: 0.9527, Test: 0.9521


Epoch 16: 100%|██████████| 153431/153431 [00:16<00:00, 9104.57it/s] 


Epoch 16, Loss: 0.5327, Approx. Train: 0.9384


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15326.83it/s]


Train: 0.9712, Val: 0.9515, Test: 0.9524


Epoch 17: 100%|██████████| 153431/153431 [00:16<00:00, 9041.78it/s] 


Epoch 17, Loss: 0.4713, Approx. Train: 0.9400


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15294.79it/s]


Train: 0.9723, Val: 0.9514, Test: 0.9513


Epoch 18: 100%|██████████| 153431/153431 [00:17<00:00, 9019.50it/s] 


Epoch 18, Loss: 0.4891, Approx. Train: 0.9407


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15363.23it/s]


Train: 0.9720, Val: 0.9539, Test: 0.9516


Epoch 19: 100%|██████████| 153431/153431 [00:16<00:00, 9148.80it/s] 


Epoch 19, Loss: 0.5029, Approx. Train: 0.9396


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15340.51it/s]


Train: 0.9718, Val: 0.9525, Test: 0.9522


Epoch 20: 100%|██████████| 153431/153431 [00:16<00:00, 9033.30it/s] 


Epoch 20, Loss: 0.4942, Approx. Train: 0.9397


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15165.69it/s]


Train: 0.9724, Val: 0.9523, Test: 0.9514


Epoch 21: 100%|██████████| 153431/153431 [00:16<00:00, 9054.80it/s] 


Epoch 21, Loss: 0.5234, Approx. Train: 0.9406


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15303.54it/s]


Train: 0.9741, Val: 0.9541, Test: 0.9526


Epoch 22: 100%|██████████| 153431/153431 [00:16<00:00, 9105.12it/s] 


Epoch 22, Loss: 0.4710, Approx. Train: 0.9416


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15332.84it/s]


Train: 0.9744, Val: 0.9543, Test: 0.9522


Epoch 23: 100%|██████████| 153431/153431 [00:16<00:00, 9092.66it/s] 


Epoch 23, Loss: 0.5066, Approx. Train: 0.9419


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15300.43it/s]


Train: 0.9742, Val: 0.9544, Test: 0.9527


Epoch 24: 100%|██████████| 153431/153431 [00:16<00:00, 9189.08it/s] 


Epoch 24, Loss: 0.5067, Approx. Train: 0.9404


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15377.75it/s]


Train: 0.9721, Val: 0.9510, Test: 0.9507


Epoch 25: 100%|██████████| 153431/153431 [00:16<00:00, 9184.25it/s] 


Epoch 25, Loss: 0.4844, Approx. Train: 0.9413


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15494.27it/s]


Train: 0.9743, Val: 0.9531, Test: 0.9517


Epoch 26: 100%|██████████| 153431/153431 [00:16<00:00, 9103.22it/s] 


Epoch 26, Loss: 0.4888, Approx. Train: 0.9421


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15267.68it/s]


Train: 0.9720, Val: 0.9517, Test: 0.9508


Epoch 27: 100%|██████████| 153431/153431 [00:16<00:00, 9184.87it/s] 


Epoch 27, Loss: 0.5573, Approx. Train: 0.9399


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15187.66it/s]


Train: 0.9739, Val: 0.9529, Test: 0.9533


Epoch 28: 100%|██████████| 153431/153431 [00:16<00:00, 9137.31it/s] 


Epoch 28, Loss: 0.4918, Approx. Train: 0.9414


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15348.26it/s]


Train: 0.9752, Val: 0.9535, Test: 0.9527


Epoch 29: 100%|██████████| 153431/153431 [00:17<00:00, 8988.56it/s] 


Epoch 29, Loss: 0.4840, Approx. Train: 0.9417


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15380.73it/s]


Train: 0.9743, Val: 0.9527, Test: 0.9521


Epoch 30: 100%|██████████| 153431/153431 [00:16<00:00, 9078.52it/s] 


Epoch 30, Loss: 0.5294, Approx. Train: 0.9423


Evaluating: 100%|██████████| 465930/465930 [00:30<00:00, 15374.66it/s]

Train: 0.9748, Val: 0.9527, Test: 0.9533



