In [12]:
import tqdm
from torch import nn
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero, Linear
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData
from torch import Tensor
from torch_geometric.datasets import DBLP
import numpy as np
import random


# We initialize conference node features with a single one-vector as feature:
dataset = DBLP('./data/dblp', transform=T.Constant(node_types='conference'))
data = dataset[0]
data["author"].node_id = torch.arange(data["author"].x.shape[0])
data["paper"].node_id = torch.arange(data["paper"].x.shape[0])
print(data)


def set_seed(seed=42):
    #torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    #torch.use_deterministic_algorithms(True)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
    node_id=[4057],
  },
  paper={
    x=[14328, 4231],
    node_id=[14328],
  },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1],
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)


In [13]:
# For this, we first split the set of edges into
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision.
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1.
# Negative edges during training will be generated on-the-fly.
# We can leverage the `RandomLinkSplit()` transform for this from PyG:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=("author", "to", "paper"),
    rev_edge_types=("paper", "to", "author"), 
)

train_data, val_data, test_data = transform(data)


In [14]:
train_data["author", "to", "paper"].edge_label

tensor([1., 1., 1.,  ..., 1., 1., 1.])

In [15]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:
from torch_geometric.loader import LinkNeighborLoader

# Define seed edges:
edge_label_index = train_data["author", "to", "paper"].edge_label_index
edge_label = train_data["author", "to", "paper"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("author", "to", "paper"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

next(iter(train_loader))

HeteroData(
  author={
    x=[784, 334],
    y=[784],
    train_mask=[784],
    val_mask=[784],
    test_mask=[784],
    node_id=[784],
    n_id=[784],
    num_sampled_nodes=[3],
  },
  paper={
    x=[6759, 4231],
    node_id=[6759],
    n_id=[6759],
    num_sampled_nodes=[3],
  },
  term={
    x=[2409, 50],
    n_id=[2409],
    num_sampled_nodes=[3],
  },
  conference={
    num_nodes=20,
    x=[20, 1],
    n_id=[20],
    num_sampled_nodes=[3],
  },
  (author, to, paper)={
    edge_index=[2, 2015],
    edge_label=[384],
    edge_label_index=[2, 384],
    e_id=[2015],
    num_sampled_edges=[2],
    input_id=[128],
  },
  (paper, to, author)={
    edge_index=[2, 2332],
    e_id=[2332],
    num_sampled_edges=[2],
  },
  (paper, to, term)={
    edge_index=[2, 8286],
    e_id=[8286],
    num_sampled_edges=[2],
  },
  (paper, to, conference)={
    edge_index=[2, 190],
    e_id=[190],
    num_sampled_edges=[2],
  },
  (term, to, paper)={
    edge_index=[2, 10025],
    e_id=[10025],
    num_sa

In [18]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        self.paper_lin = torch.nn.Linear(4231, hidden_channels)
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)
        self.author_emb = torch.nn.Embedding(data["author"].num_nodes, hidden_channels)
        self.paper_emb = torch.nn.Embedding(data["paper"].num_nodes, hidden_channels)

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())

        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> Tensor:
        #x_dict = {
        #  "author": self.author_emb(data["author"].node_id),
        #  "paper": self.paper_lin(data["paper"].x) + self.paper_emb(data["paper"].node_id),
        #}
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in data.x_dict.items()
        }

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["author"],
            x_dict["paper"],
            data["author", "to", "paper"].edge_label_index,
        )

        return pred

        
model = Model(hidden_channels=64)

print(model)

