## Accuracy Study of Lossy MFG

We select some non-training nodes in MFG and drop their incoming edges to simulate the effects of node partitioning.

There may be several strategies for choosing those nodes:
* random
* random but excluding important (i.e. high-degree) nodes
* based on the influence score


In [None]:
import sys
sys.path.append("..")
import torch, torch_geometric as pyg
import models.pyg as pyg_models

print("PyTorch:", torch.__version__, *torch.__path__)
print("PyG:", pyg.__version__, *pyg.__path__)
print("CPU parallelism:", torch.get_num_threads())

In [None]:
import os
from pathlib import Path
from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T
from torch_geometric.datasets.reddit import Reddit
from torch_geometric.utils import mask_to_index
from torch_geometric.loader import NeighborLoader, NeighborSampler
from logger import Logger
import pickle

path = Path('/mnt/md0/datasets/')

# name = 'arxiv'
# dataset = PygNodePropPredDataset(
#     root=path, name='ogbn-arxiv',
#     pre_transform=T.ToUndirected(),
#     transform=T.ToSparseTensor()
# )
# data = dataset[0]
# data.NID = torch.arange(0, data.num_nodes, dtype=torch.int32)
# num_classes = dataset.num_classes
# print(data)
# split = dataset.get_idx_split()
# train_nid, val_nid, test_nid = split['train'], split['valid'], split['test']

name = 'reddit'
dataset = Reddit(path / name, pre_transform=T.AddSelfLoops(), transform=T.ToSparseTensor())
data = dataset[0]
data.NID = torch.arange(0, data.num_nodes, dtype=torch.int32)
num_classes = dataset.num_classes
print(data)
if os.path.exists(f'{name}_rand_split.pt'):
    perm = torch.load(f'{name}_rand_split.pt')
else:
    perm = torch.randperm(data.num_nodes)
    torch.save(perm, f'{name}_rand_split.pt')
