In [512]:
# def get_nested_relation_locs(rootNode, entity_pos_dict):
#     temp = []
#     for e in rootNode.arguments:
#         if e.isEntity():
#             temp.append(entity_pos_dict[e.entity])
#         else:
#             temp.append(get_locs(e,entity_pos_dict))
#     return temp

# def get_relation_locs_from_sentence(sentence):
#     entities = [ent for ent in sentence.entities if 'RELATIONSHIP' not in ent.type.name]
#     entity_pos_dict = {e: i for i,e in enumerate(entities)}
#     relation_locs = []
#     for f in sentence.formulas:
#         if f.rootNode.isPredicate() and not f.rootNode.isEntity():
#             relation_locs.append(get_nested_relation_locs(f.rootNode, entity_pos_dict))
#         else:
#             raise ValueError
#     return relation_locs

In [420]:
import sys
from os import path
import json
import pickle
import numpy as np
import hiddenlayer as hl
import torch
from torch import nn
from torch.nn import functional as functional
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torchtext.vocab import Vocab
from sklearn.model_selection import train_test_split
from itertools import combinations
from collections import Counter, namedtuple
from tqdm import tqdm
sys.path.append('../py')
sys.path.append("../lib/BioInfer_software_1.0.1_Python3/")

In [421]:
from config import (
    ENTITY_PREFIX,
    PREDICATE_PREFIX,
    EPOCHS,
    WORD_EMBEDDING_DIM,
    VECTOR_DIM,
    HIDDEN_DIM,
    RELATION_EMBEDDING_DIM,
    BATCH_SIZE,
)

In [422]:
BATCH_SIZE = 4

In [423]:
from classes import BioInferTaskConfiguration

In [424]:
def get_sentences(percent_test):
    with open("../data/text_sentences.txt") as f:
        sentences = f.read().splitlines()

    entities = json.load(open("../data/entity_labels.json", "r"))

    with open('../data/relation_labels.json','r') as f:
        relations = json.load(f)

    relations = [tuple((p, tuple(args)) for p, args in relation) for relation in relations]

    zipped_data = list(zip(sentences, entities, relations))

    # deletes sentences with no relations
    for i in range(len(zipped_data) - 1, -1, -1):
        if zipped_data[i][2] == []:
            del zipped_data[i]

    train_data, test_data = train_test_split(
        zipped_data, test_size=percent_test, random_state=0
    )

    return train_data, test_data

In [425]:
def sent_to_idxs(sentence,vocab_dict):
    token_list = sentence.split()

    index_list = []
    for token in token_list:
        if token in vocab_dict:
            index_list.append(vocab_dict[token])
        else:
            index_list.append(vocab_dict["UNK"])
    return th.LongTensor(index_list)

In [958]:
from torch.utils.data import Dataset, DataLoader

In [959]:
from torchvision import transforms, utils

In [960]:
from BIParser import BIParser

In [1158]:
class BioInferDataset(Dataset):
    def __init__(self,xml_file,entity_prefix=ENTITY_PREFIX,predicate_prefix=PREDICATE_PREFIX):
        self.entity_prefix = entity_prefix
        self.predicate_prefix = predicate_prefix
        self.parser = BIParser()
        with open(xml_file,'r') as f:
            self.parser.parse(f)
            
        self.vocab_dict = self.create_vocab_dictionary(self.parser)
        entities = self.get_entities(self.parser)
        predicates = self.get_predicates(self.parser)
        elements = entities + predicates
        self.element_to_idx = {elements[i]: i for i in range(len(elements))}
        self.schema = self.get_schema(self.parser,self.element_to_idx)
        self.inverse_schema = self.invert_schema(self.schema)
        
    def __len__(self):
        return len(self.parser.bioinfer.sentences.sentences)
    
    def __getitem__(self,idx):
        sentence = self.parser.bioinfer.sentences.sentences[idx]
        entities, entity_locs = self.get_entities_from_sentence(sentence)
