In [32]:
import torch
import os.path as osp

DATA_DIR = '/home/gangda/workspace/graph_engine/intermediate'

In [27]:
res = []
for i in range(2):
    for j in range(8):
        res += torch.load(osp.join(DATA_DIR, 'data_{}_{}.pt'.format(i, j)))

len(res)

KeyboardInterrupt: 

In [36]:
import os
from ogb.nodeproppred import DglNodePropPredDataset

file_path = '/home/gangda/workspace/graph_engine/data/ogbn-products-p2'

def extract_core_global_ids(parts):
    part_core_global_ids = []
    for i in range(len(parts)):
        part = parts[i]
        core_mask = part.ndata['inner_node'].type(torch.bool)
        part_global_id = part.ndata['orig_id']
        part_core_global_ids.append(part_global_id[core_mask])
    return part_core_global_ids

og, y = DglNodePropPredDataset(name='ogbn-products', root='/data/gangda/dgl')[0]
dataset = dict(X=og.ndata['feat'], y=y)
dataset['edge_index'] = torch.load(os.path.join(file_path, 'dgl_edge_index.pt'))
parts = torch.load(os.path.join(file_path, 'metis_partitions.pt'))
dataset['part_core_global_ids'] = extract_core_global_ids(parts)

In [36]:
# [1235633, 1213396]
vids = torch.cat([torch.arange(1235633), torch.arange(1213396)])
sids = torch.cat([torch.zeros(1235633), torch.ones(1213396)])

def get_global_id(local_ids, shard_ids):
    local_ids = local_ids.to(torch.long)
    global_ids = torch.empty_like(local_ids)
    for j in range(len(dataset['part_core_global_ids'])):
        mask = shard_ids == j
        if mask.sum() == 0: continue
        global_ids[mask] = dataset['part_core_global_ids'][j][local_ids[mask]]
    return global_ids

gids = get_global_id(vids, sids)
gids

In [44]:
K = 150

global_topk = []
for i, r in enumerate(res):
    val, idx = torch.sort(r[2], descending=True)
    top_k_index = idx[:K]
    global_ids = get_global_id(r[0][top_k_index], r[1][top_k_index])
    vec = torch.full((K,), gids[i])
    vec[:global_ids.shape[0]] = global_ids
    global_topk.append(vec)
global_topk = torch.stack(global_topk)

ppr_matrix = torch.empty((dataset['X'].shape[0], K), dtype=torch.long)
ppr_matrix[gids] = global_topk
ppr_matrix

2449029

In [56]:
torch.save(ppr_matrix, osp.join(DATA_DIR, 'ppr_matrix.pt'))

tensor([ 769608,  228252,  960517,  132355,       0, 1043730,    4467,  657230,
          33678,  188057])

In [33]:
ppr_matrix = torch.load(osp.join(DATA_DIR, 'ogbn-products_ppr_matrix.pt'))
ppr_matrix

tensor([[      0,  152857,  194591,  ...,  186457, 2034637,   35148],
        [      1,   89825,  151342,  ..., 1137752,  620866, 2322836],
        [      2,  488076, 1598820,  ..., 2017855,  442068, 1514525],
        ...,
        [2449026,  149963,  148503,  ..., 1886107, 2210439, 2009035],
        [2449027,  739621,  306502,  ...,  188674,  307615,  706682],
        [2449028,  728426,   48117,  ..., 1815472, 1420498, 2088456]])

In [29]:
batch_index = torch.arange(1024, 1024 + 128)
batch_index

tensor([1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035,
        1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047,
        1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059,
        1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071,
        1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083,
        1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095,
        1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107,
        1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119,
        1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131,
        1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143,
        1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151])

In [38]:
from torch_sparse import SparseTensor
from pyg_lib.sampler import subgraph as libsubgraph

adj = SparseTensor.from_edge_index(dataset['edge_index'])
rowptr, col, _ = adj.csr()

In [39]:
subset = torch.cat([batch_index, ppr_matrix[batch_index].view(-1)])
subset, inv = subset.unique(return_inverse=True)

libsubgraph(rowptr, col, subset)

(tensor([     0,     96,    121,  ..., 545857, 545883, 545884]),
 tensor([  407,   520,   582,  ..., 16622, 17125,   121]),
 tensor([    10188,     10189,     10190,  ..., 123715627, 123715631,
         123716018]))

In [50]:
%%time

from torch_geometric.utils import subgraph
from torch_geometric.data import Data

subset = torch.cat([batch_index, ppr_matrix[batch_index].view(-1)])
subset, inv = subset.unique(return_inverse=True)

sub_rowptr, sub_col, _ = libsubgraph(rowptr, col, subset)
adj = SparseTensor(rowptr=sub_rowptr, col=sub_col)

# sub_edge_index = subgraph(subset, dataset['edge_index'])[0]
# sub_edge_index = sub_edge_index.unique(return_inverse=True)[1]

