## Process pretrained node/edge embeddings for knowledge graph

In [12]:
import torch
from kge.model import KgeModel
from kge.util.io import load_checkpoint

In [2]:
save_folder = '../data/wn18rr_kg'
model_path = 'local/experiments/20210329-115948-wn18rr-rescal-train/checkpoint_best.pt'
checkpoint = load_checkpoint(model_path)
model = KgeModel.create_from(checkpoint)

Loading configuration of dataset wnrr from /home/wuyx/KGE/data/wnrr ...
Loaded 11 keys from map relation_ids
Setting reciprocal_relations_model.base_model.entity_embedder.dropout to 0., was set to -0.29821809173392233.


In [3]:
entity_embedder = model.get_s_embedder()
relation_embedder = model.get_p_embedder()

In [4]:
def dataset_statics(dataset):
    """
    Input:  dataset -> [kge.dataset.Dataset]
            Dataset for training embeddings(WordNet)
    
    Output: [None]
    
    Utils : print statistical information for dataset
    """
    num_entities = dataset.num_entities()
    num_relations = dataset.num_relations()
    
    ent = torch.Tensor([i for i in range(5)]).long()             # subject indexes
    rel = torch.Tensor([i for i in range(3)]).long()             # relation indexes
    demo_entities = dataset.entity_strings(ent)
    demo_relations = dataset.relation_strings(rel)
    
    print("Number of Nodes: {:d}".format(num_entities))
    print("Semantics of First 5 Nodes: ", demo_entities)
    print()
    
    print("Number of Edge Types: {:d}".format(num_relations))
    print("Semantics of First 3 Edge Types: ", demo_relations)
    print()
    
    num_triples = []
    for split in ["train", "test", "valid"]:
        dataset.load_triples(split)
        print("Number of Edges in {:s}-set : {:d}".\
              format(split, dataset._triples[split].size(0)))
    print()

        
def get_vocab(dataset):
    """
    Input:  dataset -> [kge.dataset.Dataset]
            Dataset for training embeddings(WordNet)
    
    Output: vocab -> [Dict]
            {word: word_id}
    """
    return {dataset._meta['entity_strings'][i]: i \
             for i in range(dataset.num_entities())}


def tokens_to_ids(vocab, batched_data):
    """
    Input:  vocab -> [kge.dataset.Dataset]
            Dataset for training embeddings(WordNet)
            
            batched_data --> [list, np.array]
            Batched and tokenized data with shape (Batch_size, Sequence_Length)
            
    Output: entities_id_list -> [torch.LongTensor]
            shape (Batch_size, Sequence_Length)
    """
    def tokens_to_word_ids(tokens, vocab):
        return [vocab[word] for word in tokens if word in vocab.keys()]
    
    entities_id_list = [tokens_to_word_ids(seq, vocab) for seq in batched_data]
    return entities_id_list

def extract_relations(dataset):
    """
    Input:  dataset -> [kge.dataset.Dataset]
            Dataset for training embeddings(WordNet)
            
    Output: edge_index -> [torch.LongTensor] with shape (2, E)
            edge_type  -> [torch.LongTensor] with shape (E) 
    """
    all_rels = torch.tensor([])
    for split in ["train", "test", "valid"]:
        dataset.load_triples(split)
        all_rels = torch.cat([all_rels, dataset._triples[split]], dim=0)
    
    edge_index = all_rels[:, [0, 2]].T.long()
    edge_type = all_rels[:, 1].long()
    return edge_index, edge_type


def get_edge_embs(num_edge_types, egde_embedder):
    all_edges = torch.tensor(range(num_edge_types)).long()
    edge_embeddings = egde_embedder(all_edges)
    return edge_embeddings


def get_node_embs(num_node_types, node_embedder):
    all_nodes = torch.tensor(range(num_node_types)).long()
    node_embeddings = node_embedder(all_nodes)
    return node_embeddings


In [5]:
# Test above functions

dataset = model.dataset
dataset_statics(dataset)