Model(
  (paper_lin): Linear(in_features=4231, out_features=64, bias=True)
  (lin_dict): ModuleDict(
    (author): Linear(-1, 64, bias=True)
    (paper): Linear(-1, 64, bias=True)
    (term): Linear(-1, 64, bias=True)
    (conference): Linear(-1, 64, bias=True)
  )
  (author_emb): Embedding(4057, 64)
  (paper_emb): Embedding(14328, 64)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (author__to__paper): SAGEConv(64, 64, aggr=mean)
      (paper__to__author): SAGEConv(64, 64, aggr=mean)
      (paper__to__term): SAGEConv(64, 64, aggr=mean)
      (paper__to__conference): SAGEConv(64, 64, aggr=mean)
      (term__to__paper): SAGEConv(64, 64, aggr=mean)
      (conference__to__paper): SAGEConv(64, 64, aggr=mean)
    )
    (conv2): ModuleDict(
      (author__to__paper): SAGEConv(64, 64, aggr=mean)
      (paper__to__author): SAGEConv(64, 64, aggr=mean)
      (paper__to__term): SAGEConv(64, 64, aggr=mean)
      (paper__to__conference): SAGEConv(64, 64, aggr=mean)
      (term__to__paper): SAG

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 6):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()

        sampled_data.to(device)
        pred = model(sampled_data)

        ground_truth = sampled_data["author", "to", "paper"].edge_label
        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)

        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Device: 'cuda'


  0%|          | 0/37 [00:00<?, ?it/s]

torch.Size([823, 64])
torch.Size([6614, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 365,  366,  367,  ...,  359,  822,  363],
           [   1,    3,    5,  ..., 1858, 1859, 1859]], device='cuda:0',
          sparse_size=(823, 6614), nnz=2266, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 375,  376,  156,  ..., 2536, 2537,  374],
           [   0,    0,    1,  ...,  542,  542,  542]], device='cuda:0',
          sparse_size=(6614, 823), nnz=2503, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2538,  579, 2539,  ..., 4405,  374,  374],
           [   0,    0,    0,  ..., 1036, 1036, 1037]], device='cuda:0',
          sparse_size=(6614, 2594), nnz=8088, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6424,  432, 6487, 2623, 6488, 6489, 6490, 6491, 6492, 6493, 6494,
            6495, 6496, 6477, 6497, 6498,   46, 6499, 6500, 4581, 6501, 6502,
            6503, 6504, 6505,   56, 6506, 6507,  382, 6508, 2320, 1689, 6509,
            6510, 6511, 6512, 6513, 651

  3%|▎         | 1/37 [00:00<00:07,  4.82it/s]

torch.Size([834, 64])
torch.Size([6519, 64])
{('author', 'to', 'paper'): EdgeIndex([[  30,  363,  364,  ...,  358,  359,  476],
           [   1,    2,    3,  ..., 1684, 1685, 1685]], device='cuda:0',
          sparse_size=(834, 6519), nnz=2107, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 380,  381,  382,  ..., 2453, 2454, 2455],
           [   0,    0,    0,  ...,  560,  560,  560]], device='cuda:0',
          sparse_size=(6519, 834), nnz=2427, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2456, 1894, 2457,  ..., 1289, 6404,  379],
           [   0,    0,    0,  ..., 1038, 1038, 1038]], device='cuda:0',
          sparse_size=(6519, 2463), nnz=8140, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[5920, 5228, 6405, 6406, 6407, 6408, 5503, 3050, 4105, 6409, 6410,
            3292, 4637, 6230, 6411, 5484, 6412, 6413, 2193, 6087, 6414, 3149,
            6415, 6416, 6417, 6418, 3148, 6419,  386, 6420,  989,  985, 5895,
            6421, 6422,  961, 6423, 642

 11%|█         | 4/37 [00:00<00:02, 13.80it/s]

torch.Size([821, 64])
torch.Size([6535, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 361,  362,  363,  ...,  355,  356,  359],
           [   3,    4,    5,  ..., 1739, 1740, 1741]], device='cuda:0',
          sparse_size=(821, 6535), nnz=2102, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 378,  379,  380,  ..., 1436,  377, 2422],
           [   0,    1,    1,  ...,  554,  554,  554]], device='cuda:0',
          sparse_size=(6535, 821), nnz=2389, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2423,  146,  ...,  376,  376,  376],
           [   0,    0,    0,  ..., 1000, 1001, 1002]], device='cuda:0',
          sparse_size=(6535, 2479), nnz=8065, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6427,   14, 6428, 5364, 5236,  697,  696, 6429, 4680, 1775, 6430,
            6431, 6080, 6270, 6432, 1931, 6433, 6434, 6435, 5636, 6436, 4049,
            6437, 6438, 1253, 6439, 4048, 6440, 6441, 6442, 6443, 6444, 1612,
              79, 6445, 1041, 6446, 644

 19%|█▉        | 7/37 [00:00<00:01, 18.09it/s]