batch_data = Data(dataset['X'][subset],
                  edge_index=adj,
                  y=dataset['y'][batch_index].view(-1),
                  ego_index=inv[:batch_index.shape[0]])
batch_data

CPU times: user 8.18 s, sys: 342 ms, total: 8.52 s
Wall time: 64.2 ms


Data(x=[18918, 100], edge_index=[18918, 18918, nnz=545884], y=[128], ego_index=[128])

In [169]:
%%time
from torch_geometric.data import Batch

datas = []
for bid in batch_index:
    subset, inv = torch.cat([bid.unsqueeze(dim=0), ppr_matrix[bid]]).unique(return_inverse=True)
    ego_index = inv[0]
    # subset = ppr_matrix[bid].unique()
    # ego_index=torch.tensor([0])

    edge_id = libsubgraph(rowptr, col, subset, return_edge_id=True)[-1]
    sub_edge_index = dataset['edge_index'][:, edge_id]
    # sub_edge_index = subgraph(ppr_matrix[bid], dataset['edge_index'])[0]
    sub_edge_index = sub_edge_index.unique(return_inverse=True)[1]

    ego_data = Data(dataset['X'][subset],
                    sub_edge_index,
                    y=dataset['y'][bid],
                    ego_index=ego_index
                    )
    datas.append(ego_data)

batch = Batch.from_data_list(datas)
batch

CPU times: user 8min 7s, sys: 0 ns, total: 8min 7s
Wall time: 31 s


DataBatch(x=[11491, 100], edge_index=[2, 158175], y=[128], ego_index=[128], batch=[11491], ptr=[129])

---

In [200]:
d = DglNodePropPredDataset(name='ogbn-products', root='/data/gangda/dgl')
og, y = d[0]
dataset = dict(X=og.ndata['feat'], y=y)
dataset['train_index'] = d.get_idx_split()['train']
dataset['valid_index'] = d.get_idx_split()['valid']
dataset['test_index'] = d.get_idx_split()['test']

data_list = torch.load('/home/gangda/workspace/graph_engine/intermediate/egograph_list.pt')
print('Data Loading Complete!')

Data Loading Complete!


In [201]:
class Dict(dict):
    def __getattr__(self, key):
        return self.get(key)
    def __setattr__(self, key, value):
        self[key] = value

args = Dict({
    'batch_size': 128,
    'runs': 5,
    'epochs': 35,
    'lr': 0.001,
})

In [269]:
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.utils import subgraph, dropout_edge
from torchmetrics import Accuracy
from tqdm import tqdm
from torch_geometric.nn import GATConv
import numpy as np


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4,
                 num_layers=5, dropout=0.35, dropedge=0.1, pooling='center'):
        super().__init__()
        self.dropout = dropout
        self.dropedge = dropedge
        self.pooling = pooling

        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(heads * hidden_channels, hidden_channels, heads))
        self.convs.append(GATConv(heads * hidden_channels, out_channels, heads))

    def forward(self, x: Tensor, edge_index: Tensor, ego_index: Tensor) -> Tensor:
        for i, conv in enumerate(self.convs):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x, dropout_edge(edge_index, p=self.dropedge, training=self.training)[0])
            if i < len(self.convs) - 1:
                x = x.relu_()

        if self.pooling == 'center':
            x = x[ego_index]

        return x


class ShadowLoader(DataLoader):
    def __init__(self, node_idx, data_list, **kwargs):
        self.data_list = data_list
        if node_idx.dtype == torch.bool:
            node_idx = node_idx.nonzero(as_tuple=False).view(-1)
        super().__init__(node_idx.tolist(), collate_fn=self.__collate__, **kwargs)

    def __collate__(self, batch_nodes):
        batch_data_list = []
        for nid in batch_nodes:
            batch_data_list.append(self.data_list[nid])
        return Batch.from_data_list(batch_data_list)


