In [1]:
import typing as ty
import torch
import dgl

Using backend: pytorch


In [2]:
fb15 = dgl.data.FB15k237Dataset()
graph = fb15[0]

# entities: 14541
# relations: 237
# training edges: 272115
# validation edges: 17535
# testing edges: 20466
Done loading data from cached files.


In [3]:
graph

Graph(num_nodes=14541, num_edges=620232,
      ndata_schemes={'ntype': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'etype': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'test_edge_mask': Scheme(shape=(), dtype=torch.bool), 'valid_edge_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_edge_mask': Scheme(shape=(), dtype=torch.bool)})

In [4]:
def get_split(graph: dgl.DGLGraph, split_key: str) -> ty.Tuple[dgl.DGLGraph, torch.Tensor]:
    split_mask = graph.edata[f"{split_key}_edge_mask"]
    split_edges_index = torch.nonzero(split_mask, as_tuple=False).squeeze()

    split_graph = graph.edge_subgraph(split_edges_index, preserve_nodes=True)
    split_graph.edata["etype"] = graph.edata["etype"][split_edges_index]

    return split_graph, split_edges_index


In [5]:
train_g, train_edges = get_split(graph, "train")
val_g, val_edges = get_split(graph, "valid")
test_g, test_edges = get_split(graph, "test")

In [6]:
graph.edata["train_mask"].sum()

tensor(544230)

In [7]:
graph.edata["train_edge_mask"].sum()

tensor(272115)

In [8]:
graph.number_of_src_nodes()

14541

In [9]:
graph.number_of_dst_nodes()

14541

In [10]:
graph.number_of_nodes()

14541

In [11]:
import itertools
from torch.utils.data import Dataset


class DirectedDglGraphDataset(Dataset):

    def __init__(self, graph: dgl.DGLGraph):
        self.graph = graph
        self.number_of_nodes = graph.number_of_nodes()
        self.adjacency_mat = graph.adj()

    def __getitem__(self, i):
        denominator = self.number_of_nodes - 1
        src_node_index = i // denominator
        dst_node_index = i % denominator

        if src_node_index <= dst_node_index:
            dst_node_index += 1

        return {
            "src_node_index": src_node_index,
            "dst_node_index": dst_node_index,
            "relation": self.adjacency_mat[src_node_index, dst_node_index]
        }

    def __len__(self):
        return self.number_of_nodes * (self.number_of_nodes - 1)


# Very model specific dataset since it fixates what is a head and what is a tail entites
class UndirectedDglGraphDataset(Dataset):

    def __init__(self, graph: dgl.DGLGraph):
        self.graph = graph
        self.number_of_nodes = graph.number_of_nodes()
        self.adjacency_mat = graph.adj()
        self.sample = [
            (i, j)
            for i, j in itertools.product(range(self.number_of_nodes), range(self.number_of_nodes))
            if i < j # Take only upper triangle indices
        ]

    def __getitem__(self, i):
        src_node_index, dst_node_index = self.sample[i]

        return {
            "src_node_index": src_node_index,
            "dst_node_index": dst_node_index,
            "relation": self.adjacency_mat[src_node_index, dst_node_index]
        }

    def __len__(self):
        return (self.number_of_nodes * (self.number_of_nodes - 1)) // 2
