In [2]:
#Imports

from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm
import json

In [3]:
# importing the module
import json
  
# Opening JSON file
f = open('json_datasets/train.json', 'r')

raw_train = json.load(f)

In [28]:
class VocabCategory():
	def __init__(self):
		self.wordlist = []
		self.word2idx = {}
		self.wordfreq = Counter()

class Vocabulary():
	def __init__(self):
		print("Creating empty vocabulary object")
		self.text = VocabCategory()
		self.entities = VocabCategory()
		self.relations = VocabCategory()
	def parseSentence(self, raw_json_sentence):
		for relation in raw_json_sentence['relations']: #Relation parsing here
			assert len(relation) == 3, "CHECK THIS!"
			if relation[1] not in self.relations.word2idx:
				self.relations.word2idx[relation[1]] = len(self.relations.wordlist)+1
				self.relations.wordlist.append(relation[1])
			self.relations.wordfreq.update({relation[1]: 1})
		
		for word in raw_json_sentence['text'].split(): #Word parsing here
			if word not in self.text.word2idx:
				self.text.word2idx[word] = len(self.text.wordlist)
				self.text.wordlist.append(word)
		self.text.wordfreq += Counter(raw_json_sentence['text'].split())

		for entity in raw_json_sentence['entities']:
			for e in entity:
				if e not in self.entities.word2idx:
					self.entities.word2idx[e] = len(self.entities.wordlist)
					self.entities.wordlist.append(e)
			self.entities.wordfreq += Counter(entity)
	
	def parseText(self, raw_json):
		for raw_sentence in raw_json:
			self.parseSentence(raw_sentence)

In [29]:
vocab = Vocabulary()
vocab.parseText(raw_train)

Creating empty vocabulary object


In [18]:
raw_train[54]

{'relations': [[['Abilene', 'Regional', 'Airport'],
   'runwayLength',
   ['2195.0']]],
 'text': 'the runway length of <ENT_0> is <ENT_1> .',
 'entities': [['Abilene', 'Regional', 'Airport'], ['2195.0']]}

In [51]:
def text2Relation(vocab, raw_json_sentence):
	l = len(raw_json_sentence['entities'])
	ret = torch.zeros((l,l), dtype = torch.long)
	entitydict = {}
	for i, entity in enumerate(raw_json_sentence['entities']):
		entitydict["".join(entity)] = i
	print(entitydict)
	for relation in raw_json_sentence['relations']:
		print(relation)
		ind1 = entitydict["".join(relation[0])]
		ind2 = entitydict["".join(relation[2])]
		ret[ind1][ind2] = ret[ind2][ind1] = vocab.relations.word2idx[relation[1]]
	return ret

text2Relation(vocab, raw_train[54])


{'AbileneRegionalAirport': 0, '2195.0': 1}
[['Abilene', 'Regional', 'Airport'], 'runwayLength', ['2195.0']]


tensor([[0, 5],
        [5, 0]])

In [89]:
def entity2Indices(vocab, entity):
	temp = torch.zeros(len(entity), dtype = torch.long)
	for ind, word in enumerate(entity):
		if word not in vocab.entities.word2idx:
			temp[ind] = -1
		else:
			temp[ind] = vocab.entities.word2idx[word]
	return temp
		
def text2Indices(vocab, text):
	temp = torch.zeros(len(text.split()), dtype=torch.long)
	for ind, word in enumerate(text.split()):
		if word not in vocab.text.word2idx:
			temp[ind] = -1
		else:
			temp[ind] = vocab.text.word2idx[word]
	return temp
		

In [54]:
raw_train[27]

{'relations': [[['Abilene', ',', 'Texas'],
   'isPartOf',
   ['Jones', 'County', ',', 'Texas']]],
 'text': '<ENT_0> is part of <ENT_1> .',
 'entities': [['Abilene', ',', 'Texas'], ['Jones', 'County', ',', 'Texas']]}

1484

In [117]:
def concatTextEntities(vocab, raw_json_sentence, entity_indices):
	sent = text2Indices(vocab, raw_json_sentence['text'])
	modified_input = torch.LongTensor([0])
	lbound = 0
	entity_locations = []
	additional_words = 0
	for index, value in enumerate(sent):
		if value.item() in entity_indices:
			temp = entity2Indices(vocab, raw_json_sentence['entities'][entity_indices[value.item()]])
			temp += len(vocab.text.wordlist)
			modified_input = torch.cat((modified_input, sent[lbound:index], temp), dim = 0)
			entity_locations.append((index + additional_words, index + additional_words + len(temp)))
			additional_words += len(temp) - 1
			lbound = index + 1
	modified_input = torch.cat((modified_input, sent[lbound:]), dim = 0)[1:]

	return modified_input, entity_locations

In [72]:
entity_indices = {}
i = 0
while True:
	if '<ENT_' + str(i) + '>' in vocab.text.word2idx:
		entity_indices[(vocab.text.word2idx['<ENT_' + str(i) + '>'])] = i
		i += 1
	else:
		break

print(entity_indices)
print([vocab.text.wordlist[k] for k in entity_indices])

{3: 0, 1: 1, 721: 2, 959: 3, 1131: 4, 1343: 5, 1445: 6, 1471: 7}
['<ENT_0>', '<ENT_1>', '<ENT_2>', '<ENT_3>', '<ENT_4>', '<ENT_5>', '<ENT_6>', '<ENT_7>']