#         relations = self.get_relations_from_sentence(sentence)
#         relation_locs = self.get_relation_locs_from_sentence(sentence)
        entity_names, entity_locs = self.entities_to_tensors(entities, entity_locs)
        graphs, nkis, node_idx_to_element_idxs = self.get_relation_graphs_from_sentence(sentence,entity_locs)
        
        sample = {'text': sentence.getText(),
                'tokens': self.sent_to_idxs(sentence.getText(),self.vocab_dict),
                'element_names': entity_names,
                'element_locs': entity_locs,
#                 'relations': relations,
#                 'relation_locs': relation_locs,
                'relation_graphs': graphs,
                'node_idx_to_element_idxs': node_idx_to_element_idxs,
#                 'nkis': nkis 
                 }
        return sample
    
    def create_vocab_dictionary(self,parser):
        vocab = {u"UNK"}

        for s in parser.bioinfer.sentences.sentences:
            for token in s.tokens:
                vocab.add(token.getText())

        vocab_size = len(vocab)
        vocab_index_list = [index for index in range(vocab_size)]

        vocab_dict = dict(zip(vocab, vocab_index_list))
        return vocab_dict
    
    def get_entities(self, parser):
        entities = set()
        for s in parser.bioinfer.sentences.sentences:
            for e in s.entities:
                entity_type = e.type.name
                if 'RELATIONSHIP' not in entity_type:
                    entities.add(f"{self.entity_prefix}{entity_type}")
                    
        return list(entities)
    
    def get_predicates(self, parser):
        predicates = set()
        for s in parser.bioinfer.sentences.sentences:
            for f in s.formulas:
                for argument in f.rootNode.arguments:
                    if argument.isPredicate():
                        predicates.add(f"{self.predicate_prefix}{argument.predicate.name}")
                predicate_type = f.rootNode.predicate.name
                predicates.add(f"{self.predicate_prefix}{predicate_type}")
        return list(predicates)

    def get_entities_from_sentence(self,sentence):
        entity_locs = {}
        entities = []
        i = 0
        for e in sentence.entities:
            entity_type = e.type.name
            if 'RELATIONSHIP' not in entity_type:
                entity = (f"{ENTITY_PREFIX}{entity_type}",tuple([st.token.sequence for st in e.subTokens]))
                entities.append(entity)
                entity_locs[e.id] = i
                i += 1
        return entities, entity_locs
    
#     def get_nested_relation_locs(self,rootNode, entity_pos_dict):
#         temp = []
#         for e in rootNode.arguments:
#             if e.isEntity():
#                 temp.append(entity_pos_dict[e.entity])
#             else:
#                 temp.append(self.get_nested_relation_locs(e,entity_pos_dict))
#         return temp

#     def get_relation_locs_from_sentence(self,sentence):
#         entities = [ent for ent in sentence.entities if 'RELATIONSHIP' not in ent.type.name]
#         entity_pos_dict = {e: i for i,e in enumerate(entities)}
#         relation_locs = []
#         for f in sentence.formulas:
#             if f.rootNode.isPredicate() and not f.rootNode.isEntity():
#                 relation_locs.append(self.get_nested_relation_locs(f.rootNode, entity_pos_dict))
#             else:
#                 raise ValueError
#         return relation_locs
    
    def get_relation_graphs_from_sentence(self,sentence,element_locs):
        graphs = []
        nkis = []
        node_idx_to_element_idxs = []
        for f in s.formulas:
            g, nki = pairs_to_graph(self.construct_graph_pairs(f.rootNode)) 
            nkis.append(nki)
            g.ndata['element_indices'] = torch.tensor([element_locs[nki[n.item()]] if nki[n.item()] in element_locs.keys() else -2 for n in g.nodes()])
            graphs.append(g)
            node_idx_to_element_idxs.append({k:v for k,v in zip(g.nodes().tolist(),g.ndata['element_indices'].tolist())})
        return graphs, nkis, node_idx_to_element_idxs
    
#     def get_relations_from_sentence(self,sentence):
#         sentence_relation_labels = []
#         for f in sentence.formulas:
#             if f.rootNode.isPredicate() and not f.rootNode.isEntity():
#                 predicate_name = f.rootNode.predicate.name
#                 arguments = self.get_relnode_argument_types(f.rootNode)
#                 sentence_relation_labels.append((f"{PREDICATE_PREFIX}{predicate_name}",arguments))
#             else:
#                 raise ValueError
#         return sentence_relation_labels
    
    def entities_to_tensors(self,entities, entity_locs):
        entity_names = th.tensor([th.tensor(self.element_to_idx[e[0]]) for e in entities])
        entity_names = entity_names.reshape(-1,1)
        
