In [1]:
import numpy as np
from copy import deepcopy
import random

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader

In [4]:
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred.evaluate import Evaluator

In [5]:
from dataset import temp_partition_arxiv

In [6]:
dataset_name = 'ogbn-arxiv'
dataset = PygNodePropPredDataset(name = dataset_name, root='/home/hhchung/data/ogb-data')

In [7]:
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

## Model Structure ##

In [8]:
class TwoLayerGraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, out_dim)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        return x

class ThreeLayerGraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, out_dim)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv3(x, edge_index)
        return x
    

    
class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x = self.linear1(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        x = self.linear2(x)
        # x = F.softmax(x, dim=1)
        return x

In [9]:
def train(encoder, mlp, optimizer, data):
    encoder.train()
    mlp.train()
    
    out = F.log_softmax(mlp(encoder(data.x, data.edge_index)), dim=1)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()
    
@torch.no_grad()
def test(encoder, mlp, data, evaluator):
    encoder.eval()
    mlp.eval()
    
    out = F.log_softmax(mlp(encoder(data.x, data.edge_index)), dim=1)
    val_loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask]).item()
    y_pred = out.argmax(dim=-1, keepdim=True)
    val_acc = evaluator.eval({
        'y_true': data.y[data.val_mask].unsqueeze(1),
        'y_pred': y_pred[data.val_mask],
    })['acc']
    
    return val_loss, val_acc

## Load Data ##

In [10]:
dataset = PygNodePropPredDataset(name = dataset_name, root='/home/hhchung/data/ogb-data')
data = dataset[0]
data.edge_index = to_undirected(data.edge_index, data.num_nodes) # mimicking barlow twins repo

## Data Partition ##

* Train: 0-2011
* Val: 2012
* Test:
** 2013-2014 (then adapt 2012-2013 adapt-val 2014)
** 2015-2016 (then adapt 2014-2015 adapt-val 2016)
** 2017-2018 (then adapt 2016-2017 adapt-val 2018)
** 2019-2020

## Source Training Stage ##

* Train: 0-2011
* Val: 2012