vocab = get_vocab(dataset)
print("the entity id for 'world' is : ", vocab["world"])
print()

entities_id_list = tokens_to_ids(vocab, [["hi", "i", "am", "your", "friend", "good", "day"],
                                         ["hate", "terrible", "weather"]])
print("input test:", entities_id_list)

edge_index, edge_type = extract_relations(dataset)

x = get_node_embs(dataset.num_entities(), entity_embedder)
edge_attr = get_edge_embs(dataset.num_relations(), relation_embedder)

print("edge_index ", edge_index.size())
print("edge_type ", edge_type.size())
print("x ", x.size())
print("edge_attr ", edge_attr.size())

Loaded 40943 keys from map entity_ids
Loaded 40943 keys from map entity_strings
Loaded 11 keys from map relation_strings
Number of Nodes: 40943
Semantics of First 5 Nodes:  ['land_reform' 'reform' 'cover' 'covering' 'phytology']

Number of Edge Types: 11
Semantics of First 3 Edge Types:  ['_hypernym' '_derivationally_related_form' '_instance_hypernym']

Loaded 86835 train triples
Number of Edges in train-set : 86835
Loaded 3134 test triples
Number of Edges in test-set : 3134
Loaded 3034 valid triples
Number of Edges in valid-set : 3034

the entity id for 'world' is :  27976

input test: [[7986, 28531, 4087, 36371], [21866, 34808]]
edge_index  torch.Size([2, 93003])
edge_type  torch.Size([93003])
x  torch.Size([40943, 128])
edge_attr  torch.Size([11, 16384])


In [6]:
# prepare learned knn graph
from torch_cluster.knn import knn_graph
knn_edge_index = knn_graph(x, k=3)

In [7]:
# test knn graph
from torch_geometric.utils import remove_self_loops
knn_edge_index, _ = remove_self_loops(knn_edge_index)
row, col = knn_edge_index[:, :10]
row_sent = dataset.entity_strings(row)
col_sent = dataset.entity_strings(col)
print(row_sent)
print(col_sent)

['pronunciamento' 'chartist' 'passive_resister' 'event_planner'
 'paternalism' 'utopian' 'sergei_mikhailovich_eisenstein' 'samuel_goldwyn'
 'reform' 'tenderization']
['land_reform' 'land_reform' 'land_reform' 'land_reform' 'land_reform'
 'land_reform' 'land_reform' 'land_reform' 'reform' 'reform']


In [8]:
# save processed data
import os.path as osp

torch.save(vocab, osp.join(save_folder, 'vocab.pt'))
torch.save(x, osp.join(save_folder, 'wn18rr_x.pt'))
torch.save(edge_index, osp.join(save_folder, 'wn18rr_edge_index.pt'))
torch.save(knn_edge_index, osp.join(save_folder, 'wn18rr_knn_edge_index.pt'))
torch.save(edge_type, osp.join(save_folder, 'wn18rr_edge_type.pt'))
torch.save(edge_attr, osp.join(save_folder, 'wn18rr_edge_attr.pt'))

In [34]:
import re
import os
import networkx as nx
import os.path as osp
import numpy as np
import torch
from torch_geometric.utils.num_nodes import maybe_num_nodes

save_folder = '../data/wn18rr_kg'
vocab = torch.load(osp.join(save_folder, 'vocab.pt'))
all_x = torch.load(osp.join(save_folder, 'wn18rr_x.pt'))
all_edge_index = torch.load(osp.join(save_folder, 'wn18rr_edge_index.pt'))
all_edge_type = torch.load(osp.join(save_folder, 'wn18rr_edge_type.pt'))
all_edge_attr = torch.load(osp.join(save_folder, 'wn18rr_edge_attr.pt'))
knn_edge_index = torch.load(osp.join(save_folder, 'wn18rr_knn_edge_index.pt'))