#         entity_locs = [th.tensor(e[1]) for e in entities]
#         entity_locs = th.stack([functional.pad(e,pad=(0,5-len(e)),mode='constant',value=-1) for e in entity_locs])
        
        return entity_names, entity_locs
                
    def get_relnode_argument_types(self, relnode) -> tuple:
        arguments = set()
        for a in relnode.arguments:
            if a.isEntity():
                arguments.add(f"{self.entity_prefix}{a.entity.type.name}")
            elif a.isPredicate():
                arguments.add(f"{self.predicate_prefix}{a.predicate.name}")
            else:
                raise ValueError

        return tuple(sorted(list(arguments)))
    
    def sent_to_idxs(self,sentence,vocab_dict):
        token_list = sentence.split()

        index_list = []
        for token in token_list:
            if token in vocab_dict:
                index_list.append(vocab_dict[token])
            else:
                index_list.append(vocab_dict["UNK"])
        return th.LongTensor(index_list)
    

    def get_schema(self,parser,element_to_idx):
        schema = {}
        for s in parser.bioinfer.sentences.sentences:
            for f in s.formulas:
                if f.rootNode.isPredicate() and not f.rootNode.isEntity():
                    predicate_name = f.rootNode.predicate.name
                    key = f"{self.predicate_prefix}{predicate_name}"
                    num_key = self.element_to_idx[key]
                    if num_key not in schema.keys():
                        schema[num_key] = Counter()
                    arguments = self.get_relnode_argument_types(f.rootNode)
                    try:
                        arguments = tuple(sorted([self.element_to_idx[arg] for arg in arguments]))
                    except:
                        print(f.rootNode)
                    schema[num_key][arguments] += 1
                else:
                    raise ValueError("formula rootNode should not be Entity")
        return schema
    
    def invert_schema(self, schema):
        inverted_schema = {}

        for rel, argsets in schema.items():
            for argset in argsets:
                if argset not in inverted_schema.keys():
                    inverted_schema[argset] = Counter()
                inverted_schema[argset][rel] += 1

        return inverted_schema
    
    def construct_graph_pairs(self,node):
        pairs = []
#         element_names = []
        if node.isPredicate():
            node_type = f"{self.predicate_prefix}{node.predicate.name}"
        else:
            node_type = f"{self.entity_prefix}{node.entity.type.name}"
#         element_names.append(node_type)
        for arg in node.arguments:

            pairs.append([node.entity.id,arg.entity.id,self.element_to_idx[node_type]])
            pairs += construct_graph_pairs(arg)
        return pairs
    
    
def collate_fn(data):
    print(data)

In [1214]:
def pairs_to_graph(pairs):
    g = np.array(pairs)
    node_keys = {v:i for i, v in enumerate(np.unique(g[:,:2].flatten()))}
    u = g[:,0]
    v = g[:,1]
    u = np.vectorize(node_keys.get)(u)
    v = np.vectorize(node_keys.get)(v)
    g[:,0] = u
    g[:,1] = v
    element_names = [-1 for _ in np.unique(g[:,:2])]
    for row in g:
        element_names[int(row[0])] = int(row[2])
    g = dgl.graph((u,v))
    node_keys_inverse = {v:k for k,v in node_keys.items()}
    g.ndata['element_names'] = th.tensor(element_names)
    return g, node_keys_inverse

In [1215]:
dataset = BioInferDataset('../data/BioInfer_corpus_1.1.1.xml')

In [1216]:
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [1217]:
a = 1 #1st sentence

For each sentence, we sort all relations and candidates according to their topological orders and put entities and relations into an element list. If an element has operands, its operands must positioned before it in the list. Then each element has an index in list from $0 \text{ to } n_a^e +n_a^e − 1$

In [1218]:
# import networkx as nx
# import matplotlib.pyplot as plt

# def plot_tree(g):
#     # this plot requires pygraphviz package
#     pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
#     nx.draw(g, pos, with_labels=False, node_size=10,
#             node_color=[[.5, .5, .5]], arrowsize=4)
#     plt.show()

In [1219]:
# plot_tree(g.to_networkx())

children

In [1220]:
# [element_locs[nki[n.item()]] if nki[n.item()] in element_locs.keys() else -1 for n in g.nodes()]

In [1221]:
MAX_LAYERS = 2 # TODO change to 3

In [1222]:
sample = dataset[1]

In [1223]:
g = np.array([['a','b','c'],['a','b','c']])
for x in g:
    print(x)

['a' 'b' 'c']
['a' 'b' 'c']


In [1224]:
sample

{'text': 'A binary complex of birch profilin and skeletal muscle actin could be isolated by gel chromatography .',
 'tokens': tensor([1594, 1461, 3997, 2373,  122, 2644, 3398, 2624, 2187, 4806, 1863, 2996,
         4017, 1943, 2805, 3553, 4874]),
 'element_names': tensor([[12],
         [12],
         [12],
         [52]]),
 'element_locs': {'e.4.1': 0, 'e.4.2': 1, 'e.4.3': 2, 'e.4.4': 3},
 'relation_graphs': [Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_names': Scheme(shape=(), dtype=torch.int64), 'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={}),
  Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_names': Scheme(shape=(), dtype=torch.int64), 'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={})],
 'node_idx_to_element_idxs': [{0: -2, 1: 0, 2: 3}, {0: -2, 1: 2, 2: 3}]}