torch.Size([778, 64])
torch.Size([6496, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 360,  361,  362,  ...,  777,  358,  359],
           [   2,    5,    7,  ..., 1616, 1616, 1617]], device='cuda:0',
          sparse_size=(778, 6496), nnz=1906, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 382,  383,  384,  ...,  376, 2318, 2319],
           [   0,    0,    0,  ...,  557,  557,  557]], device='cuda:0',
          sparse_size=(6496, 778), nnz=2271, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2320, 2321, 2322,  ..., 1244,  381, 1764],
           [   0,    0,    0,  ..., 1054, 1054, 1054]], device='cuda:0',
          sparse_size=(6496, 2360), nnz=8250, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[  26, 6388,  728,  729, 6389, 6390, 6391, 6392, 1979, 6393, 1123,
            6394, 4813, 6395, 6396, 6397, 6398,  993, 2297, 6193, 6399, 6400,
            6401, 5352, 6402, 6403, 6404, 6405, 6406, 6407, 2159, 1402, 6408,
            6409, 6410, 1085, 6411, 641

 27%|██▋       | 10/37 [00:00<00:01, 20.46it/s]

torch.Size([805, 64])
torch.Size([6618, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 358,  359,  360,  ...,  351,  804,  353],
           [   0,    2,    4,  ..., 1760, 1760, 1761]], device='cuda:0',
          sparse_size=(805, 6618), nnz=2138, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 382,  383,  384,  ..., 2413,  378, 2414],
           [   0,    0,    0,  ...,  528,  528,  528]], device='cuda:0',
          sparse_size=(6618, 805), nnz=2348, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2415, 2416,  ..., 2571, 1686, 2627],
           [   0,    0,    0,  ..., 1071, 1071, 1071]], device='cuda:0',
          sparse_size=(6618, 2512), nnz=8413, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[5942,  415, 6506, 6507, 6508, 6509, 3911, 6510, 3991, 1969, 6511,
            1344, 5155, 4694, 6512,   47, 6513, 6514, 6515, 5970, 1034, 3943,
            6516, 6517, 6518, 6519, 3942, 6520, 6521, 6522, 1101, 1097, 6523,
            6524, 6525, 1071,   72, 652

 43%|████▎     | 16/37 [00:00<00:00, 22.72it/s]

torch.Size([841, 64])
torch.Size([6553, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 360,  361,  362,  ...,  353,  354,  356],
           [   0,    1,    2,  ..., 1734, 1735, 1736]], device='cuda:0',
          sparse_size=(841, 6553), nnz=2121, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ...,  377, 2598, 2599],
           [   0,    0,    0,  ...,  567,  567,  567]], device='cuda:0',
          sparse_size=(6553, 841), nnz=2565, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2600,  376, 2601,  ..., 4551, 1158,  378],
           [   0,    0,    0,  ..., 1017, 1017, 1017]], device='cuda:0',
          sparse_size=(6553, 2492), nnz=7968, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[4544, 6244, 6447, 6448, 6449,   18, 1790, 6450, 5566, 5033, 1302,
            6451, 4743, 6452, 6453, 6454, 6455, 6456, 2340, 6241, 2330, 4122,
            6457, 6458, 6459, 6460, 4121, 6461, 5628, 6462, 6463, 6464, 5801,
            6465, 6466, 6467, 6468, 646

 51%|█████▏    | 19/37 [00:00<00:00, 23.31it/s]

torch.Size([801, 64])
torch.Size([6569, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 361,  362,  363,  ...,  356,  358,  360],
           [   1,    3,    4,  ..., 1798, 1799, 1800]], device='cuda:0',
          sparse_size=(801, 6569), nnz=2191, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 377,  378,  379,  ...,  375, 2474, 2475],
           [   0,    0,    0,  ...,  541,  541,  541]], device='cuda:0',
          sparse_size=(6569, 801), nnz=2426, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2476, 2477, 1084,  ..., 4974,  376, 2492],
           [   0,    0,    0,  ..., 1040, 1040, 1040]], device='cuda:0',
          sparse_size=(6569, 2537), nnz=8280, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[1870, 2745, 1843,   16, 6460, 6461, 6462, 4872, 3998,  564, 6463,
            6464, 4517, 6465, 6466, 6467,  400, 6468, 6469, 4548, 6470, 4034,
            6471,   59, 6472,  381, 4033, 6473, 6474, 6475, 2276, 2272, 6476,
            6477, 6478, 1280, 6479, 648

 59%|█████▉    | 22/37 [00:01<00:00, 23.76it/s]

torch.Size([782, 64])
torch.Size([6464, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 364,  365,  366,  ...,  781,  359,  362],
           [   3,    5,    5,  ..., 1658, 1659, 1660]], device='cuda:0',
          sparse_size=(782, 6464), nnz=1975, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 382,  383,  384,  ...,  379, 2392,  380],
           [   0,    0,    0,  ...,  545,  545,  546]], device='cuda:0',
          sparse_size=(6464, 782), nnz=2329, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2393, 2394, 2395,  ..., 6359,  381,  381],
           [   0,    0,    0,  ..., 1046, 1046, 1047]], device='cuda:0',
          sparse_size=(6464, 2431), nnz=8221, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6360, 6361,  733,  734, 6362, 6363, 3739,  580, 3112, 6364, 6365,
            1137, 6366, 6367, 6368, 1266, 6369, 6370, 1886, 1464, 6371, 1117,
            1119, 3936, 6372,  403, 5297, 6373,  388, 6374, 2236, 2232, 1516,
            6375, 6376, 6377, 6378, 637

 68%|██████▊   | 25/37 [00:01<00:00, 24.08it/s]