# simply cat all_edge_index and knn_edge_index, and use the mean edge_attr as knn_edge_attr 
all_edge_index = torch.cat([all_edge_index, knn_edge_index], dim=1)
knn_edge_type = max(all_edge_type) + 1
all_edge_type = torch.cat([all_edge_type, torch.ones(knn_edge_index.size(1)).long() * knn_edge_type], dim=0)
all_edge_attr = torch.cat([all_edge_attr, all_edge_attr.mean(dim=0).unsqueeze(dim=0)], dim=0)

In [37]:
print('all_x:', all_x.size())
print('all_edge_index:', all_edge_index.size())
print('all_edge_type:', all_edge_type.size())
print('all_edge_attr:', all_edge_attr.size())

all_x: torch.Size([40943, 128])
all_edge_index: torch.Size([2, 420547])
all_edge_type: torch.Size([420547])
all_edge_attr: torch.Size([12, 16384])


## Debug for Subgraph extraction

In [12]:

import re
import os
import networkx as nx
import os.path as osp
import numpy as np
import torch
from torch_geometric.utils.num_nodes import maybe_num_nodes

class Sentence2Graph(object):
    
    LARGE_NUM = 1e10
    
    def __init__(self, kg_folder, num_nodes=None):
        """
        all_x          -> [torch.FloatTensor]  shape: (N, node_emb_dim)
        all_edge_index -> [torch.LongTensor]   shape: (2, E)
        all_edge_type  -> [torch.LongTensor]   shape: (E)
        all_edge_attr  -> [torch.FloatTensor]  shape: (n_edge_type, edge_emb_dim)
        vocab          -> [Dict] {word: word_id}
        """
        self.vocab = torch.load(osp.join(kg_folder, 'vocab.pt'))
        self.all_x = torch.load(osp.join(kg_folder, 'wn18rr_x.pt'))
        all_edge_index = torch.load(osp.join(kg_folder, 'wn18rr_edge_index.pt'))
        all_edge_type = torch.load(osp.join(kg_folder, 'wn18rr_edge_type.pt'))
        all_edge_attr = torch.load(osp.join(kg_folder, 'wn18rr_edge_attr.pt'))
        knn_edge_index = torch.load(osp.join(kg_folder, 'wn18rr_knn_edge_index.pt'))

        # simply cat all_edge_index and knn_edge_index, and use the mean edge_attr as knn_edge_attr 
        self.all_edge_index = torch.cat([all_edge_index, knn_edge_index], dim=1)
        knn_edge_type = max(all_edge_type) + 1
        self.all_edge_type = torch.cat([all_edge_type, torch.ones(knn_edge_index.size(1)).long() * knn_edge_type], dim=0)
        self.all_edge_attr = torch.cat([all_edge_attr, all_edge_attr.mean(dim=0).unsqueeze(dim=0)], dim=0)

        self.num_nodes = maybe_num_nodes(all_edge_index, num_nodes)
        self.G = nx.Graph() # undirected
        self.G.add_nodes_from(range(self.num_nodes))
        self.G.add_edges_from(list(all_edge_index.cpu().numpy().T))
    
    @staticmethod
    def split_line(line):
        '''split given line/phrase into list of words

        Input:   line   -> [str]
                string representing phrase to be split

        Output: strings -> [list]
                list of strings, with each string representing a word
        '''
        return  re.findall(r"[\w']+|[.,!?;]", line)

    @staticmethod
    def tokens_to_ids(vocab, batched_data):
        """
        Input:  vocab -> [kge.dataset.Dataset]
                Dataset for training embeddings(WordNet)

                batched_data --> [list, np.array]
                Batched and tokenized data with shape (Batch_size, Sequence_Length)

        Output: entities_id_list -> [list]
                shape (Batch_size, Sequence_Length)
        """
        def tokens_to_word_ids(tokens, vocab):
            return list(set([vocab[word] for word in tokens if word in vocab.keys()]))

        entities_id_list = [tokens_to_word_ids(seq, vocab) for seq in batched_data]
        return entities_id_list

    @staticmethod
    def subgraph(subset, edge_index, edge_type=None, edge_attr=None, num_nodes=None):
        """
        Returns the induced subgraph of :obj:`(edge_index, edge_attr)`
        containing the nodes in :obj:`subset`.

        Input:  subset         -> [torch.LongTensor]  
                edge_index     -> [torch.LongTensor]   shape: (2, E)
                edge_type      -> [torch.LongTensor]   shape: (E)
                edge_attr      -> [torch.FloatTensor]  shape: (n_edge_type, edge_emb_dim)

        Output: 
                masked_edge_index    -> [torch.LongTensor]
                masked_edge_attr     -> [torch.FloatTensor]
        """
        device = edge_index.device

        if isinstance(subset, list) or isinstance(subset, tuple):
            subset = torch.tensor(subset, dtype=torch.long)

        if subset.dtype == torch.bool or subset.dtype == torch.uint8:
            n_mask = subset
        else:
            num_nodes = maybe_num_nodes(edge_index, num_nodes)
            n_mask = torch.zeros(num_nodes, dtype=torch.bool)
            n_mask[subset] = 1

        mask = n_mask[edge_index[0]] & n_mask[edge_index[1]]
        masked_edge_index = edge_index[:, mask]

        if edge_type is not None:
            assert edge_attr is not None
            masked_edge_type = edge_type[mask]
            masked_edge_attr = edge_attr[masked_edge_type]
        else:
            masked_edge_attr = None

        return masked_edge_index, masked_edge_attr, masked_edge_type
    
    def __relabel__(self, sub_nodes, edge_index):

        sub_nodes = torch.tensor(sub_nodes).long()
        row, col = edge_index
        # remapping the nodes in the explanatory subgraph to new ids.
        node_idx = row.new_full((self.num_nodes,), -1)
        node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device)
        relabeled_edge_index = node_idx[edge_index]
        return relabeled_edge_index
    
    def reduce(self, batched_sentence, device, raw=False):
        """
        Input:  batched_sentence -> [list]

                e.g. [["good", "day", "i", "am", "your", "friend"],
                      ["hate", "terrible", "weather"]]
                      then set raw = False

                or  [["good day, i am your friend"],
                      ["hate terrible weather"]
                    set raw = False

        Output:  x, edge_index, edge_attr: information for subgraphs
                batch             ->   [torch.LongTensor]  shape: (N)
                num_nodes         ->   [torch.LongTensor]  shape: (N_graphs)
                num_edges         ->   [torch.LongTensor]  shape: (N_graphs)
                flat_entities_id  ->   [list]              
                edge_type         ->   [torch.LongTensor]  shape: (E)
        """
        if raw:
            batched_sentence = [self.split_line(sentence.lower()) for sentence in batched_sentence]
        entities_id = self.tokens_to_ids(self.vocab, batched_sentence)

        num_nodes = torch.tensor([len(entities) for entities in entities_id]).long()
        batch = torch.tensor([i for i in range(len(num_nodes)) for _ in range(num_nodes[i])]).long()
        cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]]).long()
        
        def extract_subgraph(entities):
            node_idx = torch.unique(torch.tensor(entities).long())
            masked_edge_index, masked_edge_attr, masked_edge_type = self.subgraph(node_idx, self.all_edge_index, self.all_edge_type, self.all_edge_attr)
            return masked_edge_index, masked_edge_attr, masked_edge_type
        
        num_edges = []
        x = torch.tensor([])
        edge_index = torch.tensor([[], []])
        edge_attr = torch.tensor([])
        edge_type = torch.tensor([])
        
        for i, entities in enumerate(entities_id):
            x = torch.cat([x, self.all_x[entities]], dim=0)
            masked_edge_index, masked_edge_attr, masked_edge_type = extract_subgraph(entities)
            masked_edge_index = self.__relabel__(entities, masked_edge_index) + cum_nodes[i]
            
            edge_index = torch.cat([edge_index, masked_edge_index], dim=1)
            edge_type = torch.cat([edge_type, masked_edge_type], dim=0)
            edge_attr = torch.cat([edge_attr, masked_edge_attr], dim=0)
            num_edges.append(masked_edge_index.size(1))
            
        num_edges = torch.tensor(num_edges).long()
        x = x.to(device)
        edge_index = edge_index.long().to(device)
        edge_type = edge_type.long().to(device)
        edge_attr = edge_attr.to(device)
        num_nodes = num_nodes.to(device)
        num_edges = num_edges.to(device)
        
        flat_entities_id = []
        for entities in entities_id:
            flat_entities_id.extend(entities)
        flat_entities_id = torch.tensor(flat_entities_id).long().to(device)
        
        assert x.size(0) == flat_entities_id.size(0)
        assert x.size(0) == batch.size(0)
        
        assert edge_index.size(1) == edge_attr.size(0)
        assert edge_index.size(1) == edge_type.size(0)
        assert len(num_nodes) == len(num_edges)
        
        return x, edge_index, edge_attr, batch, num_nodes, num_edges, flat_entities_id, edge_type
    
    def reduce_connected(self, batched_sentence, device, raw=False):
        """
        
        
        Input:  batched_sentence -> [list]

                e.g. [["good", "day", "i", "am", "your", "friend"],
                      ["hate", "terrible", "weather"]]
                      then set raw=True

                or  [["good day, i am your friend"],
                      ["hate terrible weather"]
                    set raw = False

        Output:  x, edge_index, edge_attr: information for subgraphs
                batch             ->   [torch.LongTensor]  shape: (N)
                num_nodes         ->   [torch.LongTensor]  shape: (N_graphs)
                num_edges         ->   [torch.LongTensor]  shape: (N_graphs)
                flat_entities_id  ->   [list]              
                edge_type         ->   [torch.LongTensor]  shape: (E)
        """
        # 1. prerpocess the input sentences into lists of entities
        if raw:
            batched_sentence = [self.split_line(sentence.lower()) for sentence in batched_sentence]
        entities_id = self.tokens_to_ids(self.vocab, batched_sentence)
        num_edges = []
        num_nodes = []
        cum_nodes = [0]
        flat_entities_id = []
        x = torch.tensor([]).to(device)
        edge_index = torch.tensor([[], []]).long().to(device)
        edge_attr = torch.tensor([]).to(device)
        edge_type = torch.tensor([]).long().to(device)
        for subset in entities_id:
            # 2. for each node subset, we calculate the pair-wise distance
            # ... and get the minimum spanning tree whose nodes contains subset 
            n =  len(subset)
            flag = 1
            if n < 2:
                num_nodes.append(n)
                num_edges.append(0)
                cum_nodes.append(cum_nodes[-1] + num_nodes[-1]) 
                x = torch.cat([x, self.all_x[subset].to(device)], dim=0)
                flat_entities_id.extend(subset)
                continue
                
            g = nx.DiGraph()
            adj = torch.ones([n, n]) * self.LARGE_NUM
            paths = {i:{} for i in range(n)}
            g.add_nodes_from(subset)
            for i in range(n):
                for j in range(i + 1, n):
                    try:
                        path_ij = nx.shortest_path(self.G, source=subset[i], target=subset[j])
                        g.add_weighted_edges_from([(subset[i], subset[j], len(path_ij))])
                        g.add_weighted_edges_from([(subset[j], subset[i], len(path_ij))])
                        paths[i][j] = path_ij
                    except:
                        flag = 0
                        break
                if flag ==0 :
                    break
            if flag ==0 :
                num_nodes.append(n)
                num_edges.append(0)
                cum_nodes.append(cum_nodes[-1] + num_nodes[-1]) 
                x = torch.cat([x, self.all_x[subset].to(device)], dim=0)
                flat_entities_id.extend(subset)
                continue
            sg = nx.minimum_spanning_arborescence(g)
            sg_edges = list(sg.edges)
            
            # 3. collect the paths(each path representes an ego edge in the spanning tree) 
            # ... into one graph
            sub_nodes = []
            return_edge_index = []
            return_edge_type = []
            row, col = self.all_edge_index
            for e in sg_edges:
                _i = subset.index(e[0]); _j = subset.index(e[1])
                i = min(_i, _j); j = max(_i, _j)
                p = paths[i][j]
                sub_nodes.extend(p)
                for k in range(len(p)-1):
                    return_edge_index.append([p[k], p[k+1]])
                    mask = (row == p[k]) * (col == p[k+1])
                    if mask.sum() > 0:
                        return_edge_type.append(self.all_edge_type[torch.nonzero(mask).view(-1)[0]])
                    else:
                        mask = (row == p[k+1]) * (col == p[k])
                        return_edge_type.append(self.all_edge_type[torch.nonzero(mask).view(-1)[0]])
                    
            # 4. aggregate the subgraph into batch
            sub_nodes = torch.tensor(sub_nodes).long().unique().to(device)
            num_nodes.append(sub_nodes.size(0))
            return_x = self.all_x[sub_nodes].to(device)
            
            return_edge_index = torch.tensor(return_edge_index).long().T.to(device)
            return_edge_index = self.__relabel__(sub_nodes, return_edge_index) + cum_nodes[-1]
            
            return_edge_type = torch.tensor(return_edge_type).to(device)
            return_edge_attr = self.all_edge_attr[return_edge_type].to(device)
            
            cum_nodes.append(cum_nodes[-1] + num_nodes[-1])
            num_edges.append(return_edge_index.size(1))
            
            x = torch.cat([x, return_x], dim=0)
            edge_index = torch.cat([edge_index, return_edge_index], dim=1)
            edge_attr = torch.cat([edge_attr, return_edge_attr], dim=0)
            edge_type = torch.cat([edge_type, return_edge_type], dim=0)
            flat_entities_id.extend(sub_nodes.tolist())
            
        num_nodes = torch.tensor(num_nodes).long()
        num_edges = torch.tensor(num_edges).long()
        num_nodes = num_nodes.to(device)
        num_edges = num_edges.to(device)
        flat_entities_id = torch.tensor(flat_entities_id).long().to(device)
        batch = torch.tensor([i for i in range(len(num_nodes)) for _ in range(num_nodes[i])]).long().to(device)
        
        
        assert x.size(0) == flat_entities_id.size(0)
        assert x.size(0) == batch.size(0)
        
        assert edge_index.size(1) == edge_attr.size(0)
        assert edge_index.size(1) == edge_type.size(0)
        assert len(num_nodes) == len(num_edges)
        
        return x, edge_index, edge_attr, batch, num_nodes, num_edges, flat_entities_id, edge_type