In [1225]:
element_names

array([ 12,  12,  12,  52, 119, 106, 124,  72, 117,  85, 126,  86, 119,
       106, 124,  72, 117,  85, 126,  86, 119, 106, 124,  72, 117,  85,
       126,  86, 119, 106, 124,  72, 117,  85, 126,  86,  71,  81,  71,
        81,  76, 125,  75,  71, 117,  76,  71,  81,  71,  81,  76, 125,
        75,  71, 117,  76,  71,  81,  71,  81,  76, 125,  75,  71, 117,
        76, 119, 106, 124,  72, 117,  85, 126,  86,  71,  81,  71,  81,
        76, 125,  75,  71, 117,  76,  71,  81,  71,  81,  76, 125,  75,
        71, 117,  76,  71,  81,  71,  81,  76, 125,  75,  71, 117,  76,
       119, 106, 124,  72, 117,  85, 126,  86,  71,  81,  71,  81,  76,
       125,  75,  71, 117,  76,  71,  81,  71,  81,  76, 125,  75,  71,
       117,  76,  71,  81,  71,  81,  76, 125,  75,  71, 117,  76,  84,
        84,  84,  76,  76,  76,  76, 125,  71,  76,  76, 125,  71,  76,
        76, 125,  71,  76,  76, 125,  76,  76, 125,  76,  76,  76,  71,
        71,  76,  76,  76, 125,  71,  76,  76, 125,  71,  76,  7

In [1226]:
def get_child_indices(g, node_idx):
    return torch.stack(g.out_edges(node_idx))[1].tolist()

In [1240]:
element_names = sample['element_names'].numpy()
j = len(element_names)
element_indices = torch.arange(j)

S_temp = [nn.functional.pad(e,pad=(0,2-len(e)),mode='constant',value=-1) for e in list(element_indices.chunk(j))]
T_temp = element_indices.tolist()

A_temp = [a for _ in element_indices]
labels_temp = [0.0 for _ in element_indices]

max_layers = MAX_LAYERS

for _ in range(max_layers):
    ttt = torch.tensor(T_temp)
    for c in torch.combinations(ttt):
        e_names = torch.tensor(element_names)[c]
        key = tuple(sorted(e.item() for e in e_names))
        if key in dataset.inverse_schema.keys():
            for predicate in dataset.inverse_schema[key].keys():
                S_temp.append(c)
                T_temp.append(j)
                A_temp.append(a)
                element_names = np.append(element_names,predicate)
                L = 0.0 # default label is false
                for i, g in enumerate(sample['relation_graphs']):
                    for n in g.nodes():
                        child_idx = get_child_indices(g, node_idx=n)
                        child_idx = th.tensor([sample['node_idx_to_element_idxs'][i][idx] for idx in child_idx])
                        if child_idx.shape == c.shape:
                            # TODO ordering
                            if child_idx.tolist() == c.tolist() and element_names[j] == g.ndata['element_names'][n]: # check if children match and the predicate type is correct 
                                print(c)
                                sample['node_idx_to_element_idxs'][i][n.item()] = j 
                                L = 1.0 # this label is true because we found this candidate in the gold standard relation graphs
                labels_temp.append(L)
                j += 1

labels = th.tensor(labels_temp)            

tensor([0, 3])
tensor([2, 3])
tensor([0, 3])
tensor([2, 3])


In [1233]:
element_names[labels == 1.0]

array([119, 119, 119, 119])

In [1237]:
dataset.element_to_idx['p-CONTAIN']

119

In [1234]:
sample

{'text': 'A binary complex of birch profilin and skeletal muscle actin could be isolated by gel chromatography .',
 'tokens': tensor([1594, 1461, 3997, 2373,  122, 2644, 3398, 2624, 2187, 4806, 1863, 2996,
         4017, 1943, 2805, 3553, 4874]),
 'element_names': tensor([[12],
         [12],
         [12],
         [52]]),
 'element_locs': {'e.4.1': 0, 'e.4.2': 1, 'e.4.3': 2, 'e.4.4': 3},
 'relation_graphs': [Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_names': Scheme(shape=(), dtype=torch.int64), 'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={}),
  Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_names': Scheme(shape=(), dtype=torch.int64), 'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={})],
 'node_idx_to_element_idxs': [{0: 28, 1: 0, 2: 3}, {0: 104, 1: 2, 2: 3}]}

In [1235]:
L_temp