train_nid = perm[:data.num_nodes//10]
val_nid = perm[data.num_nodes//2:]
test_nid = perm[data.num_nodes//10:data.num_nodes//2]

print(f"train: {len(train_nid)}, val: {len(val_nid)}, test: {len(test_nid)}")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_layers = 3
num_hidden = 128
fanout = [-1] * num_layers
eval_fanout = [-1] * num_layers
batch_size = 4096
n_epochs = 50
n_runs = 5

def get_info():
    return {
        'dataset': dataset.name,
        'num_layers': num_layers, 'num_hidden': num_hidden,
        'fanout': fanout, 'batch_size': batch_size,
        'n_epochs': n_epochs, 'n_runs': n_runs
    }


In [None]:
from graphutils.rw import lazy_rw

def influential(data, nids, k=3, topk=100):
    init_score = torch.zeros((data.num_nodes,), device=nids.device)
    init_score[nids] = 1 / nids.size(0)
    final_score = lazy_rw(data.adj_t, init_score, k=k)
    topk = final_score.cpu().topk(topk)
    # print("score sum:", final_score.sum(), "topk sum:", topk.values.sum())
    return topk.indices, topk.values

def importance(data, nids, k=3, topk=100):
    init_score = torch.zeros((data.num_nodes,), device=nids.device)
    init_score[nids] = 1 / nids.size(0)
    final_score = torch.zeros_like(init_score)
    score = init_score
    for _ in range(k):
        score = lazy_rw(data.adj_t, score, k=1)
        final_score += score
    topk = (final_score/k).cpu().topk(topk)
    # print("score sum:", final_score.sum(), "topk sum:", topk.values.sum())
    return topk.indices, topk.values

data_cuda = data.cuda()
train_cuda = train_nid.cuda()
pt_name = f'train_topk_infl_{name}.pt'
if os.path.exists(pt_name):
    train_topk, train_topk_scores = torch.load(pt_name)
else:
    train_topk = torch.empty((data.num_nodes, 100), dtype=torch.long)
    train_topk_scores = torch.empty((data.num_nodes, 100), dtype=torch.float)
    from tqdm import tqdm
    for t in tqdm(train_cuda.tolist()):
        t_topk = influential(data_cuda, torch.tensor([t], device='cuda'), k=3, topk=100)
        train_topk[t][:] = t_topk[0]
        train_topk_scores[t][:] = t_topk[1]
    torch.save([train_topk, train_topk_scores], pt_name)

train_topk, train_topk_scores = train_topk.cuda(), train_topk_scores.cuda()

In [None]:
import tqdm
from pyinstrument import Profiler
import torch.nn.functional as F

# val_dataloader = NeighborLoader(
#     data, num_neighbors=fanout, shuffle=False,
#     input_nodes=val_nid, batch_size=batch_size,
# )
# test_dataloader = NeighborLoader(
#     data, num_neighbors=fanout, shuffle=False,
#     input_nodes=test_nid, batch_size=batch_size,
# )

def train(model, optimizer, dataloader, description='train'):
    model.train()
    # minibatches = tqdm.tqdm(dataloader)
    # minibatches.set_description(description)
    total_loss = total_correct = total_examples = 0
    mfg_sizes = num_nodes = alive_nodes = batch_score = 0
    for batch in dataloader:
        bsize = batch.batch_size
        dev_attrs = [key for key in batch.keys if not key.endswith('_mask')]
        batch = batch.to(device, *dev_attrs, non_blocking=True)
        optimizer.zero_grad()
        y = batch.y[:bsize].long().view(-1)
        y_hat = model(batch.x, batch.adj_t)[:bsize]
        loss = F.nll_loss(y_hat, y)
        loss.backward()
        optimizer.step()
        # collect stats
        total_loss += float(loss) * bsize
        batch_correct = int((y_hat.argmax(dim=-1) == y).sum())
        total_correct += batch_correct
        total_examples += bsize
        mfg_sizes += batch.adj_t.nnz()
        num_nodes += batch.num_nodes
        alive_nodes += batch.alive_nodes
        batch_score += batch.score
    train_acc = total_correct / total_examples
    num_iters = len(dataloader)
    # loss, acc, batch_nodes, rem_nodes, rem_edges, batch_score
    return total_loss / total_examples, train_acc, \
        num_nodes/num_iters, alive_nodes / num_iters, \
        mfg_sizes / num_iters, batch_score / total_examples

@torch.no_grad()
def eval_batch(model, dataloader, description='eval'):
    model.eval()
    minibatches = tqdm.tqdm(dataloader)
    total_loss = total_correct = total_examples = 0
    for batch in minibatches:
        bsize = batch.batch_size
        dev_attrs = [key for key in batch.keys if not key.endswith('_mask')]
        batch = batch.to(device, *dev_attrs, non_blocking=True)
        y = batch.y[:bsize].long().view(-1)
        y_hat = model(batch.x, batch.adj_t)[:bsize]
        loss = F.nll_loss(y_hat, y)
        # collect stats
        total_loss += float(loss) * bsize
        batch_correct = int((y_hat.argmax(dim=-1) == y).sum())
        total_correct += batch_correct
        total_examples += bsize
    train_acc = total_correct / total_examples
    return total_loss / total_examples, train_acc

@torch.no_grad()
def eval_full(model, data, masks, description='eval'):
    model.eval()
    y_hat = model(data.x.cuda(), data.adj_t.cuda())
    out = []
    for mask in masks:
        y = data.y[mask].long().view(-1)
        loss = float(F.nll_loss(y_hat[mask], y))
        acc = int((y_hat[mask].argmax(dim=-1) == y).sum()) / y.shape[0]
        out.append((loss, acc))
    return out


In [None]:
import numpy as np
from torch_sparse import SparseTensor
from torch_geometric.data import Data

# dropping strategies

# random drop node with prob roughly p
def drop_random(batch, p: float) -> Data:
    bn = batch.adj_t.sizes()[0]
    device = batch.adj_t.device()
    node_mask = torch.ones((bn,), dtype=torch.bool, device=device)
    n_drop = int((bn - batch.batch_size) * p)
    if p < 0.25:
        dropped = torch.randint(batch.batch_size, bn, size=(n_drop,), device=device)
    else:
        dropped = torch.randperm(bn-batch.batch_size)[:n_drop] + batch.batch_size
    node_mask[dropped] = False

    n = data.num_nodes
    n_score: torch.Tensor = torch.zeros(n, device=device)
    target_id = batch.n_id[:batch.batch_size]
    topk_nids = train_topk[target_id].view(-1)
    topk_scores = train_topk_scores[target_id].view(-1)
    n_score.scatter_add_(dim=0, index=topk_nids, src=topk_scores)
    n_score = n_score[batch.n_id]
    batch.n_score = n_score

    src, dst, _ = batch.adj_t.coo()
    mask = node_mask[src]
    mask &= node_mask[dst]
    batch.adj_t = SparseTensor(
        row=src[mask], col=dst[mask],
        sparse_sizes=batch.adj_t.sizes(),
        is_sorted=True,
    )
    batch.node_mask = node_mask
    return batch

# drop based on train_topk, train_topk_scores
def drop_by_influence(batch, p: float):
    bn = batch.adj_t.sizes()[0]
    device = batch.adj_t.device()
    node_mask = torch.zeros((bn,), dtype=torch.bool, device=device) 
    n_drop = int((bn - batch.batch_size) * p)

    # select top-k nodes in batch by score
    n = data.num_nodes
    n_score: torch.Tensor = torch.zeros(n, device=device)
    target_id = batch.n_id[:batch.batch_size]
    topk_nids = train_topk[target_id].view(-1)
    topk_scores = train_topk_scores[target_id].view(-1)
    n_score.scatter_add_(dim=0, index=topk_nids, src=topk_scores)
    n_score = n_score[batch.n_id]
    batch.n_score = n_score
    node_mask[n_score.topk(bn-n_drop).indices] = True

    src, dst, _ = batch.adj_t.coo()
    mask = node_mask[dst]
    mask &= node_mask[src]
    batch.adj_t = SparseTensor(
        row=src[mask], col=dst[mask],
        sparse_sizes=batch.adj_t.sizes(),
        is_sorted=True,
    )
    batch.node_mask = node_mask
    return batch

def drop_hop(batch, p: float, hop=1) -> Data:
    '''
    only drop nodes that are not immediate neighbors of target nodes in the mfg
    '''
    def hop_bound(hop_num):
        src, dst, _ = batch.adj_t.coo()
        hop_i = 0
        node_bound = batch.batch_size
        while hop_i < hop_num:
            edge_bound = (src < node_bound).sum().item()
            node_bound = dst[:edge_bound].max().item() + 1
            hop_i += 1
        return node_bound

    bn = batch.adj_t.sizes()[0]
    device = batch.adj_t.device()
    node_mask = torch.ones((bn,), dtype=torch.bool, device=device) 
    n_drop = int((bn - batch.batch_size) * p)

    n_close = hop_bound(hop)
    if n_drop < bn - n_close:
        dropped = torch.randperm(bn-n_close)[:n_drop] + n_close
        node_mask[dropped] = False
    else:
        node_mask[torch.arange(n_close, bn)] = False
        drop_more = n_drop - (bn - n_close)
        dropped = torch.randperm(n_close - batch.batch_size)[:drop_more] + batch.batch_size
        node_mask[dropped] = False

    # select top-k nodes in batch by score
    n = data.num_nodes
    n_score: torch.Tensor = torch.zeros(n, device=device)
    target_id = batch.n_id[:batch.batch_size]
    topk_nids = train_topk[target_id].view(-1)
    topk_scores = train_topk_scores[target_id].view(-1)
    n_score.scatter_add_(dim=0, index=topk_nids, src=topk_scores)
    n_score = n_score[batch.n_id]
    batch.n_score = n_score

    src, dst, _ = batch.adj_t.coo()
    mask = node_mask[dst]
    mask &= node_mask[src]
    batch.adj_t = SparseTensor(
        row=src[mask], col=dst[mask],
        sparse_sizes=batch.adj_t.sizes(),
        is_sorted=True,
    )
    batch.node_mask = node_mask
    return batch

def transform_fn(batch, drop_fn, k=3):
    batch = drop_fn(batch)
    batch.score = batch.n_score[batch.node_mask].sum().item()
    batch.alive_nodes = int(batch.node_mask.sum())
    return batch


In [None]:
# from torch_geometric.utils import degree
# from functools import lru_cache

# data_degrees = degree(data.adj_t.coo()[0], data.num_nodes, dtype=torch.int32)
# @lru_cache
# def topk_degree(topk):
#     return torch.topk(data_degrees, int(topk * data.num_nodes)).values[-1]
# def drop_minor(batch, p: float, topk=0.01) -> Data:
#     '''
#     only drop nodes that are not important, i.e. of low degree
#     '''
#     if p == 0:
#         batch.drop_nodes = 0
#         return batch
#     num_nodes = batch.adj_t.sizes()[0]
#     node_mask = torch.ones((num_nodes,), dtype=torch.bool)
#     drop_nodes = torch.from_numpy(np.random.choice(
#         torch.arange(batch.batch_size, num_nodes).numpy(),
#         size=int((num_nodes - batch.batch_size) * p),
#         replace=False
#     ))
#     node_mask[drop_nodes] = False
#     topk_cutout = topk_degree(topk)
#     node_mask |= (data_degrees[batch.NID.long()] >= topk_cutout)
#     src, dst, _ = batch.adj_t.coo()
#     mask = torch.gather(input=node_mask, dim=0, index=src)
#     mask = mask & torch.gather(input=node_mask, dim=0, index=dst)
#     batch.adj_t = SparseTensor(
#         row=src[mask], col=dst[mask],
#         sparse_sizes=batch.adj_t.sizes(),
#         is_sorted=True
#     )
#     batch.node_mask = node_mask
#     return batch

# def compute_score(batch):
#     '''
#     compute the importance score of each node in the current batch
#     '''
#     num_nodes = batch.num_nodes
#     e_t = torch.zeros((num_nodes,), dtype=torch.float)
#     e_t[:batch.batch_size] = 1
#     deg = degree(batch.adj_t.coo()[0], batch.num_nodes)
#     deg[deg==0] = 1
#     adj = batch.adj_t.t()
#     scores = [e_t]
#     for _ in range(3):
#         pi = scores[-1] / deg
#         scores.append(adj.spmm(pi.view(num_nodes, 1)).view(-1))
#     per_node = sum(scores[1:])
#     return per_node / per_node.sum()

# def transform_fn(batch, drop_fn):
#     score = compute_score(batch)
#     batch = drop_fn(batch)
#     batch.quality_score = score[batch.node_mask].sum().item()
#     batch.alive_nodes = int(batch.node_mask.sum())
#     return batch

In [None]:
def train_with_drop(data, drop_p, fn=drop_random, print_once=False):
    model = pyg_models.SAGE(data.x.shape[1], num_hidden, num_classes, num_layers)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    drop_fn = lambda batch: fn(batch, drop_p)
    transform = lambda batch: transform_fn(batch, drop_fn)
    dataloader = NeighborLoader(
        data, num_neighbors=fanout, shuffle=True,
        input_nodes=train_nid, batch_size=batch_size,
        transform=transform,
        # num_workers=torch.get_num_threads()-2,
        # persistent_workers=True
    )

    logger = Logger(get_info() | {'drop_p': drop_p, "method": fn.__name__})
    print(logger.info)
    for run in range(n_runs):
        logger.set_run(run)
        model.reset_parameters()
        best_val = final_test = 0
        best_epoch = 0
        pbar = tqdm.tqdm(range(n_epochs))
        description = '{:.4f}|{:.4f}|{:.4f}'
        pbar.set_description(f"Run {run}")
        pbar.set_postfix({'acc': description.format(0, best_val, 0)})
        for epoch in pbar:
            # if epoch < 2:
            #     profiler = Profiler()
            #     profiler.start()
            train_loss, train_acc, *info = train(model, optimizer, dataloader, description=f"Epoch {epoch}")
            if print_once:
                # batch_nodes, rem_nodes, rem_edges, batch_score
                print(info)
                print_once = False
            logger.add(epoch, data={'train': {'acc': train_acc, 'loss': train_loss, 'extra': info}})
            val, test = eval_full(model, data, [val_nid, test_nid])
            logger.add(epoch, data={
                'val': {'loss': val[0], 'acc': val[1]},
                'test': {'loss': test[0], 'acc': test[1]},
            })
            if val[1] > best_val:
                best_val, final_test = val[1], test[1]
                best_epoch = epoch
            pbar.set_postfix({'acc': description.format(train_acc, best_val, final_test)})
            # if epoch < 2:
            #     profiler.stop()
            #     profiler.print()
        pbar.close()
        logger.add(epoch, data={'best-epoch': best_epoch})
    # explicitly shutdown workers when persistent_workers is True
    del dataloader._iterator
    return logger

In [None]:
ps = [0.98, 0.95, 0.9, 0.8, 0.7, 0.5, 0]
n_runs = 5
n_epochs = 50
loggers = [train_with_drop(data, drop_p=p, fn=drop_random, print_once=True) for p in ps]
with open(f"{name}_drop_random.pkl", "wb") as fp:
    pickle.dump(loggers, fp)

In [None]:
ps = [0.98, 0.95, 0.9, 0.8, 0.7, 0.5, 0]
n_runs = 5
n_epochs = 50
loggers = [train_with_drop(data, drop_p=p, fn=drop_by_influence, print_once=True) for p in ps]
with open(f"{name}_drop_infl.pkl", "wb") as fp:
    pickle.dump(loggers, fp)

In [None]:
ps = [0.98, 0.95, 0.9, 0.8, 0.7, 0.5, 0]
n_runs = 5
n_epochs = 50
loggers = [train_with_drop(data, drop_p=p, fn=drop_hop, print_once=True) for p in ps]
with open(f"{name}_drop_hop.pkl", "wb") as fp:
    pickle.dump(loggers, fp)

In [None]:
# ps = [0.98, 0.95, 0.9, 0.8, 0.7, 0.5, 0]
# n_runs = 5
# n_epochs = 50
ps = [0.99]
n_runs = 1
n_epochs = 50
loggers = [train_with_drop(data, drop_p=p, fn=drop_by_influence, print_once=True) for p in ps]

In [None]:
ps = [0.95]
n_runs = 1
n_epochs = 50
loggers = [train_with_drop(data, drop_p=p, fn=drop_random, print_once=True) for p in ps]

In [None]:
def stdmean(logger: Logger, *labels, summarize=None):
    '''
    compute the stdmean of logged data with the given labels across all runs
    customized by the `summarize` fn
    '''
    if summarize is None:
        summarize = lambda x: x
    series_dict = {}
    for run in logger:
        run_dict = logger.get_data(run, *labels)
        summary_dict = summarize(run_dict)
        for new_label in summary_dict:
            if new_label not in series_dict:
                series_dict[new_label] = []
            series_dict[new_label].append(summary_dict[new_label])
    stdmean_dict = {}
    for label in series_dict:
        t = torch.tensor(series_dict[label])
        stdmean_dict[label] = [t.mean().item(), t.std().item()]
    return stdmean_dict

def stdmean_acc(logger: Logger):
    def get_acc(val_test):
        val_acc = 100 * torch.tensor(val_test['val/acc'])
        valid = val_acc.max().item()
        test = 100 * val_test['test/acc'][val_acc.argmax()]
        return {'val/acc' : valid, 'test/acc': test}
    return stdmean(logger, 'val/acc', 'test/acc', summarize=get_acc)

import pickle
loggers = {}
with open(f'{name}_drop_random.pkl', 'rb') as fp:
    loggers['random'] = pickle.load(fp)
with open(f'{name}_drop_infl.pkl', 'rb') as fp:
    loggers['influence'] = pickle.load(fp)
with open(f'{name}_drop_hop.pkl', 'rb') as fp:
    loggers['neighbor'] = pickle.load(fp)


def extract_train_info(train_extra):
    ext_list = list(zip(*train_extra['train/extra']))
    # batch_nodes, rem_nodes, rem_edges, batch_score
    return {
        # 'batch_nodes': ext_list[0],
        'rem_nodes': ext_list[1],
        'rem_edges': ext_list[2],
        'coverage': ext_list[3],
    }

proc_data = {}
for method, loggers in loggers.items():
    proc_data[method] = {
        k: [] for k in
        ('p', 'train/acc', 'val/acc', 'test/acc', 'rem_nodes', 'rem_edges', 'coverage')
    }
    m_data = proc_data[method]
    for logger in loggers:
        info = logger.info.copy()
        m_data['p'].append(info['drop_p'])
        acc = stdmean_acc(logger)
        for k in acc:
            m_data[k].append(acc[k])
        train_acc = stdmean(logger, 'train/acc',
            summarize=lambda x: {k: 100*max(v) for k, v in x.items()})
        for k in train_acc:
            m_data[k].append(train_acc[k])
        train_stat = stdmean(logger, 'train/extra', summarize=extract_train_info)
        for k in train_stat:
            m_data[k].append(train_stat[k])


In [None]:
import matplotlib.pyplot as plt

for _, method in enumerate(proc_data):
    m_data = proc_data[method]
    fig, axs = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
    plt.suptitle(f"Dropping nodes by {method.upper()}")
    plt.ylim([55, 75])
    for i, xlabel in enumerate(['p', 'rem_nodes', 'coverage']):
        for k in ('train/acc', 'val/acc', 'test/acc'):
            if isinstance(m_data[xlabel][0], list) or isinstance(m_data[xlabel][0], list):
                xs = list(zip(*m_data[xlabel]))[0]
            else:
                xs = m_data[xlabel]
            mean, std = [torch.tensor(a) for a in zip(*m_data[k])]
            axs[i].plot(xs, mean, marker='o', label=k.split('/')[0])
            axs[i].fill_between(xs, mean-std, mean+std, alpha=0.5, interpolate=True)
            axs[i].legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
for metric, ylim in (('val/acc', None), ('test/acc', None)):
    # data = proc_data[method]
    fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
    plt.suptitle(f"arxiv-standard ({metric})", fontsize=20)
    if ylim is not None:
        plt.ylim(ylim)
    for i, xlabel in enumerate(['rem_nodes', 'coverage']):
        # for k in ('train/acc', 'val/acc', 'test/acc'):
        for k in proc_data:
            k_data = proc_data[k]
            if isinstance(k_data[xlabel][0], list) or isinstance(k_data[xlabel][0], list):
                xs = list(zip(*k_data[xlabel]))[0]
            else:
                xs = k_data[xlabel]
            mean, std = [torch.tensor(a) for a in zip(*k_data[metric])]

            if xs[0] > xs[1]:
                xs = list(reversed(xs))
                mean = reversed(mean)
                std = reversed(std)
            xs, mean, std = xs[1:], mean[1:], std[1:]

            axs[i].plot(xs, mean, marker='o', label=k.split('/')[0])
            axs[i].fill_between(xs, mean-std, mean+std, alpha=0.1, interpolate=True)
            axs[i].set_xlabel(xlabel, fontsize=16)
            axs[i].autoscale(enable=True, axis='x', tight=True)
            axs[i].tick_params(axis='y', which='major', labelsize=14)
            axs[i].tick_params(axis='x', which='major', rotation=30, labelsize=14)
            if i == 0:
                axs[i].set_ylabel('accuracy', fontsize=16)
    axs[i].legend(fontsize=16)
    
    if 'train' not in metric:
        plt.savefig(f"{name}_drop_{metric.replace('/', '_')}.pdf", dpi=160, bbox_inches='tight')
plt.show()

In [None]:
iterator = iter(NeighborLoader(
    data, num_neighbors=[-1]*3, shuffle=False,
    input_nodes=train_nid, batch_size=batch_size,
    num_workers=0
))
batch = next(iterator)

In [None]:
num_nodes = batch.num_nodes
e_t = torch.zeros((num_nodes,), dtype=torch.float)
e_t[:batch.batch_size] = 1
print(batch)
deg = degree(batch.adj_t.coo()[0], batch.num_nodes)
deg[deg==0] = 1
adj = batch.adj_t.t()

for num_layers in range(2, 8):
    scores = [e_t]
    for k in range(num_layers):
        pi = scores[-1] / deg
        scores.append(adj.spmm(pi.view(num_nodes, 1)).view(-1))

    per_node = sum(scores[1:])
    per_node[:batch.batch_size] = 0
    topk = torch.topk(per_node, k=batch.batch_size*4)
    nid = torch.sort(topk.indices).values
    topk_sum = topk.values.sum().item()
    all_sum = per_node.sum().item()

    src, dst, _ = batch.adj_t.coo()
    hop1 = dst[src<batch.batch_size]
    hop1_max = dst[src==batch.batch_size-1].max().item()
    assert hop1.max().item() == hop1_max
    topk = nid.shape[0]
    in_hop1 = (nid <= hop1_max).sum().item()
    print(f"num_layers={num_layers}")
    print(f"{in_hop1}/{topk}={in_hop1/topk*100:.2f}%; {topk_sum:.0f}/{all_sum:.0f}")

In [None]:
dataloader = iter(NeighborLoader(
    data, num_neighbors=[-1,-1], shuffle=False,
    input_nodes=train_nid, batch_size=batch_size,
    num_workers=0
))
# next(dataloader)
batch = next(dataloader)

In [None]:
drop_random.__name__