In [8]:
import argparse
import sys

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv

from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

from logger import Logger


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        final_embeddings = x

        print(f"embeddings: {final_embeddings}")
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


def train(model, data, train_idx, optimizer):
    model.train()

    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[train_idx]
    loss = F.nll_loss(out, data.y.squeeze(1)[train_idx])
    loss.backward()
    optimizer.step()

    return loss.item()


@torch.no_grad()
def test(model, data, split_idx, evaluator, node_list):
    model.eval()

    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)

    # Create a mask for the specific nodes
    mask = torch.zeros(data.y.size(0), dtype=torch.bool)
    mask[node_list] = True  # Set the specific nodes to True

    # Ensure split indices are tensors
    train_idx = torch.tensor(split_idx['train'], dtype=torch.long)
    valid_idx = torch.tensor(split_idx['valid'], dtype=torch.long)
    test_idx = torch.tensor(split_idx['test'], dtype=torch.long)

    # Use the mask to filter indices correctly for each split
    train_mask = mask[train_idx]
    valid_mask = mask[valid_idx]
    test_mask = mask[test_idx]

    # Filter the true and predicted labels
    y_true_train = data.y[train_idx][train_mask]
    y_pred_train = y_pred[train_idx][train_mask]

    y_true_valid = data.y[valid_idx][valid_mask]
    y_pred_valid = y_pred[valid_idx][valid_mask]

    y_true_test = data.y[test_idx][test_mask]
    y_pred_test = y_pred[test_idx][test_mask]

    # Debugging outputs
    print(f"Train: {y_true_train.numel()}, {y_pred_train.numel()}")
    print(f"Valid: {y_true_valid.numel()}, {y_pred_valid.numel()}")
    print(f"Test: {y_true_test.numel()}, {y_pred_test.numel()}")

    # Evaluate using the filtered true and predicted labels
    train_acc = evaluator.eval({'y_true': y_true_train, 'y_pred': y_pred_train})['acc'] if y_true_train.numel() > 0 else 0.0
    valid_acc = evaluator.eval({'y_true': y_true_valid, 'y_pred': y_pred_valid})['acc'] if y_true_valid.numel() > 0 else 0.0
    test_acc = evaluator.eval({'y_true': y_true_test, 'y_pred': y_pred_test})['acc'] if y_true_test.numel() > 0 else 0.0

    return train_acc, valid_acc, test_acc



In [9]:
def main():
    # Detect if running in Jupyter and simulate argument parsing
    if "ipykernel" in sys.argv[0]:
        class Args:
            device = 0
            log_steps = 1
            use_sage = False
            num_layers = 3
            hidden_channels = 256
            dropout = 0.5
            lr = 0.01
            epochs = 20
            runs = 5
        args = Args()
    else:
        # Argument parsing for command-line execution
        parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)')
        parser.add_argument('--device', type=int, default=0)
        parser.add_argument('--log_steps', type=int, default=1)
        parser.add_argument('--use_sage', action='store_true')
        parser.add_argument('--num_layers', type=int, default=3)
        parser.add_argument('--hidden_channels', type=int, default=256)
        parser.add_argument('--dropout', type=float, default=0.5)
        parser.add_argument('--lr', type=float, default=0.01)
        parser.add_argument('--epochs', type=int, default=500)
        parser.add_argument('--runs', type=int, default=10)
        args = parser.parse_args()
    
    print(args)

    # Your main code (rest of your logic)...
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    node_list = [110223, 146929, 2940, 104544, 62326, 29759, 96890, 47025, 117732, 163206, 61450, 20589, 145422, 33882, 4523, 81254, 82143, 85138, 167093, 125903, 116417, 158097, 95232, 81835, 84070, 53849, 112263, 17237, 34055, 10707, 164818, 32412, 75064, 139893, 161232, 100687, 148029, 55852, 23076, 152166, 44377, 111682, 145090, 132166, 83408, 153864, 143857, 111282, 150656, 36239, 142479, 112753, 149468, 50327, 163800, 45303, 155755, 151189, 160714, 8875, 116826, 142383, 143573, 136795, 277, 120988, 120990, 151641, 44082, 94916, 32814, 117613, 98216, 41936, 16625, 163884, 69260, 29036, 91369, 31451, 161505, 27948, 29235, 15854, 162540, 161899, 157799, 90013, 120275, 57309, 92874, 140989, 11692, 78367, 62285, 80502, 72672, 121790, 12662, 90677]
    
    dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data = data.to(device)

    split_idx = dataset.get_idx_split()

    # Filter the split indices based on your `node_list`
    train_idx = torch.tensor([n for n in split_idx['train'] if n in node_list], device=device)
    valid_idx = torch.tensor([n for n in split_idx['valid'] if n in node_list], device=device)
    test_idx = torch.tensor([n for n in split_idx['test'] if n in node_list], device=device)

    print(train_idx)
    print(valid_idx)
    print(test_idx)

    if args.use_sage:
        model = SAGE(data.num_features, args.hidden_channels,
                     dataset.num_classes, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(data.num_features, args.hidden_channels,
                    dataset.num_classes, args.num_layers,
                    args.dropout).to(device)

    evaluator = Evaluator(name='ogbn-arxiv')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, data, train_idx, optimizer)
            result = test(model, data, {'train': train_idx, 'valid': valid_idx, 'test': test_idx}, evaluator, node_list)
            logger.add_result(run, result)

            if epoch % args.log_steps == 0:
                train_acc, valid_acc, test_acc = result
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {100 * train_acc:.2f}%, '
                      f'Valid: {100 * valid_acc:.2f}%, '
                      f'Test: {100 * test_acc:.2f}%')

        logger.print_statistics(run)
    logger.print_statistics()