In [13]:
import numpy as np

def batch_seq(sentences, intervel=1):
    # assume the lengthes are the same for the sentences
    sentences = np.array(sentences, dtype=np.str)
    l = sentences.shape[1]
    seq = []
    for i in range(1, l+1):
        seq.extend([sentences[:, :i].tolist()])
    seq = np.array(seq).T
    seq = seq.tolist()
    
    res_seq = []
    for subseq in seq:
        res_seq.extend(subseq)
    return res_seq
seq = batch_seq([['you', 'are', 'a', 'kind', 'man'], ['what', 'a', 'beautiful', 'day', "!"]])
print(seq)

[['you'], ['you', 'are'], ['you', 'are', 'a'], ['you', 'are', 'a', 'kind'], ['you', 'are', 'a', 'kind', 'man'], ['what'], ['what', 'a'], ['what', 'a', 'beautiful'], ['what', 'a', 'beautiful', 'day'], ['what', 'a', 'beautiful', 'day', '!']]


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sentence2graph = Sentence2Graph(kg_folder='wn18rr_kg')
x, edge_index, edge_attr, batch, num_nodes, num_edges, entities_id, edge_type = sentence2graph.reduce_connected(batched_sentence=seq, device=device)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [15]:
edge_attr.size()

torch.Size([7, 16384])