torch.Size([801, 64])
torch.Size([6560, 64])
{('author', 'to', 'paper'): EdgeIndex([[  50,  358,  359,  ...,  353,  354,  356],
           [   2,    3,    4,  ..., 1754, 1755, 1756]], device='cuda:0',
          sparse_size=(801, 6560), nnz=2130, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ..., 2467, 2468, 2469],
           [   0,    0,    1,  ...,  550,  550,  550]], device='cuda:0',
          sparse_size=(6560, 801), nnz=2461, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2470, 2471, 2472,  ...,  378, 5566,  378],
           [   0,    0,    0,  ..., 1014, 1015, 1015]], device='cuda:0',
          sparse_size=(6560, 2439), nnz=8116, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6450, 6451, 6452, 5279, 6453, 6454, 2645, 6280, 4479, 6455, 6456,
            6457, 6179, 4495, 6458, 6459, 6460, 6461, 1090, 5545, 6462, 3219,
            6463, 6464, 6465,  391, 3218, 6466, 1878, 6467, 1888, 1885, 6468,
              64, 6469, 2254, 6470, 647

 84%|████████▍ | 31/37 [00:01<00:00, 24.53it/s]

torch.Size([777, 64])
torch.Size([6479, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 360,  361,  362,  ...,  355,  357,  357],
           [   4,    6,    7,  ..., 1649, 1650, 1651]], device='cuda:0',
          sparse_size=(777, 6479), nnz=1963, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 380,  381,  382,  ...,  567, 2366, 2367],
           [   0,    0,    0,  ...,  545,  545,  545]], device='cuda:0',
          sparse_size=(6479, 777), nnz=2313, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2368,  156, 2369,  ..., 3214, 2348,  379],
           [   0,    0,    0,  ..., 1035, 1035, 1035]], device='cuda:0',
          sparse_size=(6479, 2385), nnz=8096, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6361, 6362, 6363, 6364, 6365, 6366, 6367, 6368, 4536, 2380, 6369,
            6370, 4233, 5480,   46, 6371, 6372, 1038, 6373, 3659, 2098, 6374,
            6375, 6376, 6377, 6378, 6379, 6380, 6381,  381, 6382, 6383, 6384,
            6385, 6386, 1708, 6387, 638

 92%|█████████▏| 34/37 [00:01<00:00, 24.58it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 353,  354,  355,  ...,  796,  346,  350],
           [   1,    2,    4,  ..., 1691, 1692, 1693]], device='cuda:0',
          sparse_size=(797, 6647), nnz=2076, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ..., 2387, 2388, 2389],
           [   0,    0,    0,  ...,  550,  550,  550]], device='cuda:0',
          sparse_size=(6647, 797), nnz=2361, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2390, 2391,  ...,  377,  378,  377],
           [   0,    0,    0,  ..., 1056, 1056, 1057]], device='cuda:0',
          sparse_size=(6647, 2392), nnz=8519, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[  29, 6534, 6535, 6536, 6537, 6538, 6539, 6540, 4922,  679, 1288,
            6541, 5095, 6302, 6542, 5511, 6543, 6544, 6545, 4611, 2167, 6546,
            6547, 6548, 6549,  401, 6550, 6551,  385, 6552, 6553, 2195, 6554,
            6555, 6556, 6557, 6558, 6559, 1784, 6560, 6561, 6562, 2841, 6563,
      

100%|██████████| 37/37 [00:01<00:00, 22.24it/s]


Epoch: 001, Loss: 0.5547


  0%|          | 0/37 [00:00<?, ?it/s]

torch.Size([832, 64])
torch.Size([6779, 64])


  8%|▊         | 3/37 [00:00<00:01, 23.32it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 358,   41,  359,  ...,  353,  354,  831],
           [   2,    3,    4,  ..., 1752, 1753, 1753]], device='cuda:0',
          sparse_size=(832, 6779), nnz=2160, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 376,  377,  378,  ...,  374, 2472, 2473],
           [   0,    1,    1,  ...,  574,  574,  574]], device='cuda:0',
          sparse_size=(6779, 832), nnz=2463, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2474, 2475, 1593,  ..., 3758,  374,  375],
           [   0,    0,    0,  ..., 1034, 1034, 1035]], device='cuda:0',
          sparse_size=(6779, 2448), nnz=8073, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6662, 2953, 6663, 6664, 6665, 6666, 3989, 6667, 4109, 3133, 6668,
            4673, 6669, 6417, 6670, 5800,   45, 6671, 6672, 3828, 6673, 6674,
            6675, 6676, 6677, 6678, 6679, 6680, 6681, 6682, 6683, 6684, 6685,
            5238,   64, 6686, 6687, 6688, 6689, 1115, 2113, 6690, 6691, 6692,
      

 24%|██▍       | 9/37 [00:00<00:01, 24.08it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 367,  368,  369,  ...,  357,  362,  365],
           [   0,    2,    3,  ..., 1848, 1849, 1850]], device='cuda:0',
          sparse_size=(836, 6697), nnz=2267, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 380,  381,  382,  ..., 2577, 2578, 2579],
           [   0,    0,    0,  ...,  562,  562,  562]], device='cuda:0',
          sparse_size=(6697, 836), nnz=2563, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2580, 2581, 2270,  ..., 6575, 1967,  377],
           [   0,    0,    0,  ...,  992,  992,  992]], device='cuda:0',
          sparse_size=(6697, 2564), nnz=7764, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[5998, 6576, 6577, 6578, 6579, 2626, 6580, 6581, 4466, 6582, 6583,
            6584, 6585, 4289, 6586, 5745, 6587, 6588,   51, 6589, 6590, 6591,
            6592, 6593, 6594,  392, 6595, 6596, 6597, 6598, 1722, 2410, 6599,
            6600, 5838, 2012, 6601, 6602, 6603, 6604, 6605, 2472, 6606, 6607,
      

 32%|███▏      | 12/37 [00:00<00:01, 24.24it/s]

{('author', 'to', 'paper'): EdgeIndex([[  38,  370,  371,  ...,  361,  363,  369],
           [   1,    5,    5,  ..., 1737, 1738, 1739]], device='cuda:0',
          sparse_size=(815, 6628), nnz=2137, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 381,  382,  383,  ...,  380, 2468, 2469],
           [   0,    0,    1,  ...,  558,  558,  558]], device='cuda:0',
          sparse_size=(6628, 815), nnz=2421, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2470,  375,  ...,  380, 6511,  380],
           [   0,    0,    0,  ..., 1061, 1062, 1062]], device='cuda:0',
          sparse_size=(6628, 2496), nnz=8454, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6512, 6513, 6514, 6515, 4173, 1749,   12, 6516, 5132,  637, 6517,
            6518, 5033, 6519, 6520, 1301, 6521, 6522, 2423, 4715, 6523, 1877,
              70, 6524, 6525, 6526, 6527, 6528, 6529, 6530, 1867, 1864, 6531,
            6532, 6533, 1200, 6534, 6535, 6536, 2003, 6537, 1042,  804, 6538,
      

 41%|████      | 15/37 [00:00<00:00, 24.11it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 366,  367,  368,  ...,  825,  360,  363],
           [   0,    1,    2,  ..., 1664, 1665, 1666]], device='cuda:0',
          sparse_size=(826, 7450), nnz=2011, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 377,  378,  379,  ..., 2366, 2367, 2368],
           [   0,    1,    1,  ...,  548,  548,  548]], device='cuda:0',
          sparse_size=(7450, 826), nnz=2281, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2369,  674, 2370,  ..., 7344, 7345,  375],
           [   0,    0,    0,  ..., 1021, 1021, 1021]], device='cuda:0',
          sparse_size=(7450, 2366), nnz=8114, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[7346, 5779, 7347, 7348, 1989, 7349, 7350, 6351, 7351, 6103, 7352,
            7353, 7354, 4526, 2456, 6707, 2303, 7355, 6252, 6765, 7356, 7357,
            3095, 7358, 7359, 7360, 7361, 5441, 3785, 3137, 1738, 7362, 7363,
            7364, 1239, 7365, 7366, 6284, 5749, 6488, 7367, 7368, 5706, 7369,
      

 49%|████▊     | 18/37 [00:00<00:00, 24.33it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 367,  368,  369,  ...,  511,  704,  366],
           [   0,    2,    3,  ..., 1857, 1857, 1858]], device='cuda:0',
          sparse_size=(826, 6678), nnz=2291, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 381,  382,  383,  ..., 2514, 2515, 2516],
           [   0,    0,    0,  ...,  548,  548,  548]], device='cuda:0',
          sparse_size=(6678, 826), nnz=2463, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2517, 2495, 2518,  ..., 2880, 4055,  380],
           [   0,    0,    0,  ..., 1059, 1059, 1059]], device='cuda:0',
          sparse_size=(6678, 2609), nnz=8310, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6559, 6560, 6561, 6562, 6563, 6564, 6565, 6566,   11, 6567, 6568,
            6569, 6570, 6361, 6571, 5669, 6572, 6573, 6574, 4641, 6575, 6576,
            6577,   56, 6578,  411, 6579, 6580,  387, 2093, 2300, 2297, 1680,
            6581, 6582,  994, 6583, 6584, 6585, 2078, 6586, 6587, 6588, 6589,
      

 65%|██████▍   | 24/37 [00:00<00:00, 24.64it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 366,  367,  368,  ...,  361,  538,  364],
           [   0,    2,    2,  ..., 1677, 1677, 1678]], device='cuda:0',
          sparse_size=(780, 6472), nnz=2037, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ..., 2336, 2337,  378],
           [   0,    0,    0,  ...,  530,  530,  530]], device='cuda:0',
          sparse_size=(6472, 780), nnz=2265, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0,  785, 2338,  ...,  377, 1557,  378],
           [   0,    0,    0,  ..., 1023, 1023, 1024]], device='cuda:0',
          sparse_size=(6472, 2405), nnz=8071, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6353, 6354, 6355, 5427, 6356,  658,  657, 1993, 4295, 2004, 6357,
            6358, 5630, 6248, 6359, 6360, 6361, 6362, 6363, 1476, 2155, 6364,
            6365, 6366, 6367, 6368, 6369, 6370, 1794,  380, 1568, 1495,   62,
            5446, 6371, 1807, 1514,   61, 1253, 6372, 6373, 6374, 6375, 6376,
      

 73%|███████▎  | 27/37 [00:01<00:00, 24.64it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 361,   41,  362,  ...,  359,  360,  360],
           [   1,    4,    7,  ..., 1786, 1787, 1788]], device='cuda:0',
          sparse_size=(789, 6549), nnz=2139, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 378,  379,  380,  ..., 2430, 2431, 2432],
           [   0,    0,    1,  ...,  533,  533,  533]], device='cuda:0',
          sparse_size=(6549, 789), nnz=2357, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2433, 1935, 2434,  ..., 3793, 3385,  377],
           [   0,    0,    0,  ..., 1025, 1025, 1025]], device='cuda:0',
          sparse_size=(6549, 2510), nnz=8163, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6434, 6435, 6436, 6437, 6438, 6439, 6440, 6441, 6165,  594, 6442,
            1287, 4699, 6292, 6443, 1895, 6444, 6445, 6446, 5805, 6447, 6448,
            6449, 6450, 6451, 6452, 6453, 6454, 5594, 2038, 1862, 6455, 6456,
            6457, 6458, 1029, 6459, 6460, 6461, 6462, 1870, 6463, 1871, 6464,
      

 81%|████████  | 30/37 [00:01<00:00, 24.72it/s]

