In [2]:
import argparse
import sys

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from logger import Logger


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, return_embeddings=False):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        embeddings = x  # Save embeddings before the final layer
        x = self.convs[-1](x, adj_t)
        if return_embeddings:
            return x.log_softmax(dim=-1), embeddings
        return x.log_softmax(dim=-1)


class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(SAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, return_embeddings=False):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        embeddings = x  # Save embeddings before the final layer
        x = self.convs[-1](x, adj_t)
        if return_embeddings:
            return x.log_softmax(dim=-1), embeddings
        return x.log_softmax(dim=-1)


def train(model, data, train_idx, optimizer, save_embeddings=False):
    model.train()

    optimizer.zero_grad()
    out, embeddings = model(data.x, data.adj_t, return_embeddings=True)
    loss = F.nll_loss(out[train_idx], data.y.squeeze(1)[train_idx])
    loss.backward()
    optimizer.step()

    if save_embeddings:
        # Save embeddings for all nodes
        torch.save(embeddings, "final_embeddings.pt")
        print("Embeddings saved to 'final_embeddings.pt'")

    return loss.item()


@torch.no_grad()
def test(model, data, split_idx, evaluator):
    model.eval()

    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_idx = split_idx['train']
    valid_idx = split_idx['valid']
    test_idx = split_idx['test']

    train_acc = evaluator.eval({'y_true': data.y[train_idx], 'y_pred': y_pred[train_idx]})['acc']
    valid_acc = evaluator.eval({'y_true': data.y[valid_idx], 'y_pred': y_pred[valid_idx]})['acc']
    test_acc = evaluator.eval({'y_true': data.y[test_idx], 'y_pred': y_pred[test_idx]})['acc']

    return train_acc, valid_acc, test_acc


def main():
    # Detect if running in Jupyter or via command-line
    if "ipykernel" in sys.argv[0]:
        class Args:
            device = 0
            log_steps = 1
            use_sage = False
            num_layers = 3
            hidden_channels = 256
            dropout = 0.5
            lr = 0.01
            epochs = 20
            runs = 5
        args = Args()
    else:
        parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)')
        parser.add_argument('--device', type=int, default=0)
        parser.add_argument('--log_steps', type=int, default=1)
        parser.add_argument('--use_sage', action='store_true')
        parser.add_argument('--num_layers', type=int, default=3)
        parser.add_argument('--hidden_channels', type=int, default=256)
        parser.add_argument('--dropout', type=float, default=0.5)
        parser.add_argument('--lr', type=float, default=0.01)
        parser.add_argument('--epochs', type=int, default=500)
        parser.add_argument('--runs', type=int, default=10)
        args = parser.parse_args()

    print(args)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    dataset = PygNodePropPredDataset(name='ogbn-arxiv', transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data = data.to(device)

    split_idx = dataset.get_idx_split()
    train_idx = torch.tensor(split_idx['train'], device=device)

    if args.use_sage:
        model = SAGE(data.num_features, args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device)
    else:
        model = GCN(data.num_features, args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device)

    evaluator = Evaluator(name='ogbn-arxiv')
    logger = Logger(args.runs, args)

    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            save_embeddings = (run == args.runs - 1 and epoch == args.epochs)
            loss = train(model, data, train_idx, optimizer, save_embeddings=save_embeddings)
            train_acc, valid_acc, test_acc = test(model, data, split_idx, evaluator)

            if epoch % args.log_steps == 0:
                print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, '
                      f'Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}%, '
                      f'Test: {100 * test_acc:.2f}%')

    print("Training complete. Final embeddings saved to 'final_embeddings.pt'.")


if __name__ == "__main__":
    main()


<__main__.main.<locals>.Args object at 0x7db268285960>


  self.data, self.slices = torch.load(self.processed_paths[0])
  train_idx = torch.tensor(split_idx['train'], device=device)
  return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())


Run: 01, Epoch: 01, Loss: 4.1579, Train: 20.33%, Valid: 26.53%, Test: 23.89%
Run: 01, Epoch: 02, Loss: 2.3036, Train: 19.68%, Valid: 19.87%, Test: 25.61%
Run: 01, Epoch: 03, Loss: 1.9341, Train: 29.37%, Valid: 25.27%, Test: 29.78%
Run: 01, Epoch: 04, Loss: 1.7596, Train: 36.05%, Valid: 34.45%, Test: 36.59%
Run: 01, Epoch: 05, Loss: 1.6553, Train: 30.45%, Valid: 20.84%, Test: 19.52%
Run: 01, Epoch: 06, Loss: 1.5689, Train: 29.59%, Valid: 17.90%, Test: 15.75%
Run: 01, Epoch: 07, Loss: 1.4996, Train: 30.67%, Valid: 19.15%, Test: 17.40%
Run: 01, Epoch: 08, Loss: 1.4482, Train: 30.34%, Valid: 19.18%, Test: 17.62%
Run: 01, Epoch: 09, Loss: 1.4049, Train: 29.58%, Valid: 19.42%, Test: 17.77%
Run: 01, Epoch: 10, Loss: 1.3687, Train: 29.84%, Valid: 22.42%, Test: 23.19%
Run: 01, Epoch: 11, Loss: 1.3419, Train: 32.02%, Valid: 26.50%, Test: 30.04%
Run: 01, Epoch: 12, Loss: 1.3172, Train: 35.92%, Valid: 29.96%, Test: 34.16%
Run: 01, Epoch: 13, Loss: 1.2973, Train: 40.38%, Valid: 34.30%, Test: 38.26%

In [3]:
embeddings = torch.load("final_embeddings.pt")
print(embeddings.shape)


torch.Size([169343, 256])


  embeddings = torch.load("final_embeddings.pt")


In [4]:
import torch

# Loading the .pt file
file_path = 'final_embeddings.pt'
data = torch.load(file_path)
print(data.shape)

print(type(data))
print(data)      


torch.Size([169343, 256])
<class 'torch.Tensor'>
tensor([[3.3821, 5.6448, 0.0000,  ..., 0.0000, 3.5918, 0.0000],
        [0.0000, 0.0000, 0.1437,  ..., 0.0000, 0.6505, 0.0000],
        [0.1314, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 1.5421, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.8261, 0.0000, 0.0000,  ..., 0.0000, 0.6053, 0.7923],
        [0.0000, 0.0000, 1.3019,  ..., 0.0000, 0.0000, 0.2466]],
       requires_grad=True)


  data = torch.load(file_path)