### Toy Example  (without SubgraphFinder)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batched_sentence = [[]]
x, edge_index, edge_attr, batch, num_nodes, num_edges, entities_id, edge_type = sentence2graph.reduce(batched_sentence, device=device)
print('x:', x.size())
print('edge_index:', edge_index.size())
print('edge_attr:', edge_attr.size())
print('batch:', batch.size())
print('num subgraphs:', len(torch.unique(batch)))
print('num nodes:', num_nodes)
print('num edges:', num_edges)

x: torch.Size([7, 128])
edge_index: torch.Size([2, 2])
edge_attr: torch.Size([2, 16384])
batch: torch.Size([7])
num subgraphs: 3
num nodes: tensor([3, 2, 2], device='cuda:0')
num edges: tensor([0, 0, 2], device='cuda:0')


#### check the returned graphs

In [17]:
from kge.model import KgeModel
from kge.util.io import load_checkpoint

save_folder = 'wn18rr_kg'
model_path = 'local/experiments/20210329-115948-wn18rr-rescal-train/checkpoint_best.pt'
checkpoint = load_checkpoint(model_path)
model = KgeModel.create_from(checkpoint)

dataset = model.dataset
print(dataset.entity_strings(entities_id))

for i in range(edge_index.size(1)):
    print("%s-->%s-->%s" % (dataset.entity_strings(entities_id[edge_index[0, i]]), 
                            dataset.relation_strings(edge_type[i]), 
                            dataset.entity_strings(entities_id[edge_index[1, i]])))

