In [1]:
#Imports

from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm

In [2]:
import data_processing as dp

In [3]:
# importing the module
import json
  
# Opening JSON file
f = open('json_datasets/train.json', 'r')

raw_train = json.load(f)

vocab = dp.Vocabulary()
vocab.parseText(raw_train)

dataset = dp.text2GraphDataset(raw_json_file = raw_train)
dataloader = dp.getBatches(vocab, dataset, batch_size = 8, shuffle = False)

Creating empty vocabulary object
Finished Parsing Text
Creating custom dataset for T2G task
Creating empty vocabulary object
Finished Parsing Text
Finished processing raw json file


In [85]:
for (x,y) in zip(dataloader[0]['text'],dataloader[0]['entity_inds']):
	print([x[z[0]:z[1]] for z in y])

[tensor([1493, 1497]), tensor([1492, 1493, 1494, 1495, 1496])]
[tensor([1493, 1497]), tensor([1492, 1493, 1494, 1495, 1496])]
[tensor([1493, 1497]), tensor([1493])]
[tensor([1493, 1497]), tensor([1498])]
[tensor([1493, 1497]), tensor([1498])]
[tensor([1493, 1497]), tensor([1498])]
[tensor([1493, 1497]), tensor([1499])]
[tensor([1493, 1497]), tensor([1499])]


In [9]:
inp_types = len(vocab.entities.wordlist) + len(vocab.text.wordlist)
rel_types = len(vocab.relations.wordlist)

In [16]:
model = ModelLSTM(inp_types, rel_types, 100)

#model.forward(dataloader[0])

In [11]:
class ModelLSTM(nn.Module):
	def __init__(self, input_types, relation_types, model_dim, dropout = 0.5):
		super().__init__()

		self.word_types = input_types
		self.relation_types = relation_types
		self.dropout = dropout
		self.model_dim = model_dim

		self.emb = nn.Embedding(input_types, self.model_dim) # 40000 because we use the Bert tokenizer
		self.lstm = nn.LSTM(self.model_dim, self.model_dim//2, batch_first=True, bidirectional=True, num_layers=2)
		self.relation_layer1 = nn.Linear(self.model_dim , self.model_dim)
		self.relation_layer2 = nn.Linear(self.model_dim , self.model_dim)
		self.drop = nn.Dropout(self.dropout)
		self.projection = nn.Linear(self.model_dim , self.model_dim)
		self.decoder = nn.Linear(self.model_dim , self.relation_types)
		self.layer_norm = nn.LayerNorm(self.model_dim)

		self.init_params()

	def init_params(self):
		nn.init.xavier_normal_(self.relation_layer1.weight.data)
		nn.init.xavier_normal_(self.relation_layer2.weight.data)
		nn.init.xavier_normal_(self.projection.weight.data)
		nn.init.xavier_normal_(self.decoder.weight.data)

		nn.init.constant_(self.relation_layer1.bias.data , 0)
		nn.init.constant_(self.relation_layer2.bias.data , 0)
		nn.init.constant_(self.projection.bias.data , 0)
		nn.init.constant_(self.decoder.bias.data , 0)

	def forward(self, batch):
		sents = batch['text']
		sents, (c_0, h_0) = self.lstm(self.emb(sents))

		bs, _, hidden_dim = sents.shape
		max_ents = max([len(x) for x in batch['entity_inds']])
		
		cont_word_mask = sents.new_zeros(bs, max_ents)
		cont_word_embs = sents.new_zeros(bs, max_ents, hidden_dim)

		for b, (sent,entind) in enumerate(zip(sents,batch['entity_inds'])):
			for n_ent, wordemb in enumerate([sent[z[0]:z[1]] for z in entind]):
				cont_word_embs[b, n_ent] = torch.mean(wordemb, dim = 0)
				cont_word_mask[b, n_ent] = 1

		# bs x max_ents x model_dim
		cont_word_embs = self.layer_norm(cont_word_embs)

		rel1 = self.relation_layer1(cont_word_embs)
		rel2 = self.relation_layer2(cont_word_embs)

		#bs x max_ents x max_ents x model_dim
		out = rel1.unsqueeze(1) + rel2.unsqueeze(2)

		out = F.relu(self.drop(out))
		out = F.relu(self.projection(out))
		out = self.decoder(out)

		out = out * cont_word_mask.view(bs,max_ents,1,1) * cont_word_mask.view(bs,1,max_ents,1)

		#bs x max_ents x max_ents
		return torch.log_softmax(out, -1)

In [25]:
train_model(model, rel_types, dataloader)

 20%|█▉        | 322/1630 [00:09<00:36, 35.54it/s]


KeyboardInterrupt: 

In [24]:

def train_model(model, num_relations, dataloader, learning_rate = 1e10, epochs = 30):
	"""
	"""

	# Create model
	optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)
	criterion = nn.NLLLoss()

	state_dict_clone = {k: v.clone() for k, v in model.state_dict().items()}
	for t in range(epochs):
		loss_this_epoch = 0.0
		for batch in tqdm.tqdm(range(len(dataloader))):
    
			log_probs = model(dataloader[batch])
			labels = dataloader[batch]['labels']	

			loss = criterion(log_probs.view(-1, num_relations), labels.view(-1))
			loss_this_epoch += loss.item()
			optimzer.zero_grad()
			loss.backward()
			# torch.nn.utils.clip_grad_norm_(
			#     [p for group in optimzer.param_groups for p in group['params']], CLIP)
			optimzer.step()

		# 	# load best parameters
		# curr_state_dict = encdec_model.state_dict()
		# for key in state_dict_clone.keys():
		# 	curr_state_dict[key].copy_(state_dict_clone[key])