In [11]:
feat_dim = data.x.shape[1]
class_dim = data.y
hidden_dim = 128
emb_dim = 256
encoder = ThreeLayerGraphSAGE(feat_dim, hidden_dim, emb_dim)
mlp = MLPHead(emb_dim, emb_dim // 4, 40)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(mlp.parameters()), lr=1e-3)
epochs = 500

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
mlp = mlp.to(device)

In [12]:
data_2012_2013 = temp_partition_arxiv(data, year_bound=[-1,2012,2013], proportion=1.0)
data_2012_2013 = data_2012_2013.to(device)

In [13]:
best_acc = 0
best_encoder = None
best_mlp = None
evaluator = Evaluator(name='ogbn-arxiv')
for e in range(1, epochs + 1):
    train_loss = train(encoder, mlp, optimizer, data_2012_2013)
    val_loss, val_acc = test(encoder, mlp, data_2012_2013, evaluator)
    print(f"Epoch:{e}/{epochs} Train Loss:{round(train_loss,4)} Val Loss:{round(val_loss,4)} Val Acc:{round(val_acc, 4)}")
    if val_acc > best_acc:
        best_acc = val_acc
        best_encoder = deepcopy(encoder)
        best_mlp = deepcopy(mlp)

Epoch:1/500 Train Loss:3.6856 Val Loss:3.6195 Val Acc:0.1966
Epoch:2/500 Train Loss:3.5901 Val Loss:3.5396 Val Acc:0.2044
Epoch:3/500 Train Loss:3.4895 Val Loss:3.4454 Val Acc:0.2045
Epoch:4/500 Train Loss:3.368 Val Loss:3.3583 Val Acc:0.2045
Epoch:5/500 Train Loss:3.2441 Val Loss:3.3238 Val Acc:0.2045
Epoch:6/500 Train Loss:3.1718 Val Loss:3.3318 Val Acc:0.2045
Epoch:7/500 Train Loss:3.1591 Val Loss:3.2723 Val Acc:0.2045
Epoch:8/500 Train Loss:3.0964 Val Loss:3.2117 Val Acc:0.2033
Epoch:9/500 Train Loss:3.0419 Val Loss:3.2022 Val Acc:0.1841
Epoch:10/500 Train Loss:3.0427 Val Loss:3.188 Val Acc:0.1699
Epoch:11/500 Train Loss:3.0402 Val Loss:3.1491 Val Acc:0.1918
Epoch:12/500 Train Loss:3.0042 Val Loss:3.1139 Val Acc:0.2061
Epoch:13/500 Train Loss:2.9534 Val Loss:3.0913 Val Acc:0.2057
Epoch:14/500 Train Loss:2.9167 Val Loss:3.0838 Val Acc:0.2045
Epoch:15/500 Train Loss:2.9075 Val Loss:3.0642 Val Acc:0.2047
Epoch:16/500 Train Loss:2.8882 Val Loss:3.0173 Val Acc:0.2067
Epoch:17/500 Train 

Epoch:147/500 Train Loss:1.4733 Val Loss:1.7418 Val Acc:0.5217
Epoch:148/500 Train Loss:1.465 Val Loss:1.7313 Val Acc:0.5187
Epoch:149/500 Train Loss:1.4676 Val Loss:1.7253 Val Acc:0.5235
Epoch:150/500 Train Loss:1.4662 Val Loss:1.7294 Val Acc:0.5256
Epoch:151/500 Train Loss:1.4597 Val Loss:1.7305 Val Acc:0.52
Epoch:152/500 Train Loss:1.4582 Val Loss:1.7274 Val Acc:0.5285
Epoch:153/500 Train Loss:1.461 Val Loss:1.719 Val Acc:0.5228
Epoch:154/500 Train Loss:1.4537 Val Loss:1.7302 Val Acc:0.5228
Epoch:155/500 Train Loss:1.4552 Val Loss:1.7212 Val Acc:0.5226
Epoch:156/500 Train Loss:1.4535 Val Loss:1.7305 Val Acc:0.5254
Epoch:157/500 Train Loss:1.4523 Val Loss:1.7202 Val Acc:0.5276
Epoch:158/500 Train Loss:1.4455 Val Loss:1.7195 Val Acc:0.5271
Epoch:159/500 Train Loss:1.4445 Val Loss:1.7109 Val Acc:0.5254
Epoch:160/500 Train Loss:1.4483 Val Loss:1.7184 Val Acc:0.5271
Epoch:161/500 Train Loss:1.439 Val Loss:1.7219 Val Acc:0.5229
Epoch:162/500 Train Loss:1.4448 Val Loss:1.7167 Val Acc:0.526

Epoch:281/500 Train Loss:1.2514 Val Loss:1.5722 Val Acc:0.558
Epoch:282/500 Train Loss:1.2441 Val Loss:1.5783 Val Acc:0.5518
Epoch:283/500 Train Loss:1.244 Val Loss:1.5827 Val Acc:0.5576
Epoch:284/500 Train Loss:1.2438 Val Loss:1.5763 Val Acc:0.5639
Epoch:285/500 Train Loss:1.2442 Val Loss:1.563 Val Acc:0.5596
Epoch:286/500 Train Loss:1.2427 Val Loss:1.5702 Val Acc:0.5618
Epoch:287/500 Train Loss:1.2425 Val Loss:1.571 Val Acc:0.5622
Epoch:288/500 Train Loss:1.2354 Val Loss:1.5716 Val Acc:0.5566
Epoch:289/500 Train Loss:1.2391 Val Loss:1.5705 Val Acc:0.5621
Epoch:290/500 Train Loss:1.239 Val Loss:1.5751 Val Acc:0.5573
Epoch:291/500 Train Loss:1.2379 Val Loss:1.5692 Val Acc:0.5611
Epoch:292/500 Train Loss:1.2351 Val Loss:1.5626 Val Acc:0.5596
Epoch:293/500 Train Loss:1.2349 Val Loss:1.5769 Val Acc:0.5602
Epoch:294/500 Train Loss:1.2317 Val Loss:1.5747 Val Acc:0.5546
Epoch:295/500 Train Loss:1.2344 Val Loss:1.5692 Val Acc:0.561
Epoch:296/500 Train Loss:1.2312 Val Loss:1.5733 Val Acc:0.561

Epoch:413/500 Train Loss:1.1169 Val Loss:1.5312 Val Acc:0.5778
Epoch:414/500 Train Loss:1.1165 Val Loss:1.5217 Val Acc:0.5782
Epoch:415/500 Train Loss:1.1233 Val Loss:1.5297 Val Acc:0.5708
Epoch:416/500 Train Loss:1.1214 Val Loss:1.5292 Val Acc:0.5782
Epoch:417/500 Train Loss:1.1173 Val Loss:1.5288 Val Acc:0.573
Epoch:418/500 Train Loss:1.1106 Val Loss:1.5304 Val Acc:0.5807
Epoch:419/500 Train Loss:1.117 Val Loss:1.5356 Val Acc:0.575
Epoch:420/500 Train Loss:1.112 Val Loss:1.5391 Val Acc:0.577
Epoch:421/500 Train Loss:1.1153 Val Loss:1.5303 Val Acc:0.5787
Epoch:422/500 Train Loss:1.1169 Val Loss:1.5277 Val Acc:0.5806
Epoch:423/500 Train Loss:1.1051 Val Loss:1.5384 Val Acc:0.5812
Epoch:424/500 Train Loss:1.115 Val Loss:1.5341 Val Acc:0.5772
Epoch:425/500 Train Loss:1.1119 Val Loss:1.5217 Val Acc:0.5775
Epoch:426/500 Train Loss:1.1057 Val Loss:1.5282 Val Acc:0.5782
Epoch:427/500 Train Loss:1.1094 Val Loss:1.5267 Val Acc:0.5765
Epoch:428/500 Train Loss:1.1106 Val Loss:1.5461 Val Acc:0.573

In [14]:
best_acc

0.5867909867909868

In [15]:
test_acc_list = []

## 2013-2014 ##

In [16]:
data_2013_2015 = temp_partition_arxiv(data, year_bound=[-1,2013,2015], proportion=1.0)
data_2013_2015 = data_2013_2015.to(device)

In [17]:
val_loss, val_acc = test(best_encoder, best_mlp, data_2013_2015, evaluator)
test_acc_list.append(val_acc)

In [18]:
val_loss, val_acc

(1.4982811212539673, 0.5856903233269709)

## 2015-2016 ##

In [19]:
data_2015_2017 = temp_partition_arxiv(data, year_bound=[-1,2015,2017], proportion=1.0)
data_2015_2017 = data_2015_2017.to(device)

In [20]:
val_loss, val_acc = test(best_encoder, best_mlp, data_2015_2017, evaluator)
test_acc_list.append(val_acc)

In [21]:
val_loss, val_acc

(1.5647748708724976, 0.5557200253753436)

## 2017-2018 ##

In [22]:
data_2017_2019 = temp_partition_arxiv(data, year_bound=[-1,2017,2019], proportion=1.0)
data_2017_2019 = data_2017_2019.to(device)

In [23]:
val_loss, val_acc = test(best_encoder, best_mlp, data_2017_2019, evaluator)
test_acc_list.append(val_acc)

In [24]:
val_loss, val_acc

(1.6256769895553589, 0.5425538143283699)

## 2019-2020 ##

In [25]:
data_2019_2021 = temp_partition_arxiv(data, year_bound=[-1,2019,2021], proportion=1.0)
data_2019_2021 = data_2019_2021.to(device)

In [26]:
val_loss, val_acc = test(best_encoder, best_mlp, data_2019_2021, evaluator)
test_acc_list.append(val_acc)

In [27]:
val_loss, val_acc

(1.5660314559936523, 0.5557475875974734)

In [28]:
print(test_acc_list)

[0.5856903233269709, 0.5557200253753436, 0.5425538143283699, 0.5557475875974734]


In [29]:
print(sum(test_acc_list) / len(test_acc_list))

0.5599279376570395