Loading configuration of dataset wnrr from /home/wuyx/MetaKG/KGE/data/wnrr ...
Loaded 11 keys from map relation_ids
Setting reciprocal_relations_model.base_model.entity_embedder.dropout to 0., was set to -0.29821809173392233.
Loaded 40943 keys from map entity_ids
Loaded 40943 keys from map entity_strings
['friend' 'day' 'good' 'weather' 'hate' 'phytology' 'botanize']
Loaded 11 keys from map relation_strings
phytology-->_derivationally_related_form-->botanize
botanize-->_derivationally_related_form-->phytology


### Toy Example  (with SubgraphFinder)

In [120]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batched_sentence = ["Thank you for the love and encourgement, sweetheart"]
                    # [["good", "day", "i", "am", "your", "friend"],
                    # ["hate", "terrible", "weather"],
                    # ["phytology", "botanize"]]
print(sentence2graph.num_nodes)
x, edge_index, edge_attr, batch, num_nodes, num_edges, entities_id, edge_type = sentence2graph.reduce_connected(batched_sentence, device=device, raw=True)
print('x:', x.size())
print('edge_index:', edge_index.size())
print('edge_attr:', edge_attr.size())
print('batch:', batch.size())
print('num subgraphs:', len(torch.unique(batch)))
print('num nodes:', num_nodes)
print('num edges:', num_edges)

