In [8]:
from data_utils import load_graph, load_queries_by_formula
from metamodel import MetaModel
from loader import collate_query_data, QueryBatch

In [9]:
embed_dim = 128
data_dir = "./data/AIFB/processed/"
graph, embedding_modules, node_maps = load_graph(data_dir, embed_dim)

In [13]:
train_queries = load_queries_by_formula(data_dir + "/train_edges.pkl")
for i in range(2, 4):
    train_queries.update(load_queries_by_formula(data_dir + "/train_queries_{:d}.pkl".format(i)))

In [45]:
chainqueries = train_queries['1-chain']
first_formula = list(chainqueries.keys())[0]
formqueries = chainqueries[first_formula]

In [44]:
import numpy as np
from graph import _reverse_relation
from data_utils import RGCNQueryDataset
import torch
from torch_geometric.data import Data, Batch

def get_query_graph(formula, queries, rel_ids, mode_ids):
    batch_size = len(queries)
    n_anchors = len(formula.anchor_modes)

    anchor_ids = np.empty([batch_size, n_anchors]).astype(np.int)
    # First rows of x contain embeddings of all anchor nodes
    for i, anchor_mode in enumerate(formula.anchor_modes):
        anchors = [q.anchor_nodes[i] for q in queries]
        anchor_ids[:, i] = anchors

    # The rest of the rows contain generic mode embeddings for variables
    all_nodes = formula.get_nodes()
    print(all_nodes) # ['project', 'publication']
    var_idx = RGCNQueryDataset.variable_node_idx[formula.query_type]
    print(var_idx) # [0]
    var_ids = np.array([mode_ids[all_nodes[i]] for i in var_idx],
                        dtype=np.int)
    print(var_ids) # [5]

    edge_index = RGCNQueryDataset.query_edge_indices[formula.query_type]
    edge_index = torch.tensor(edge_index, dtype=torch.long)

    rels = formula.get_rels()
    rel_idx = RGCNQueryDataset.query_edge_label_idx[formula.query_type]
    edge_type = [rel_ids[_reverse_relation(rels[i])] for i in rel_idx]
    edge_type = torch.tensor(edge_type, dtype=torch.long)

    edge_data = Data(edge_index=edge_index)
    edge_data.edge_type = edge_type
    edge_data.num_nodes = n_anchors + len(var_idx)
    graph = Batch.from_data_list([edge_data for i in range(batch_size)])
    print(graph.edge_index)

    return (torch.tensor(anchor_ids, dtype=torch.long),
            torch.tensor(var_ids, dtype=torch.long),
            graph)

In [17]:
mode_ids = {}
mode_id = 0
for mode in graph.mode_weights:
    mode_ids[mode] = mode_id
    mode_id += 1

rel_ids = {}
id_rel = 0
for r1 in graph.relations:
    for r2 in graph.relations[r1]:
        rel = (r1, r2[1], r2[0])
        rel_ids[rel] = id_rel
        id_rel += 1

In [42]:
anchor_ids, var_ids, q_graphs = get_query_graph(first_formula, formqueries, rel_ids, mode_ids)

['project', 'publication']
[0]
[5]
tensor([[   0,    2,    4,  ..., 1532, 1534, 1536],
        [   1,    3,    5,  ..., 1533, 1535, 1537]])


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


In [35]:
# length = 769
anchor_ids

tensor([[ 296],
        [1059],
        [1401],
        [1338],
        [1145],
        [1661],
        [ 299],
        [1928],
        [1770],
        [1354],
        [1145],
        [ 233],
        [ 829],
        [ 975],
        [ 802],
        [ 290],
        [1530],
        [ 392],
        [ 928],
        [ 997],
        [ 479],
        [ 996],
        [2281],
        [ 534],
        [ 920],
        [ 634],
        [1086],
        [1134],
        [1495],
        [  94],
        [1295],
        [1184],
        [  28],
        [ 144],
        [1576],
        [1308],
        [1215],
        [1112],
        [ 587],
        [ 664],
        [ 134],
        [  95],
        [1017],
        [1579],
        [ 290],
        [1354],
        [ 838],
        [1724],
        [  42],
        [ 955],
        [ 412],
        [1770],
        [ 836],
        [ 154],
        [1209],
        [ 166],
        [1115],
        [ 907],
        [2008],
        [ 928],
        [ 851],
        [ 821],
        

In [36]:
var_ids

tensor([5])

In [27]:
q_graphs

DataBatch(edge_index=[2, 769], edge_type=[769], num_nodes=1538, batch=[1538], ptr=[770])

In [28]:
q_graphs.edge_index

tensor([[   0,    2,    4,  ..., 1532, 1534, 1536],
        [   1,    3,    5,  ..., 1533, 1535, 1537]])

In [38]:
q_graphs.edge_type

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [31]:
q_graphs.num_nodes

1538

In [32]:
q_graphs.batch

tensor([  0,   0,   1,  ..., 767, 768, 768])

In [33]:
q_graphs.ptr

tensor([   0,    2,    4,    6,    8,   10,   12,   14,   16,   18,   20,   22,
          24,   26,   28,   30,   32,   34,   36,   38,   40,   42,   44,   46,
          48,   50,   52,   54,   56,   58,   60,   62,   64,   66,   68,   70,
          72,   74,   76,   78,   80,   82,   84,   86,   88,   90,   92,   94,
          96,   98,  100,  102,  104,  106,  108,  110,  112,  114,  116,  118,
         120,  122,  124,  126,  128,  130,  132,  134,  136,  138,  140,  142,
         144,  146,  148,  150,  152,  154,  156,  158,  160,  162,  164,  166,
         168,  170,  172,  174,  176,  178,  180,  182,  184,  186,  188,  190,
         192,  194,  196,  198,  200,  202,  204,  206,  208,  210,  212,  214,
         216,  218,  220,  222,  224,  226,  228,  230,  232,  234,  236,  238,
         240,  242,  244,  246,  248,  250,  252,  254,  256,  258,  260,  262,
         264,  266,  268,  270,  272,  274,  276,  278,  280,  282,  284,  286,
         288,  290,  292,  294,  296,  2