In [25]:
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.optim import Adam, lr_scheduler

import pytorch_lightning as pl
from torch_geometric.loader import NeighborLoader

from general.utils import set_seeds, standardize_data


class Args(Dataset):
    def __init__(
        self, seed, dataset, optimizer_type, optimizer_lr,
        optimizer_decay, epochs, hidden_channel, dropout,
        nlayers, heads_in, heads_out, batch_size, n_neighbors, num_workers
    ):

        self.seed = seed
        self.dataset = dataset
        self.optimizer_type = optimizer_type
        self.optimizer_lr = optimizer_lr
        self.optimizer_decay = optimizer_decay
        self.epochs = epochs
        self.hidden_channel = hidden_channel
        self.dropout = dropout
        self.nlayers = nlayers
        self.heads_in = heads_in
        self.heads_out = heads_out
        self.batch_size = batch_size
        self.n_neighbors = n_neighbors
        self.num_workers = num_workers



args = Args(
    seed=42,
    dataset='arxiv',
    optimizer_type='Adam',
    optimizer_lr=0.005,
    optimizer_decay=0.0005,
    epochs=5,
    hidden_channel=8,
    dropout=0.6,
    nlayers=2,
    heads_in=8,
    heads_out=1,
    batch_size=100,
    n_neighbors=100,
    num_workers=1,
)

set_seeds(args.seed)
path = f'data/{args.dataset}/{args.dataset}_sign_k0.pth'
data = standardize_data(torch.load(path), args.dataset)

train_loader = NeighborLoader(
    data,
    input_nodes=data.train_mask,  # can be bool or n_id indices
    num_neighbors=[args.n_neighbors]*args.nlayers,
    shuffle=True,
    batch_size=args.batch_size,
    drop_last=True,  # remove final batch if incomplete
    num_workers=args.num_workers,
)

subgraph_loader = NeighborLoader(
    copy.copy(data),
    input_nodes=None,
    num_neighbors=[-1]*args.nlayers,  # sample all neighbors
    shuffle=False,  # :batch_size in sequential order
    batch_size=args.batch_size,
    drop_last=False,
    num_workers=args.num_workers,
)


NameError: name 'copy' is not defined

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.hidden_channel = hidden_channel
        self.dropout = dropout
        self.nlayers = max(2, nlayers)
        self.heads_in = heads_in
        self.heads_out = heads_out

        # convs layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(
            GATConv(
                self.in_channel,
                self.hidden_channel,
                heads=self.heads_in,
                dropout=self.dropout
            ))
        for _ in range(nlayers-2):
            self.convs.append(
                GATConv(
                    self.hidden_channel*self.heads_in,
                    self.hidden_channel,
                    heads=self.heads_in,
                    dropout=self.dropout
                ))
        self.convs.append(
            GATConv(
                self.hidden_channel*self.heads_in,
                self.out_channel,
                heads=self.heads_out,
                dropout=self.dropout,
                concat=False,
            ))


    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=self.dropout, training=self.training)
        return F.log_softmax(x, dim=1)


    def configure_optimizers(self):
        optimizer = Adam(
            self.parameters(), 
            lr=self.optimizer_lr,
            weight_decay=self.optimizer_decay)
        lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


    def training_step(self, batch, batch_idx):
        batch_size = batch.batch_size

        # forward pass
        logits = model(
            batch.x.to(device),
            batch.edge_index.to(device)
        )[:batch_size]

        y = batch.y[:batch_size].to(logits.device)
        loss = F.nll_loss(logits, y)

    def 
