In [None]:
import argparse

import torch
from torch_geometric.nn import Node2Vec
from torch_geometric.utils import to_undirected

from ogb.nodeproppred import PygNodePropPredDataset


def save_embedding(model):
    torch.save(model.embedding.weight.data.cpu(), 'embedding.pt')


In [None]:
def main():
    if "ipykernel" in sys.argv[0]:
        class Args:
            device = 0
            log_steps = 1
            use_sage = False
            num_layers = 3
            hidden_channels = 256
            dropout = 0.5
            lr = 0.01
            epochs = 20
            runs = 5
        args = Args()
    else:
        # Argument parsing for command-line execution
        parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)')
        parser.add_argument('--device', type=int, default=0)
        parser.add_argument('--log_steps', type=int, default=1)
        parser.add_argument('--use_sage', action='store_true')
        parser.add_argument('--num_layers', type=int, default=3)
        parser.add_argument('--hidden_channels', type=int, default=256)
        parser.add_argument('--dropout', type=float, default=0.5)
        parser.add_argument('--lr', type=float, default=0.01)
        parser.add_argument('--epochs', type=int, default=500)
        parser.add_argument('--runs', type=int, default=10)
        args = parser.parse_args()
    
    print(args)
    
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    node_list = [110223, 146929, 2940, 104544, 62326, 29759, 96890, 47025, 117732, 163206, 61450, 20589, 145422, 33882, 4523, 81254, 82143, 85138, 167093, 125903, 116417, 158097, 95232, 81835, 84070, 53849, 112263, 17237, 34055, 10707, 164818, 32412, 75064, 139893, 161232, 100687, 148029, 55852, 23076, 152166, 44377, 111682, 145090, 132166, 83408, 153864, 143857, 111282, 150656, 36239, 142479, 112753, 149468, 50327, 163800, 45303, 155755, 151189, 160714, 8875, 116826, 142383, 143573, 136795, 277, 120988, 120990, 151641, 44082, 94916, 32814, 117613, 98216, 41936, 16625, 163884, 69260, 29036, 91369, 31451, 161505, 27948, 29235, 15854, 162540, 161899, 157799, 90013, 120275, 57309, 92874, 140989, 11692, 78367, 62285, 80502, 72672, 121790, 12662, 90677]
    
    dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data = data.to(device)

    split_idx = dataset.get_idx_split()

    # Filter the split indices based on your `node_list`
    train_idx = torch.tensor([n for n in split_idx['train'] if n in node_list], device=device)
    valid_idx = torch.tensor([n for n in split_idx['valid'] if n in node_list], device=device)
    test_idx = torch.tensor([n for n in split_idx['test'] if n in node_list], device=device)

    print(train_idx)
    print(valid_idx)
    print(test_idx)
    
    model = Node2Vec(data.edge_index, args.embedding_dim, args.walk_length,
                     args.context_size, args.walks_per_node,
                     sparse=True).to(device)

    loader = model.loader(batch_size=args.batch_size, shuffle=True,
                          num_workers=4)
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr)

    model.train()
    for epoch in range(1, args.epochs + 1):
        for i, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()

            if (i + 1) % args.log_steps == 0:
                print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                      f'Loss: {loss:.4f}')

            if (i + 1) % 100 == 0:  # Save model every 100 steps.
                save_embedding(model)
        save_embedding(model)


if __name__ == "__main__":
    main()