def train(model, optimizer, metric, train_loader, epoch):
    model.train()
    metric.reset()

    pbar = tqdm(total=int(len(train_loader.dataset)))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = 0
    for batch in train_loader:
        batch.to(metric.device)

        optimizer.zero_grad()
        y_hat = model(batch.x, batch.edge_index, batch.ego_index)
        loss = F.cross_entropy(y_hat, batch.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() / batch.y.shape[0]
        metric.update(y_hat.argmax(dim=-1), batch.y)
        pbar.update(batch.y.shape[0])
    pbar.close()

    return total_loss, metric.compute()


@torch.no_grad()
def mini_test(model, metric, *loaders):
    model.eval()
    ms = []
    for loader in loaders:
        metric.reset()
        for data in tqdm(loader):
            data.to(metric.device)
            y_hat = model(data.x, data.edge_index, data.ego_index)
            metric.update(y_hat.argmax(dim=-1), data.y)
        ms.append(metric.compute())
    return ms


def main(args, dataset, data_list):
    num_features, num_classes = dataset['X'].shape[-1], dataset['y'].max().item()+1

    kwargs = {'batch_size': args.batch_size, 'shuffle': True, 'num_workers': 1, 'persistent_workers': True}
    train_loader = ShadowLoader(dataset['train_index'], data_list, **kwargs)
    val_loader = ShadowLoader(dataset['valid_index'], data_list, **kwargs)
    test_loader = ShadowLoader(dataset['test_index'], data_list, **kwargs)

    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    metric = Accuracy(task="multiclass", num_classes=num_classes)
    metric.to(device)

    # runs
    best_val, best_test = [], []
    for i in range(1, args.runs + 1):
        model = GAT(num_features, 256, num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0)

        best_val_acc = test_acc = 0
        test_accs = []
        print(f'------------------------{i}------------------------')

        for epoch in range(1, args.epochs + 1):
            loss, train_acc = train(model, optimizer, metric, train_loader, epoch)
            print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {train_acc:.4f}')
            if epoch > 20 and epoch % 5 == 0:
                val_acc, tmp_test_acc = mini_test(model, metric, val_loader, test_loader)
                test_accs.append(round(tmp_test_acc.item(), 3))
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    test_acc = tmp_test_acc
                print(f'Epoch: {epoch:02d}, Loss: {loss: .4f}, Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

        best_val.append(float(best_val_acc))
        best_test.append(float(test_acc))
        print(test_accs)

    print(f'Valid: {np.mean(best_val):.4f} +- {np.std(best_val):.4f}')
    print(f'Test: {np.mean(best_test):.4f} +- {np.std(best_test):.4f}')

In [None]:
%%time
main(args, dataset, data_list)

------------------------1------------------------


Epoch 01:   0%|          | 128/196615 [00:02<54:42, 59.86it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2604071ca0>
Traceback (most recent call last):
  File "/home/gangda/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/gangda/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/home/gangda/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2604071ca0>
Traceback (most recent call last):
  File "/home/gangda/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/gangda/anaconda3/lib/pyth

Epoch 01, Loss: 6.2611, Approx. Train: 0.8625


Epoch 02: 100%|██████████| 196615/196615 [02:22<00:00, 1382.78it/s]


Epoch 02, Loss: 4.2208, Approx. Train: 0.9046


Epoch 03: 100%|██████████| 196615/196615 [02:22<00:00, 1383.35it/s]


Epoch 03, Loss: 3.7545, Approx. Train: 0.9132


Epoch 04: 100%|██████████| 196615/196615 [02:22<00:00, 1377.26it/s]


Epoch 04, Loss: 3.3817, Approx. Train: 0.9203


Epoch 05: 100%|██████████| 196615/196615 [02:22<00:00, 1379.16it/s]


Epoch 05, Loss: 3.7369, Approx. Train: 0.9154


Epoch 06: 100%|██████████| 196615/196615 [02:22<00:00, 1380.05it/s]


Epoch 06, Loss: 3.4559, Approx. Train: 0.9204


Epoch 07: 100%|██████████| 196615/196615 [02:22<00:00, 1379.80it/s]


Epoch 07, Loss: 3.3202, Approx. Train: 0.9216


Epoch 08: 100%|██████████| 196615/196615 [02:22<00:00, 1376.17it/s]


Epoch 08, Loss: 3.8664, Approx. Train: 0.9145


Epoch 09: 100%|██████████| 196615/196615 [02:22<00:00, 1383.27it/s]


Epoch 09, Loss: 3.7716, Approx. Train: 0.9113


Epoch 10: 100%|██████████| 196615/196615 [02:22<00:00, 1380.24it/s]


Epoch 10, Loss: 3.7703, Approx. Train: 0.9107


Epoch 11:  51%|█████     | 100480/196615 [01:12<01:13, 1314.27it/s]

In [263]:
kwargs = {'batch_size': args.batch_size, 'shuffle': True, 'num_workers': 1, 'persistent_workers': True}

In [251]:
train_loader = ShadowLoader(dataset['train_index'], data_list, **kwargs)
iter = train_loader.__iter__()

In [262]:
%%time
batch = next(iter)
batch

CPU times: user 3.9 ms, sys: 50 µs, total: 3.95 ms
Wall time: 3.97 ms


DataBatch(x=[11139, 100], edge_index=[2, 149450], y=[128], ego_index=[128], batch=[11139], ptr=[129])

---

In [8]:
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import NeighborLoader

data = PygNodePropPredDataset('ogbn-products', root='/data/gangda/ogb')[0]

In [18]:
import torch
data.n_id = torch.arange(data.num_nodes)
subgraph_loader = NeighborLoader(
    data,
    input_nodes=None,
    num_neighbors=[-1],
    batch_size=4096,
    num_workers=1,
    persistent_workers=True,
)

it = iter(subgraph_loader)
batch = next(it)

In [19]:
batch.n_id

tensor([     0,      1,      2,  ..., 113059, 380489, 634640])

In [20]:
from dgl.data import CoraGraphDataset

og = CoraGraphDataset(raw_dir='/data/gangda/dgl', verbose=False)[0]
og

Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [27]:
og.ndata['label'].max()

tensor(6)