In [1]:
# import
from simplet5 import SimpleT5

#Imports

from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm

import data_processing as dp
import json

import t2g
import g2t

Global seed set to 42


In [2]:
class CycleModel():
	def __init__(self, vocab):
		self.t2g_model = t2g.T2GModel(vocab, 768)
		self.g2t_model = g2t.G2TModel(vocab)
		self.t2g_opt = torch.optim.Adam(self.t2g_model.model.parameters())
		self.g2t_opt = torch.optim.Adam(self.g2t_model.t5_model.model.parameters())
		self.vocab = vocab
    
	def t_cycle(self, text_batch): # optimizes g2t
		self.t2g_model.eval()
		self.g2t_model.train()
		with torch.no_grad():
				pred_graphs = self.t2g_model.predict(text_batch)
		# syn_batch???
		self.g2t_opt.zero_grad()
		pred_text = self.g2t_model.predict(pred_graphs)
		# convert pred_text to tensor of word indices
		loss = F.nll_loss(pred_text.reshape(-1, pred_text.shape[-1]), text_batch.reshape(-1), ignore_index=0) # could be wrong, again
		loss.backward()
		#nn.utils.clip_grad_norm_(g2t_model.parameters(), config['clip'])
		self.g2t_opt.step()
		return loss.item()

	def g_cycle(self, graph_batch): # optimizes t2g
		self.g2t_model.eval()
		self.t2g_model.train()
		with torch.no_grad():
			pred_text = self.g2t_model.predict(graph_batch)
		# convert pred_text to correct format to input into t2g
		self.t2g_opt.zero_grad()
		pred_graphs = self.t2g_model.predict(pred_text)
		loss = F.nll_loss(pred_graphs.contiguous().view(-1, pred_graphs.shape[-1]), graph_batch.contiguous().view(-1), ignore_index=0) # could be wrong, again
		loss.backward()
		#nn.utils.clip_grad_norm_(g2t_model.parameters(), config['clip'])
		self.t2g_opt.step()
		return loss.item()

	def back_translation(self, text_batch, graph_batch):
		g_loss = self.g_cycle(graph_batch)
		t_loss = self.t_cycle(text_batch)
		return g_loss, t_loss

	def train(self, epochs, batch_size, learning_rate, shuffle):

		for i in range(epochs):
			tcycle_dataloader, gcycle_dataloader = dp.create_cycle_dataloader(raw_json_file=self.vocab.raw_data, batch_size = batch_size, shuffle=shuffle)
			dataloader = list(zip(tcycle_dataloader, gcycle_dataloader))
			for index, (tbatch, gbatch) in tqdm.tqdm(enumerate(dataloader)):
				g_loss, t_loss = self.back_translation(tbatch, gbatch)

In [17]:
cycle_model.g2t_model.g2t_preprocess(gbatch)[1]

[['Germans of Romania', '1 Decembrie 1918 University', 'Romania'],
 ['Germans of Romania', '1 Decembrie 1918 University', 'Romania'],
 ['1 Decembrie 1918 University', 'Klaus Iohannis', 'Romania'],
 ['1 Decembrie 1918 University', 'Andrew the Apostle', 'Romania'],
 ['1 Decembrie 1918 University', 'Andrew the Apostle', 'Romania'],
 ['Denmark',
  '737',
  'School of Business and Social Sciences at the Aarhus University'],
 ['Denmark',
  '737',
  'School of Business and Social Sciences at the Aarhus University'],
 ['1928',
  '737',
  'School of Business and Social Sciences at the Aarhus University']]

In [4]:
cycle_model.g2t_model.t5_model.predict(gbatch)

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [10]:
cycle_model.g2t_model.t5_model.predict(cycle_model.g2t_model.g2t_preprocess(gbatch)[0])

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [3]:
# Opening JSON file
f = open('json_datasets/train.json', 'r')

raw_train = json.load(f)

vocab = dp.Vocabulary()
vocab.parseText(raw_train)

#create cycle

cycle_model = CycleModel(vocab)

tcycle_dataloader, gcycle_dataloader = dp.create_cycle_dataloader(vocab.raw_data, batch_size = 8, shuffle=False)

tbatch = tcycle_dataloader[0]
gbatch = gcycle_dataloader[0]

Creating empty vocabulary object
Finished Parsing Text


In [5]:
dataloader = list(zip(tcycle_dataloader, gcycle_dataloader))
for index, (tbatch, gbatch) in tqdm.tqdm(enumerate(dataloader)):
	print(tbatch)
	break

0it [00:00, ?it/s]

[{'relations': [[['Aarhus', 'Airport'], 'cityServed', ['``', 'Aarhus', ',', 'Denmark', "''"]]], 'text': 'the <ENT_1> of <ENT_0> .', 'entities': [['``', 'Aarhus', ',', 'Denmark', "''"], ['Aarhus', 'Airport']]}
 {'relations': [[['Aarhus', 'Airport'], 'cityServed', ['``', 'Aarhus', ',', 'Denmark', "''"]]], 'text': '<ENT_1> serves the city of <ENT_0> .', 'entities': [['``', 'Aarhus', ',', 'Denmark', "''"], ['Aarhus', 'Airport']]}
 {'relations': [[['Aarhus', 'Airport'], 'cityServed', ['Aarhus']]], 'text': '<ENT_1> serves the city of <ENT_0> .', 'entities': [['Aarhus'], ['Aarhus', 'Airport']]}
 {'relations': [[['Aarhus', 'Airport'], 'elevationAboveTheSeaLevel_(in_metres)', ['25.0']]], 'text': '<ENT_0> is <ENT_1> metres above sea level .', 'entities': [['Aarhus', 'Airport'], ['25.0']]}
 {'relations': [[['Aarhus', 'Airport'], 'elevationAboveTheSeaLevel_(in_metres)', ['25.0']]], 'text': '<ENT_0> is at an elevation of <ENT_1> metres above seal level .', 'entities': [['Aarhus', 'Airport'], ['25.0




In [12]:
dataloader

[]