In [15]:
#Imports

from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm
import json

In [16]:
# importing the module
import json
  
# Opening JSON file
f = open('json_datasets/train.json', 'r')

raw_train = json.load(f)

In [17]:
raw_train[0]

{'relations': [[['Aarhus', 'Airport'],
   'cityServed',
   ['``', 'Aarhus', ',', 'Denmark', "''"]]],
 'text': 'the <ENT_1> of <ENT_0> .',
 'entities': [['``', 'Aarhus', ',', 'Denmark', "''"], ['Aarhus', 'Airport']]}

In [28]:
class VocabCategory():
	def __init__(self):
		self.wordlist = []
		self.word2idx = {}
		self.wordfreq = Counter()

class Vocabulary():
	def __init__(self):
		print("Creating empty vocabulary object")
		self.text = VocabCategory()
		self.entities = VocabCategory()
		self.relations = VocabCategory()

		self.relations.word2idx["<NO_RELATION>"] = len(self.relations.wordlist) # no relation token for relations vocab
		self.relations.wordlist.append("<NO_RELATION>")

		self.init_vocab(self.text)
		self.init_vocab(self.entities)
		self.init_vocab(self.relations)
		
	# initializes UNK, SOS, EOS, and EMPTY tokens
	def init_vocab(self, vocab_category):
		tokens = ["<UNK>", "<SOS>", "<EOS>", "<EMPTY>"]

		for token in tokens:
			vocab_category.word2idx[token] = len(vocab_category.wordlist)
			vocab_category.wordlist.append(token)
		# vocab_category.word2idx["<UNK>"] = len(vocab_category.wordlist)
		# vocab_category.wordlist.append("<UNK>")
		# vocab_category.word2idx["<SOS>"] = len(vocab_category.wordlist)
		# vocab_category.wordlist.append("<SOS>")
		# vocab_category.word2idx["<EOS>"] = len(vocab_category.wordlist)
		# vocab_category.wordlist.append("<EOS>")
		# vocab_category.word2idx["<EMPTY>"] = len(vocab_category.wordlist)
		# vocab_category.wordlist.append("<EMPTY>")
		

	def parseSentence(self, raw_json_sentence):
		for relation in raw_json_sentence['relations']: #Relation parsing here
			assert len(relation) == 3, "CHECK THIS!"
			if relation[1] not in self.relations.word2idx:
				self.relations.word2idx[relation[1]] = len(self.relations.wordlist)
				self.relations.wordlist.append(relation[1])
			self.relations.wordfreq.update({relation[1]: 1})
		
		for word in raw_json_sentence['text'].split(): #Word parsing here
			if word not in self.text.word2idx:
				self.text.word2idx[word] = len(self.text.wordlist)
				self.text.wordlist.append(word)
		self.text.wordfreq += Counter(raw_json_sentence['text'].split())

		for entity in raw_json_sentence['entities']:
			for e in entity:
				if e not in self.entities.word2idx:
					self.entities.word2idx[e] = len(self.entities.wordlist)
					self.entities.wordlist.append(e)
			self.entities.wordfreq += Counter(entity)
	
	def parseText(self, raw_json):
		for raw_sentence in raw_json:
			self.parseSentence(raw_sentence)
		print("Finished Parsing Text")

In [19]:
vocab = Vocabulary()
vocab.parseText(raw_train)

Creating empty vocabulary object


In [23]:

def entity2Indices(vocab, entity):
	temp = torch.zeros(len(entity), dtype = torch.long)
	for ind, word in enumerate(entity):
		if word not in vocab.entities.word2idx:
			temp[ind] = vocab.entities.word2idx["<UNK>"]
		else:
			temp[ind] = vocab.entities.word2idx[word]
	return temp
		
def text2Indices(vocab, text):
	temp = torch.zeros(len(text.split()) + 2, dtype=torch.long)
	temp[0] = vocab.text.word2idx["<SOS>"]
	for ind, word in enumerate(text.split()):
		if word not in vocab.text.word2idx:
			temp[ind + 1] = vocab.text.word2idx["<UNK>"]
		else:
			temp[ind + 1] = vocab.text.word2idx[word]
	temp[-1] = vocab.text.word2idx["<EOS>"]
	return temp

