In [9]:
import os.path as osp
import random
import json
import torch

from torch_geometric.data import Data
from torch_geometric.utils import subgraph, k_hop_subgraph

In [10]:
DATA_PATH = '../data/wikidata5m_inductive'
ENTITY_2_ID_PATH = osp.join(DATA_PATH, 'uri_to_id.json')
RELATION_2_ID_PATH = osp.join(DATA_PATH, 'relation_uri_to_id.json')

with open(ENTITY_2_ID_PATH) as uri_to_id_in:
    uri_to_id = json.load(uri_to_id_in)

with open(RELATION_2_ID_PATH) as uri_to_id_in:
    relation_uri_to_id = json.load(uri_to_id_in)


In [11]:
data_train = torch.load(f'../data/wikidata5m_inductive/train.pt')
data_val = torch.load(f'../data/wikidata5m_inductive/val.pt')
data_test = torch.load(f'../data/wikidata5m_inductive/test.pt')

In [30]:
entities = list(uri_to_id.keys())

random.shuffle(entities)

In [71]:
subset_nodes = [uri_to_id[e] for e in entities[:100000]]

subset, train_edge_index_new, mapping, edge_mask = k_hop_subgraph(subset_nodes,
                                                                  num_hops=1,
                                                                  num_nodes=len(uri_to_id.keys()),
                                                                  edge_index=data_train.edge_index.type(torch.int64))

train_edge_type_new = data_train.edge_type[edge_mask]

data_train_sample = Data(edge_index=train_edge_index_new, edge_type=train_edge_type_new)
torch.save(data_train_sample, '../data/wikidata5m_inductive/train_sample.pt')

In [62]:
data_val.edge_index.size()

torch.Size([2, 6699])

In [72]:
from torch_geometric.utils import dropout_edge

maintain = 0.1
val_edge_index_new, edge_mask = dropout_edge(data_val.edge_index, p= 1 - maintain)
val_edge_type_new = data_val.edge_type[edge_mask]
data_val_sample = Data(edge_index=val_edge_index_new, edge_type=val_edge_type_new)
torch.save(data_val_sample, '../data/wikidata5m_inductive/val_sample.pt')

test_edge_index_new, edge_mask = dropout_edge(data_test.edge_index, p= 1 - maintain)
test_edge_type_new = data_test.edge_type[edge_mask]
data_test_sample = Data(edge_index=test_edge_index_new, edge_type=test_edge_type_new)
torch.save(data_test_sample, '../data/wikidata5m_inductive/test_sample.pt')

In [73]:
test_edge_index_new.size()

torch.Size([2, 656])