In [15]:
# import
from simplet5 import SimpleT5


# instantiate
g2t_model = SimpleT5()

# load (supports t5, mt5, byT5 models)
g2t_model.from_pretrained("t5","t5-base")

g2t_model.T5Model

AttributeError: 'SimpleT5' object has no attribute 'T5Model'

In [18]:
#g2t_model.model

In [19]:
#Imports

from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm




In [20]:
import data_processing as dp

In [7]:
# importing the module
import json
  
# Opening JSON file
f = open('json_datasets/train.json', 'r')

raw_train = json.load(f)

vocab = dp.Vocabulary()
vocab.parseText(raw_train)

dataset = dp.text2GraphDataset(raw_json_file = raw_train)
dataloader = dp.getBatches(vocab, dataset, batch_size = 8, shuffle = False)

Creating empty vocabulary object
Finished Parsing Text
Creating custom dataset for T2G task
Creating empty vocabulary object
Finished Parsing Text
Finished processing raw json file


In [8]:
inp_types = len(vocab.entities.wordlist) + len(vocab.text.wordlist)
rel_types = len(vocab.relations.wordlist)


In [9]:
class ModelLSTM(nn.Module):
	def __init__(self, input_types, relation_types, model_dim, dropout = 0.5):
		super().__init__()

		self.word_types = input_types
		self.relation_types = relation_types
		self.dropout = dropout
		self.model_dim = model_dim

		self.emb = nn.Embedding(input_types, self.model_dim) # 40000 because we use the Bert tokenizer
		self.lstm = nn.LSTM(self.model_dim, self.model_dim//2, batch_first=True, bidirectional=True, num_layers=2)
		self.relation_layer1 = nn.Linear(self.model_dim , self.model_dim)
		self.relation_layer2 = nn.Linear(self.model_dim , self.model_dim)
		self.drop = nn.Dropout(self.dropout)
		self.projection = nn.Linear(self.model_dim , self.model_dim)
		self.decoder = nn.Linear(self.model_dim , self.relation_types)
		self.layer_norm = nn.LayerNorm(self.model_dim)

		self.init_params()

	def init_params(self):
		nn.init.xavier_normal_(self.relation_layer1.weight.data)
		nn.init.xavier_normal_(self.relation_layer2.weight.data)
		nn.init.xavier_normal_(self.projection.weight.data)
		nn.init.xavier_normal_(self.decoder.weight.data)

		nn.init.constant_(self.relation_layer1.bias.data , 0)
		nn.init.constant_(self.relation_layer2.bias.data , 0)
		nn.init.constant_(self.projection.bias.data , 0)
		nn.init.constant_(self.decoder.bias.data , 0)

	def forward(self, batch):
		sents = batch['text']
		sents, (c_0, h_0) = self.lstm(self.emb(sents))

		bs, _, hidden_dim = sents.shape
		max_ents = max([len(x) for x in batch['entity_inds']])
		
		cont_word_mask = sents.new_zeros(bs, max_ents)
		cont_word_embs = sents.new_zeros(bs, max_ents, hidden_dim)

		for b, (sent,entind) in enumerate(zip(sents,batch['entity_inds'])):
			for n_ent, wordemb in enumerate([sent[z[0]:z[1]] for z in entind]):
				cont_word_embs[b, n_ent] = torch.mean(wordemb, dim = 0)
				cont_word_mask[b, n_ent] = 1

		# bs x max_ents x model_dim
		cont_word_embs = self.layer_norm(cont_word_embs)

		rel1 = self.relation_layer1(cont_word_embs)
		rel2 = self.relation_layer2(cont_word_embs)

		#bs x max_ents x max_ents x model_dim
		out = rel1.unsqueeze(1) + rel2.unsqueeze(2)

		out = F.relu(self.drop(out))
		out = F.relu(self.projection(out))
		out = self.decoder(out)

		out = out * cont_word_mask.view(bs,max_ents,1,1) * cont_word_mask.view(bs,1,max_ents,1)

		return torch.log_softmax(out, -1)

In [12]:
t2g_model = ModelLSTM(inp_types, rel_types, 100)

#print(dataloader[0])
t2g_model.forward(dataloader[0]).shape

torch.Size([8, 2, 2, 249])

In [22]:
# importing the module
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time
import tqdm
import json
import re
import pandas as pd
  
# Opening JSON file
f_train = open('json_datasets/train.json', 'r')
raw_train = json.load(f_train)
f_train.close()

f_test = open('json_datasets/test.json', 'r')
raw_test = json.load(f_test)
f_test.close()

In [23]:
def removeQuotes(lst):
    ret = []
    for s in lst:
        if s != '``' and s != "''":
            ret.append(s)
    return ret

def camelCaseSplit(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    d = [m.group(0) for m in matches]
    new_d = []
    for token in d:
        token = token.replace('(', '')
        token = token.replace(')', '')
        token_split = token.split('_')
        for t in token_split:
            #new_d.append(t.lower())
            new_d.append(t)
    return new_d

In [24]:
def g2tPreprocess(raw):
    df = []
    for item in raw:
        graph = 'g2t:'
        for relation in item['relations']:
            graph += ' <H> ' + ' '.join(removeQuotes(relation[0])) + ' <R> '
            graph += ' '.join(camelCaseSplit(relation[1])) + ' <T> '
            graph += ' '.join(removeQuotes(relation[2]))

        ents = [' '.join(removeQuotes(entity)) for entity in item['entities']]
        text = item['text']
        for i in range(len(ents)):
            text = text.replace('<ENT_'+str(i)+'>', ents[i])
        sample = [graph, text]
        df.append(sample)
    return pd.DataFrame(df, columns=['source_text', 'target_text'])

def g2tPreprocessNoText(raw):
    df = []
    graphs = []
    entities = []
    raw_ents = []
    for item in raw:
        graph = 'g2t:'
        for relation in item['relations']:
            graph += ' <H> ' + ' '.join(removeQuotes(relation[0])) + ' <R> '
            graph += ' '.join(camelCaseSplit(relation[1])) + ' <T> '
            graph += ' '.join(removeQuotes(relation[2]))

        ents = [' '.join(removeQuotes(entity)) for entity in item['entities']]
        graphs.append(graph)
        entities.append(ents)
        raw_ents.append(item['entities'])
    return graphs, entities, raw_ents

In [25]:
train_df = g2tPreprocess(raw_train)
test_df = g2tPreprocess(raw_test)
print(test_df)

source_text  \
0     g2t: <H> Abilene Regional Airport <R> city Ser...   
1     g2t: <H> Abilene Regional Airport <R> city Ser...   
2     g2t: <H> Adolfo Suárez Madrid–Barajas Airport ...   
3     g2t: <H> Adolfo Suárez Madrid–Barajas Airport ...   
4     g2t: <H> Adolfo Suárez Madrid–Barajas Airport ...   
...                                                 ...   
4923  g2t: <H> Twilight ( band ) <R> genre <T> Black...   
4924  g2t: <H> Twilight ( band ) <R> genre <T> Black...   
4925  g2t: <H> Uruguay <R> leader Name <T> Raúl Fern...   
4926  g2t: <H> Uruguay <R> leader Name <T> Raúl Fern...   
4927  g2t: <H> Uruguay <R> leader Name <T> Raúl Fern...   

                                            target_text  
0     Abilene , Texas is served by the Abilene Regio...  
1     Abilene Regional Airport serves the city of Ab...  
2     Adolfo Suárez Madrid–Barajas Airport can be fo...  
3     Adolfo Suárez Madrid–Barajas Airport is locate...  
4     Adolfo Suárez Madrid–Barajas Airport is

In [26]:
def single_g2t(graph, ents, g2t_model):
    predText = g2t_model.predict(graph)
    for i in range(len(ents)):
        if ents[i] in text:
            predText.replace(ents[i], "<ENT_" + str(i) + ">")
        else:
            print("WARNING: ENTITY " + ents[i] + " NOT FOUND IN PREDICTED TEXT")
    return {'text' : predText, 'entities' : raw_ents}

# input: batch of graphs (list of dicts with relations and entities)
# output: predicted texts with original entities taken out (list of dicts with text and entities)
def predict_g2t(graphs, g2t_model):
    pGraphs, ents, raw_ents = g2tPreprocessNoText(graphs) # processed graphs, entities
    print(pGraphs)
    print(ents)
    hyps = [single_g2t(graphs[i], ents[i], raw_ents[i], g2t_model) for i in range(len(graphs))]
    # ret = bleu.compute_score(dev_df['target_text'], hyp)
    #print(hyp[:10])
    return hyps


In [27]:
g2t_sample_out = []
for raw in raw_train[0:8]:
    g2t = {'text' : raw['text'], 'entities' : raw['entities']}
    g2t_sample_out.append(g2t)
g2t_sample_out # just text and entities


print(g2t_sample_out[1]['text'])
dp.text2Indices(vocab, g2t_sample_out[1]['text'])

<ENT_1> serves the city of <ENT_0> .


tensor([ 2,  5,  9,  4, 10,  6,  7,  8,  3])

In [28]:
def t_cycle(text_batch, g2t_model, t2g_model, g2t_opt): # optimizes g2t
    t2g_model.eval()
    g2t_model.model.train()
    with torch.no_grad():
        pred_graphs = t2g_model.forward(text_batch)
    # syn_batch???
    g2t_opt.zero_grad()
    pred_text = predict_g2t(pred_graphs, g2t_model)
    # convert pred_text to tensor of word indices
    loss = F.nll_loss(pred_text.reshape(-1, pred_text.shape[-1]), text_batch.reshape(-1), ignore_index=0) # could be wrong, again
    loss.backward()
    #nn.utils.clip_grad_norm_(g2t_model.parameters(), config['clip'])
    g2t_opt.step()
    return loss.item()
    

In [34]:
def g_cycle(graph_batch, g2t_model, t2g_model, t2g_opt): # optimizes t2g
    g2t_model.model.eval()
    t2g_model.train()
    with torch.no_grad():
        pred_text = predict_g2t(graph_batch, g2t_model)
    # convert pred_text to correct format to input into t2g
    # syn_batch???
    t2g_opt.zero_grad()
    pred_graphs = t2g_model.forward(pred_text)
    loss = F.nll_loss(pred_graphs.contiguous().view(-1, pred_graphs.shape[-1]), graph_batch.contiguous().view(-1), ignore_index=0) # could be wrong, again
    loss.backward()
    #nn.utils.clip_grad_norm_(g2t_model.parameters(), config['clip'])
    t2g_opt.step()
    return loss.item()
    

In [35]:
def back_translation(text_batch, graph_batch, g2t_model, t2g_model, g2t_opt, t2g_opt):
    loss1 = g_cycle(graph_batch, g2t_model, t2g_model, t2g_opt)
    loss2 = t_cycle(text_batch, g2t_model, t2g_model, g2t_opt)
    return loss1, loss2

In [39]:
def train(epochs):
    t2g_opt = torch.optim.Adam(t2g_model.parameters())
    g2t_opt = torch.optim.Adam(g2t_model.model.parameters())

    for i in range(epochs):
        with tqdm.tqdm(dataloader) as tqb:
            for j, x in enumerate(tqb):
                # need pairings of text/graph batches (unparallel)
                text_batch = x
                graph_batch = train_df[(j*8):(j*8+7)] # needs to be changed
                print(graph_batch)
                back_translation(text_batch, graph_batch, g2t_model, t2g_model, g2t_opt, t2g_opt)

train(1)

0%|          | 0/1630 [00:00<?, ?it/s]                                         source_text  \
0  g2t: <H> Aarhus Airport <R> city Served <T> Aa...   
1  g2t: <H> Aarhus Airport <R> city Served <T> Aa...   
2  g2t: <H> Aarhus Airport <R> city Served <T> Aa...   
3  g2t: <H> Aarhus Airport <R> elevation Above Th...   
4  g2t: <H> Aarhus Airport <R> elevation Above Th...   
5  g2t: <H> Aarhus Airport <R> elevation Above Th...   
6  g2t: <H> Aarhus Airport <R> location <T> Tirstrup   

                                         target_text  
0           the Aarhus Airport of Aarhus , Denmark .  
1  Aarhus Airport serves the city of Aarhus , Den...  
2         Aarhus Airport serves the city of Aarhus .  
3    Aarhus Airport is 25.0 metres above sea level .  
4  Aarhus Airport is at an elevation of 25.0 metr...  
5  Aarhus Airport is 25.0 metres above the sea le...  
6            Aarhus Airport is located in Tirstrup .  



TypeError: string indices must be integers