[0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0

In [934]:
len(A_temp)

190

In [933]:
sum(L_temp)

32.0

In [932]:
len(L_temp)

190

In [928]:
sample

{'text': 'A binary complex of birch profilin and skeletal muscle actin could be isolated by gel chromatography .',
 'tokens': tensor([1594, 1461, 3997, 2373,  122, 2644, 3398, 2624, 2187, 4806, 1863, 2996,
         4017, 1943, 2805, 3553, 4874]),
 'element_names': tensor([[12],
         [12],
         [12],
         [52]]),
 'element_locs': {'e.4.1': 0, 'e.4.2': 1, 'e.4.3': 2, 'e.4.4': 3},
 'relation_graphs': [Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={}),
  Graph(num_nodes=3, num_edges=2,
        ndata_schemes={'element_indices': Scheme(shape=(), dtype=torch.int64)}
        edata_schemes={})],
 'node_idx_to_element_idxs': [{0: 35, 1: 0, 2: 3}, {0: 111, 1: 2, 2: 3}]}

In [929]:
s = dataset.parser.bioinfer.sentences.sentences[1]

In [930]:
s.formulas[0].rootNode.arguments

[<BasicClasses.EntityNode at 0x7fec6d3b5820>,
 <BasicClasses.EntityNode at 0x7fec6d5255b0>]

In [888]:
sum(L_temp)

40.0

tensor([0, 1, 2, 3, 4, 5])

In [692]:
for n in g.nodes():
    print(g.out_edges(n))

(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([1, 1, 1, 1, 1, 1]), tensor([0, 2, 3, 4, 5, 6]))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))


In [668]:
L_temp

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [329]:
# def update_rlocs(rlocs,arguments_term,relation_term):
#     found = False
#     for i, args in enumerate(rlocs):
#         if arguments_term == args:
#             rlocs[i] = relation_term
#             found = True
#     return found, rlocs

In [319]:
for r in sample['relation_locs']:
    print(r)

[0, 3]
[[0, [1, 5]], 3]
[[0, [1, 2]], 3]


In [505]:
# for i, batch in enumerate(dataloader):
#     print(batch)
#     print('\n\n\n\n')

In [642]:
g.out_edges(1)

(tensor([1, 1, 1, 1, 1, 1]), tensor([0, 2, 3, 4, 5, 6]))

In [47]:
S_temp

[tensor([ 0, -1]),
 tensor([ 1, -1]),
 tensor([ 2, -1]),
 tensor([ 3, -1]),
 tensor([ 4, -1]),
 tensor([ 5, -1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 3]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 t

In [44]:
T_temp

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


In [50]:
S_temp

[tensor([ 0, -1]),
 tensor([ 1, -1]),
 tensor([ 2, -1]),
 tensor([ 3, -1]),
 tensor([ 4, -1]),
 tensor([ 5, -1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([0, 3]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 2]),
 tensor([1, 3]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 4]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([1, 5]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([2, 3]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 4]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([3, 5]),
 tensor([0, 1]),
 tensor([0, 1]),
 tensor([0, 1]),
 t

In [49]:
sample

{'sentence': tensor([3472, 2147, 3195, 1088, 1943, 3620,  795, 2373, 1449, 4127,  103, 3997,
         4874]),
 'element_names': tensor([[12],
         [52],
         [12],
         [47],
         [12],
         [12]]),
 'element_locs': tensor([[ 0, -1, -1, -1, -1],
         [ 9,  9,  9, 10, 10],
         [ 9, 10, -1, -1, -1],
         [ 2,  3, -1, -1, -1],
         [ 2, -1, -1, -1, -1],
         [ 9, -1, -1, -1, -1]]),
 'relations': [('p-SUPPRESS', ('e-Function_property', 'e-Individual_protein')),
  ('p-SUPPRESS', ('e-Function_property', 'p-PREVENT')),
  ('p-SUPPRESS', ('e-Function_property', 'p-PREVENT'))]}

In [9]:
train_data, test_data = get_sentences(0.2)

In [10]:
vocab_dict = eval(open("../data/vocab_dict.txt", "r").read())
config = BioInferTaskConfiguration().from_json("../data/configuration.json")
element_to_idx = config.element_to_idx

In [None]:
class INNModel(nn.Module):
    def __init__(
        self,
        vocab_dict,
        word_embedding_dim,
        element_embedding_dim,
        hidden_dim,
        schema,
        inverted_schema,
        element_to_idx,
        max_layer_height,
    ):
        super().__init__()
        self.vocab_dict = vocab_dict

In [None]:
optimizer = th.optim.Adadelta(model.parameters(), lr=1.0)
criterion = nn.NLLLoss()