40943
x: torch.Size([6, 128])
edge_index: torch.Size([2, 5])
edge_attr: torch.Size([5, 16384])
batch: torch.Size([6])
num subgraphs: 1
num nodes: tensor([6], device='cuda:0')
num edges: tensor([5], device='cuda:0')



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



#### see if the reduced connected graph makes sense

In [121]:
print(dataset.entity_strings(entities_id))

for i in range(edge_index.size(1)):
    
    _type = 'knn-link' if edge_type[i] > 10 else dataset.relation_strings(edge_type[i])
    print("%s-->%s-->%s" % (dataset.entity_strings(entities_id[edge_index[0, i]]), 
                            _type, 
                            dataset.entity_strings(entities_id[edge_index[1, i]])))

['woman' 'soul' 'lover' 'female_person' 'love' 'sweetheart']
love-->_derivationally_related_form-->lover
lover-->_hypernym-->soul
soul-->_hypernym-->female_person
female_person-->_hypernym-->woman
woman-->_hypernym-->sweetheart


In [124]:
dataset.relation_strings(torch.LongTensor([i for i in range(10)]))

array(['_hypernym', '_derivationally_related_form', '_instance_hypernym',
       '_also_see', '_member_meronym', '_synset_domain_topic_of',
       '_has_part', '_member_of_domain_usage', '_member_of_domain_region',
       '_verb_group'], dtype='<U28')