if __name__ == "__main__":
    main()

<__main__.main.<locals>.Args object at 0x74e9f17e6f30>


  self.data, self.slices = torch.load(self.processed_paths[0])


tensor([   277,   2940,  10707,  11692,  12662,  27948,  29235,  31451,  32814,
         33882,  34055,  36239,  41936,  44082,  45303,  47025,  55852,  57309,
         61450,  69260,  72672,  78367,  82143,  83408,  85138,  91369,  92874,
         96890,  98216, 100687, 104544, 110223, 111282, 112753, 116826, 120990,
        121790, 125903, 139893, 142383, 142479, 145090, 148029, 151189, 152166,
        153864, 155755, 157799, 161505, 161899, 162540, 163206, 167093])
tensor([  4523,  16625,  23076,  32412,  44377,  80502,  81254,  90677, 111682,
        120988, 136795, 143573, 143857, 158097, 160714, 163800])
tensor([  8875,  15854,  17237,  20589,  29036,  29759,  50327,  53849,  62285,
         62326,  75064,  81835,  84070,  90013,  94916,  95232, 112263, 116417,
        117613, 117732, 120275, 132166, 140989, 145422, 146929, 149468, 150656,
        151641, 161232, 163884, 164818])
embeddings: tensor([[0.0000, 0.0000, 0.0000,  ..., 9.2168, 0.0000, 0.0000],
        [0.0000, 1.9514, 

  train_idx = torch.tensor(split_idx['train'], dtype=torch.long)
  valid_idx = torch.tensor(split_idx['valid'], dtype=torch.long)
  test_idx = torch.tensor(split_idx['test'], dtype=torch.long)


embeddings: tensor([[6.3981, 0.0000, 1.3919,  ..., 8.7285, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0806, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.6280, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2642, 0.0000],
        [0.0000, 0.6493, 0.0000,  ..., 0.0000, 0.0000, 3.6422]],
       grad_fn=<MulBackward0>)
embeddings: tensor([[0.7681, 0.0000, 0.1222,  ..., 0.0748, 0.2687, 0.0000],
        [0.1911, 0.0000, 0.0302,  ..., 0.0000, 0.0362, 0.0416],
        [0.2648, 0.0000, 0.0389,  ..., 0.0000, 0.0888, 0.0474],
        ...,
        [0.1316, 0.0000, 0.0000,  ..., 0.0000, 0.0759, 0.0000],
        [0.1866, 0.0000, 0.0153,  ..., 0.0000, 0.0972, 0.0000],
        [0.1318, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
Train: 53, 53
Valid: 16, 16
Test: 31, 31
Run: 01, Epoch: 02, Loss: 1.6055, Train: 22.64%, Valid: 6.25%, Test: 9.68%
embeddings: tensor([[8.6045, 0.00

In [3]:
pip install ogb


Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting scikit-learn>=0.20.0 (from ogb)
  Using cached scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting pandas>=0.24.0 (from ogb)
  Downloading pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)
Collecting pytz>=2020.1 (from pandas>=0.24.0->ogb)
  Using cached pytz-2024.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas>=0.24.0->ogb)
  Downloading tzdata-2024.2-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting joblib>=1.2.0 (from scikit-learn>=0.20.0->ogb)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-lear