In [1]:
import torch
import numpy as np
from collections import defaultdict
from tt import get_geometric_data
from torch.nn import GRU
from torch_geometric.data import Data

In [2]:
HIDDEN_SIZE = 256

In [3]:
class TFG:
    def __init__(self, data):
        self.data = data
        self.feature_size = self.data.x.shape[1]
        self.word_bigru = GRU(self.feature_size, HIDDEN_SIZE, bidirectional=True)
        self.relationship_bigru = GRU(self.feature_size, HIDDEN_SIZE, bidirectional=True)
        
    def train(self):
        sv_map = defaultdict(list)
        vo_map = defaultdict(list)
        
        # assume retrieve relationship data
        edge_index = self.data.edge_index.numpy()
        node_attr = self.data.y.numpy()
        new_edge_index = []
        for i, edge_type in enumerate(self.data.edge_attr.numpy()):
            if edge_type == 1:
                if node_attr[edge_index[i][0]] == 1:
                    vo_map[edge_index[i][0]].append(edge_index[i][1])
                else:
                    sv_map[edge_index[i][1]].append(edge_index[i][0])
            else:
                new_edge_index.append(edge_index[i])
        
        
        print(vo_map)
        # convert relationship data to be Tensor
        relationship_vector = []
        x = self.data.x.numpy()
        for key, subjects in sv_map.items():
            for subj in subjects:
                if not vo_map[key]:
                    relationship_vector.append([x[subj], x[key], np.zeros(self.feature_size)])
                    new_edge_index.append([subj, key])
                for obj in vo_map[key]:
                    relationship_vector.append([x[subj], x[key], x[obj]])
                    new_edge_index.append([subj, obj])
                    
        for key, objects in vo_map.items():
            for obj in objects:
                if not sv_map[key]:
                    relationship_vector.append([np.zeros(self.feature_size), x[key], x[obj]])
                    new_edge_index.append([key, obj])             
            
        print(new_edge_index)
        print("relationship:",relationship_vector)
        relationships = torch.Tensor(relationship_vector)
        
        # run two kinds of bigru
        out1, hidden1 = self.word_bigru(self.data.x.unsqueeze(0))
        out2, hidden2 = self.relationship_bigru(relationships) if relationships.shape[0] > 0 else (torch.tensor([]), torch.tensor([]))
        
        word_level_vector = (torch.sum(hidden1, 0) / 2)[1:]
        path_level_vector = torch.sum(hidden2, 0) / 2
        combined_vector = torch.cat((word_level_vector, path_level_vector), 0)
        
        data = Data(x=self.data.x, y=self.data.y, edge_index=torch.tensor(new_edge_index), edge_attr=combined_vector)
        return data

In [4]:
data = get_geometric_data()
print(data)
tfg_model = TFG(data)
tfg_model.train()

Data(edge_attr=[16], edge_index=[16, 2], x=[12, 3], y=[12, 1])
defaultdict(<class 'list'>, {2: [3], 5: [7], 8: [11]})
[array([0, 1]), array([1, 2]), array([2, 3]), array([3, 4]), array([4, 5]), array([5, 6]), array([6, 7]), array([7, 8]), array([8, 9]), array([ 9, 10]), array([10, 11]), [1, 3], [1, 7], [8, 11]]
relationship: [[array([0., 0., 0.], dtype=float32), array([0., 0., 0.], dtype=float32), array([0., 0., 0.], dtype=float32)], [array([0., 0., 0.], dtype=float32), array([0., 0., 0.], dtype=float32), array([0., 0., 0.], dtype=float32)], [array([0., 0., 0.]), array([0., 0., 0.], dtype=float32), array([0., 0., 0.], dtype=float32)]]


Data(edge_attr=[14, 256], edge_index=[14, 2], x=[12, 3], y=[12, 1])