def relation2Indices(vocab, raw_json_sentence):
	l = len(raw_json_sentence['entities'])
	ret = torch.zeros((l,l), dtype = torch.long)
	entitydict = {}
	for i, entity in enumerate(raw_json_sentence['entities']):
		entitydict["".join(entity)] = i
	for relation in raw_json_sentence['relations']:
		ind1 = entitydict["".join(relation[0])]
		ind2 = entitydict["".join(relation[2])]
		ret[ind1][ind2] = ret[ind2][ind1] = vocab.relations.word2idx[relation[1]]
	return ret

print(raw_train[54])
relation2Indices(vocab, raw_train[54])

		

{'relations': [[['Abilene', 'Regional', 'Airport'], 'runwayLength', ['2195.0']]], 'text': 'the runway length of <ENT_0> is <ENT_1> .', 'entities': [['Abilene', 'Regional', 'Airport'], ['2195.0']]}


tensor([[0, 9],
        [9, 0]])

In [59]:
def concatTextEntities(vocab, raw_json_sentence, entity_indices):
	sent = text2Indices(vocab, raw_json_sentence['text'])
	modified_input = torch.LongTensor([0])
	lbound = 0
	entity_locations = []
	additional_words = 0
	for index, value in enumerate(sent):
		if value.item() in entity_indices:
			temp = entity2Indices(vocab, raw_json_sentence['entities'][entity_indices[value.item()]])
			temp += len(vocab.text.wordlist)
			modified_input = torch.cat((modified_input, sent[lbound:index], temp), dim = 0)
			entity_locations.append([index + additional_words, index + additional_words + len(temp)])
			additional_words += len(temp) - 1
			lbound = index + 1
	modified_input = torch.cat((modified_input, sent[lbound:]), dim = 0)[1:]
	return modified_input, torch.tensor(entity_locations)

In [32]:
def getEntityIndices(vocab):
	entity_indices = {}
	i = 0
	while True:
		if '<ENT_' + str(i) + '>' in vocab.text.word2idx:
			entity_indices[(vocab.text.word2idx['<ENT_' + str(i) + '>'])] = i
			i += 1
		else:
			return entity_indices


entity_indices = getEntityIndices(vocab)
print(entity_indices)
print([vocab.text.wordlist[k] for k in entity_indices])

{7: 0, 5: 1, 725: 2, 963: 3, 1135: 4, 1347: 5, 1449: 6, 1475: 7}
['<ENT_0>', '<ENT_1>', '<ENT_2>', '<ENT_3>', '<ENT_4>', '<ENT_5>', '<ENT_6>', '<ENT_7>']


In [57]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class text2GraphDataset(Dataset):
	def __init__(self, raw_json_file):
		print("Creating custom dataset for T2G task")
		
		self.vocab = Vocabulary()
		self.vocab.parseText(raw_json_file)
		
		self.inputs = []
		self.labels = []
		
		self.entity_indices = getEntityIndices(self.vocab)

		for raw_json_sentence in raw_json_file:
			self.labels.append(relation2Indices(self.vocab, raw_json_sentence))
			self.inputs.append(concatTextEntities(self.vocab, raw_json_sentence, self.entity_indices))

		print("Finished processing raw json file")

	def __len__(self):
		return len(self.inputs)
	def __getitem__(self, idx):
		return self.inputs[idx], self.labels[idx]

In [60]:
dataset = text2GraphDataset(raw_json_file = raw_train)
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

Creating custom dataset for T2G task
Creating empty vocabulary object
Finished Parsing Text
Finished processing raw json file


In [61]:
inputs, labels = next(iter(train_dataloader))
print(inputs)

[tensor([[   1,  311,   21, 3747,   37,    4, 2178, 1964, 2179,   83,   20,   16,
         2186, 2187, 1494, 1930,    8,    2]]), tensor([[[ 3,  4],
         [ 6,  9],
         [12, 16]]])]