torch.Size([799, 64])
torch.Size([6685, 64])


 89%|████████▉ | 33/37 [00:01<00:00, 24.74it/s]

{('author', 'to', 'paper'): EdgeIndex([[   8,  363,  364,  ...,  356,  360,  361],
           [   0,    2,    3,  ..., 1885, 1886, 1887]], device='cuda:0',
          sparse_size=(799, 6685), nnz=2268, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 378,  379,  380,  ..., 2503, 2504, 2505],
           [   0,    1,    1,  ...,  534,  534,  534]], device='cuda:0',
          sparse_size=(6685, 799), nnz=2446, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2506, 2507,  ..., 3140,  376,  377],
           [   0,    0,    0,  ..., 1057, 1057, 1058]], device='cuda:0',
          sparse_size=(6685, 2578), nnz=8390, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[4676,  412, 6576, 2651, 6577, 6578, 6579, 6580, 4866, 6581, 6582,
            1464, 6583, 4570, 6584, 6585,  401, 6586, 6587, 4405, 2379, 4172,
            6588, 6589, 1397,  383, 4171, 6590, 6591, 6592, 2391, 2387, 6593,
            6594, 6595, 1139, 6596, 6597, 6598, 6599, 6600, 1214, 6601, 6602,
      

100%|██████████| 37/37 [00:01<00:00, 24.52it/s]


{('author', 'to', 'paper'): EdgeIndex([[ 362,  363,  364,  ...,  358,  359,  360],
           [   0,    1,    2,  ..., 1706, 1707, 1708]], device='cuda:0',
          sparse_size=(809, 6671), nnz=2080, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ..., 2442, 2443, 2444],
           [   0,    0,    0,  ...,  539,  539,  539]], device='cuda:0',
          sparse_size=(6671, 809), nnz=2379, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[   0, 2445,  340,  ..., 4625,  378, 3979],
           [   0,    0,    0,  ..., 1057, 1057, 1057]], device='cuda:0',
          sparse_size=(6671, 2452), nnz=8469, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[ 583, 6555, 6556, 6557, 6558, 6559, 3223, 5167,  837,  621, 6560,
            6561, 6562, 6316, 6563, 2175, 6564, 6565, 6566, 6410, 6567, 6568,
            6569, 6570, 6571, 6572, 6573,   71, 3591, 1889, 1895, 6574, 6575,
            6576, 6440, 6577, 1905, 6578, 6579, 6580, 1913, 2411,  840, 6581,
      

  8%|▊         | 3/37 [00:00<00:01, 24.91it/s]

torch.Size([809, 64])
torch.Size([6603, 64])
{('author', 'to', 'paper'): EdgeIndex([[ 364,  365,  365,  ...,  361,  362,  363],
           [   2,    3,    4,  ..., 1709, 1710, 1711]], device='cuda:0',
          sparse_size=(809, 6603), nnz=2062, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 379,  380,  381,  ..., 2535, 2536, 2537],
           [   0,    0,    0,  ...,  565,  565,  565]], device='cuda:0',
          sparse_size=(6603, 809), nnz=2501, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2538, 2539, 2540,  ..., 2466, 4130,  378],
           [   0,    0,    0,  ..., 1047, 1047, 1047]], device='cuda:0',
          sparse_size=(6603, 2382), nnz=8268, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[3783, 6500, 6501, 3552, 6502, 6503, 6504, 2097, 5926,  513, 6505,
            6506, 6207, 6131,   46, 6507, 6508, 6509, 2439, 4551, 1769, 2906,
            6510, 6511, 6512, 6513, 2905, 6514, 1733, 6515, 6516, 1549, 6517,
            6518,   62, 1184, 6519, 652

 22%|██▏       | 8/37 [00:00<00:01, 23.09it/s]

{('author', 'to', 'paper'): EdgeIndex([[ 365,  366,   55,  ...,  362,  806,  363],
           [   0,    2,    3,  ..., 1760, 1760, 1761]], device='cuda:0',
          sparse_size=(807, 6433), nnz=2117, sort_order=col), ('paper', 'to', 'author'): EdgeIndex([[ 381,  382,  383,  ...,  651, 2258,  657],
           [   0,    0,    0,  ...,  552,  552,  552]], device='cuda:0',
          sparse_size=(6433, 807), nnz=2411, sort_order=col), ('paper', 'to', 'term'): EdgeIndex([[2466,    0, 2467,  ..., 2368, 2504,  380],
           [   0,    0,    0,  ..., 1003, 1003, 1003]], device='cuda:0',
          sparse_size=(6433, 2430), nnz=7881, sort_order=col), ('paper', 'to', 'conference'): EdgeIndex([[6310, 6311, 6312, 6313, 6314,  842,  841, 3287, 4413, 6315, 6316,
            6317, 6318, 6319,   52, 6320, 6321, 6322, 6323, 6324, 6325, 6326,
            6327, 6328, 1221, 6329, 6330, 6331, 6332,  382, 6333, 6334, 6335,
            6336, 6337, 2263, 6338, 6339, 6340, 6341, 6342, 2140, 4486, 6343,
      




KeyboardInterrupt: 