In [None]:
import os; os.environ['CUDA_VISIBLE_DEVICES'] = '0';
import re
import numpy as np
import pandas as pd
import itertools
import codecs
import json
from collections import Counter
import xml.etree.ElementTree as ET

In [None]:
predicates = []
for data_dir in ['ru/train', 'ru/dev']:
    for d in os.listdir(data_dir):
        d = os.path.join(data_dir, d)
        files = os.listdir(d)
        for f in files:
            f = os.path.join(d, f)
            root = ET.parse(f)#.getroot()

            for e in root.iterfind('./entries/entry'):
                category = e.get('category')

                triple_list = []
                for mtriple in e.findall('./modifiedtripleset/mtriple'):
                    txt = mtriple.text
                    triple_list.append(txt)
                    parts = txt.split('|')
                    assert len(parts)==3

                    predicates.append(parts[1].strip())
print (len(predicates))

In [None]:
with_translation = set(predicates)

In [None]:
root = ET.parse(
    'ru/rdf-to-text-generation-test-data-without-refs-ru.xml'
)

for e in root.iterfind('./entries/entry'):
    category = e.get('category')

    triple_list = []
    for mtriple in e.findall('./modifiedtripleset/mtriple'):
        txt = mtriple.text
        triple_list.append(txt)
        parts = txt.split('|')
        assert len(parts)==3

        predicate = parts[1].strip()
        predicates.append(parts[1].strip())
        if not predicate in with_translation:
            print (predicate)

In [None]:
d = dict(Counter(predicates))
print (len(d))

In [None]:
d = sorted(d)

In [None]:
lines = []
with open('ru/ru_predicates.txt') as f:
    for line in f.readlines():
        line = line.strip().lower()
        lines.append(line)

In [None]:
# predicates to translate version
predicate2translate = {}
for k, l in zip(d, lines):
    predicate2translate[k] = l
print (len(predicate2translate))

In [None]:
# import joblib

In [None]:
# joblib.dump(predicate2translate, 'all_predicates.pkl')

In [None]:
# predicate2translate = joblib.load('all_predicates.pkl')

In [None]:
def extract_triplets(e):
    triple_list = []
    for mtriple in e.findall('./modifiedtripleset/mtriple'):
        parts = mtriple.text.split('|')
        assert len(parts)==3
        parts = [j.strip() for j in parts]
        triple_list.append(parts)
    return triple_list

def extract_translation_dict(e):
    en2ru = {}
    for item in e.findall('./dbpedialinks/dbpedialink') + e.findall('./links/link'):
        if item.get('direction')=='en2ru':
            parts = item.text.split('|')
            assert len(parts)==3
            en = re.sub('_', ' ', parts[0].strip())
            ru = re.sub('_', ' ', parts[-1].strip())
            relation = parts[1].strip()
            if relation=='sameAs':
                en2ru[en] = ru
    return en2ru

In [None]:
recs = []
for data_dir in ['ru/train']:
    for d in os.listdir(data_dir):
        #if not '1' in d:
        #    continue
        d = os.path.join(data_dir, d)
        files = os.listdir(d)
        for f in files:
            f = os.path.join(d, f)
            root = ET.parse(f)#.getroot()

            for e in root.iterfind('./entries/entry'):
                category, eid, size = e.get('category'), e.get('eid'), e.get('size')
                idx = '_'.join([category, eid, size])
                
                triple_list = extract_triplets(e)
                en2ru = extract_translation_dict(e)
                
                # translate triples
                out_triples = []
                for triple in triple_list:
                    subject, obj = re.sub('_', ' ', triple[0]), re.sub('_', ' ', triple[-1])
                    
                    subject, obj = en2ru.get(subject, subject), en2ru.get(obj, obj)
                    predicate = predicate2translate[triple[1]]
                    out_triples.append( subject +' | '+ predicate +' | '+ obj )
                
                # extrac lex
                lexs = []
                for item in e.findall('./lex'):
                    if item.get('lang')=='ru':
                        lex = item.text
                        lexs.append( lex )
                index = np.argmax([len(l) for l in lexs])
                #recs.append( (out_triples, lexs[index], idx) )
                recs.append( ('\n'.join(out_triples), lexs[index], idx, lexs) )
print (len(recs))

In [None]:
# # add additional RDFs from chinies trainslation
# for fname in ['ru/ch2ru_dev_data.json', 'ru/ch2ru_train_data.json']:
#     with codecs.open(fname, encoding='utf-8') as f:
#         for line in f:
#             data = json.loads(line)

#             out_triples = '\n'.join([e.strip() for e in data['ru_spo']])
#             lex = data['ru_text'].strip()

#             recs.append( (out_triples, lex, 0) )
# print (len(recs))

In [None]:
recs[-1]

In [None]:
df = pd.DataFrame(recs, columns=['phrase', 'question', 'id', 'refs'])

In [None]:
df['distractor'] = np.random.permutation(df.question.values)
print (df.shape)

In [None]:
df.phrase[0]

In [None]:
df.sample(5)

In [None]:
df['topic'] = df.id.apply(lambda s: s.split('_')[0])
df.topic.value_counts()

In [None]:
from data_utils.tokenization import RubertaBPETokenizer
model_dir = './ru-gpt2-medium-rdf-2-text/'
tokenizer_args = {
    'model_path': os.path.join(model_dir, 'vocab_50000.bpe'),
    'vocab_size': 50048,
}
tokenizer = RubertaBPETokenizer(model_path=tokenizer_args['model_path'])

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AdamW, GPT2PreTrainedModel, GPT2Model
from transformers.modeling_utils import SequenceSummary
from transformers import GPT2Config

class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)

        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
        if mc_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
            outputs = (loss,) + outputs
        if lm_labels is not None:
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)

In [None]:
ph_len = 20
max_tok4ph, max_tok4q = 254, 254
pad_len = 512

def encode_pair(rdf, distractor, question):
    join_list = [[0]]
    rset = rdf.split('\n')
    for triplet in rset:
        parts = triplet.split(' | ')
        for part in parts:
            tokens = tokenizer.EncodeAsIds(part).tokenization[:ph_len]
            join_list.append(tokens)
    
    join_list = join_list[:max_tok4ph]
    join_list.append([50005])
    
    join_list.append(
        tokenizer.EncodeAsIds(distractor).tokenization[:max_tok4q]+[2]
    )
    
    seq = list(itertools.chain.from_iterable(join_list))
    
    mc_token_wrong = len(seq)
    pad = pad_len-mc_token_wrong
    pad = pad*[1]
    seq_wrong = seq+pad
    if len(seq_wrong)>pad_len:
        return None
    lm_l_wrong = len(seq_wrong)*[-100]
    
    start = seq_wrong.index(50005)
    token_type_wrong = len(seq_wrong)*[1]
    token_type_wrong[0:start] = [50010]*start
    end = seq_wrong.index(2)+1
    token_type_wrong[start:end] = [50005]*(end-start)
    
    ######## encode correct question
    join_list = [[0]]
    rset = rdf.split('\n')
    for triplet in rset:
        parts = triplet.split(' | ')
        for part in parts:
            tokens = tokenizer.EncodeAsIds(part).tokenization[:ph_len]
            join_list.append(tokens)
            #join_list.append([50006])
        #join_list[-1] = [50007]
    
    join_list = join_list[:max_tok4ph]
    join_list.append([50005])
    join_list.append(
        tokenizer.EncodeAsIds(question).tokenization[:max_tok4q]+[2]
    )
    seq = list(itertools.chain.from_iterable(join_list))
    
    mc_token = len(seq)
    pad = pad_len-mc_token
    pad = pad*[1]
    seq = seq+pad
    if len(seq)>pad_len:
        return None
    lm_l = len(seq)*[-100]
    start, end = seq.index(50005)+1, seq.index(2)+1
    lm_l[start:end] = seq[start:end]
    
    start = seq.index(50005)
    token_type = len(seq)*[1]
    token_type[0:start] = [50010]*start
    end = seq.index(2)+1
    token_type[start:end] = [50005]*(end-start)
    
    input_ids = [seq_wrong, seq]
    mc_token_ids = [mc_token_wrong-2, mc_token-2]
    mc_labels = 1
    lm_labels = [lm_l_wrong, lm_l]
    token_type_ids = [token_type_wrong, token_type]
        
    tup = ([input_ids], [mc_token_ids], [lm_labels], [mc_labels], [token_type_ids])
    return tup

In [None]:
all_datasets = []
for i in range(df.shape[0]):
    rdf = df.phrase.values[i]
    distractor = df.distractor.values[i]
    q = df.question.values[i]
    tup = encode_pair(rdf, distractor, q)
    all_datasets.append(tup)
print (len(all_datasets))

In [None]:
# batch_size = 2
# epochs = 3
# lr = 3e-5
# max_grad_norm = 1.0

# train_index, test_index = list(range(len(all_datasets))), list(range(4))

# train_index, test_index = set(train_index), set(test_index)
# config = GPT2Config.from_pretrained('gpt2-medium')
# config.vocab_size = 50048
# config.output_hidden_states = True

# model = GPT2DoubleHeadsModel(config)
# # ch2ru triplet model
# model = model.from_pretrained( 'ru_gpt2', output_hidden_states=True )
# model.to(device)

# tensor_datasets = [e for i, e in enumerate(all_datasets) if i in train_index]
# tensor_datasets_val = [e for i, e in enumerate(all_datasets) if i in test_index]

# train_dataset = TensorDataset(
#     torch.tensor([e[0] for e in tensor_datasets]),
#     torch.tensor([e[1] for e in tensor_datasets]),
#     torch.tensor([e[2] for e in tensor_datasets]),
#     torch.tensor([e[3] for e in tensor_datasets]),
#     torch.tensor([e[4] for e in tensor_datasets])
# )
# valid_dataset = TensorDataset(
#     torch.tensor([e[0] for e in tensor_datasets_val]),
#     torch.tensor([e[1] for e in tensor_datasets_val]),
#     torch.tensor([e[2] for e in tensor_datasets_val]),
#     torch.tensor([e[3] for e in tensor_datasets_val]),
#     torch.tensor([e[4] for e in tensor_datasets_val])
# )
# print (len(train_dataset), len(valid_dataset))

# train_sampler = RandomSampler(train_dataset)
# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=4)

# prediction_sampler = SequentialSampler(valid_dataset)
# prediction_dataloader = DataLoader(valid_dataset, sampler=prediction_sampler, batch_size=batch_size*2, num_workers=4)

# model = model.cuda()
# param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'gamma', 'beta']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
#     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
# ]

# optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=False)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_dataloader), epochs=epochs)

# lm_coef, mc_coef = 1., 1.

# train_loss = []
# for _ in range(epochs):
#     model.train(); torch.cuda.empty_cache()
#     # Tracking variables
#     tr_loss = 0
#     nb_tr_examples, nb_tr_steps = 0, 0

#     # Train the data for one epoch
#     for step, batch in enumerate(train_dataloader, start=1):
#         # Add batch to GPU
#         batch = tuple(t.to(device) for t in batch)

#         input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
#         optimizer.zero_grad()
#         # Forward pass
#         lm_loss, mc_loss, *__ = model(
#             input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
#             mc_labels=mc_labels, lm_labels=lm_labels
#         )
#         loss = (lm_loss * lm_coef + mc_loss * mc_coef)
#         train_loss.append(loss.item())
#         # Backward pass
#         loss.backward()
#         # Update parameters and take a step using the computed gradient
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#         optimizer.step()
#         scheduler.step()

#         # Update tracking variables
#         tr_loss += loss.item()
#         nb_tr_examples += input_ids.size(0)
#         nb_tr_steps += 1
#         if step%100==0:
#             print (step, tr_loss/nb_tr_steps)
#     print ( 'epoch {} Train loss: {}'.format(_, tr_loss/nb_tr_steps) )

#     ### val
#     model.eval()
#     # Tracking variables 
#     tr_loss, nb_tr_steps = 0, 0
#     for step, batch in enumerate(prediction_dataloader, start=1):
#         # Add batch to GPU
#         batch = tuple(t.to(device) for t in batch)
#         # Unpack the inputs from our dataloader
#         input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
#         # Telling the model not to compute or store gradients, saving memory and speeding up prediction
#         with torch.no_grad():
#             lm_loss, mc_loss, *__ = model(
#                 input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
#                 mc_labels=mc_labels, lm_labels=lm_labels
#             )
#             loss = (lm_loss * lm_coef + mc_loss * mc_coef)

#             tr_loss += loss.item()
#             nb_tr_steps += 1
#     print ( 'val loss: {}'.format(tr_loss/nb_tr_steps) )
# model.train();
# model.save_pretrained( 'ru-gpt2-medium-rdf-2-text' )

In [None]:
config = GPT2Config.from_pretrained('gpt2-medium') #cache_dir='/ayb/vol2/home/blinoff/.cache'
config.vocab_size = 50048
config.output_hidden_states = True

model = GPT2DoubleHeadsModel(config)
model = model.from_pretrained( 'ru-gpt2-medium-rdf-2-text', output_hidden_states=True )
model.to(device)
model.train();
None

In [None]:
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    assert logits.dim() == 1#Only work for batch size 1 for now-could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

In [None]:
def generate_q_set(raw_text, num_samples=17, temperature=.7, top_k=11, top_p=.9, max_lenght=64):
    with torch.no_grad():
        torch.manual_seed( np.random.randint(1000) )

        #orig_input_ids = [0]+tokenizer.EncodeAsIds(raw_text).tokenization+[50005]
        #orig_token_type_ids = len(orig_input_ids)*[50010]; orig_token_type_ids[-1] = 50005
        join_list = [[0]]
        rset = raw_text.split('\n')
        for triplet in rset:
            parts = triplet.split(' | ')
            for part in parts:
                tokens = tokenizer.EncodeAsIds(part).tokenization[:ph_len]
                join_list.append(tokens)
                #join_list.append([50006])
            #join_list[-1] = [50007]

        join_list = join_list[:max_tok4ph]
        join_list.append([50005])
        orig_input_ids = list(itertools.chain.from_iterable(join_list))
        orig_token_type_ids = len(orig_input_ids)*[50010]; orig_token_type_ids[-1] = 50005
        
        result = {}
        for j in range(num_samples):
            input_ids = orig_input_ids.copy()
            token_type_ids = orig_token_type_ids.copy()
            input_ids_prob = []

            for i in range(max_lenght):
                t_input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
                t_token_type_ids = torch.tensor(token_type_ids, device=device).unsqueeze(0)
                #print (t_input_ids.shape)
                #print (t_token_type_ids.shape)
                
                logits = model( t_input_ids, token_type_ids=t_token_type_ids )

                logits = logits[0] #as it is tuple we need only 0 elem
                #print (logits.shape)
                #print ()

                logits = logits[0, -1, :] / temperature
                #print (logits.shape)
                #print ()

                logits = top_filtering(logits, top_k=top_k, top_p=top_p)
                probs = F.softmax(logits, dim=-1)

                prev = torch.topk(probs, 1)[1]
                input_ids_prob.append( probs[prev].item() )
                
                tok = prev.item()
                input_ids.append( tok )
                token_type_ids.append( 50005 )
                if tok==2:
                    break
            
            if not 2 in set(input_ids):#ad hock if there is no end
                input_ids[-1] = 2
            s,e = input_ids.index(50005)+1, input_ids.index(2)
            
            q = tokenizer.DecodeIds(input_ids[s:e])
            if not q in result:
                result[q] = []
            
            l = len(orig_input_ids)
            p = np.prod( input_ids_prob[s-l:e-l] ) / len(input_ids_prob[s-l:e-l])
            result[q].append(p)
            
        #print (result)
        result = [k for k, v in sorted(result.items(), key=lambda item: np.max(item[1]))]
        return result

In [None]:
recs = []

f = 'ru/rdf-to-text-generation-test-data-without-refs-ru.xml'
root = ET.parse(f)

for e in root.iterfind('./entries/entry'):
    category, eid, size = e.get('category'), e.get('eid'), e.get('size')
    idx = '_'.join([category, eid, size])

    triple_list = extract_triplets(e)
    en2ru = extract_translation_dict(e)

    # translate triples
    out_triples = []
    for triple in triple_list:
        subject, obj = re.sub('_', ' ', triple[0]), re.sub('_', ' ', triple[-1])

        subject, obj = en2ru.get(subject, subject), en2ru.get(obj, obj)
        predicate = predicate2translate[triple[1]]
        out_triples.append( subject +' | '+ predicate +' | '+ obj )

    # extrac lex
    lexs = []
    for item in e.findall('./lex'):
        if item.get('lang')=='ru':
            lex = item.text
            lexs.append( lex )
    #index = np.argmax([len(l) for l in lexs])
    #recs.append( (out_triples, lexs[index], idx) )
    recs.append( ('\n'.join(out_triples), lexs, idx) )
print (len(recs))

In [None]:
hypothesis, references = [], []
for i, (ph, refs, idx) in enumerate(recs):
    candidates = generate_q_set(ph, num_samples=19)
    hypothesis.append(candidates)
    references.append(refs)
    print (i)
    # if i>10:
    #     break

In [None]:
def GetShiftingWindows(thelist, size=2):
    return [ thelist[x:x+size] for x in range( len(thelist) - size + 1 ) ]

class Detokenizer:
    def __init__(self):
        self.paired_pattern = re.compile('[\"“(«<{\[].*?[\"”)»>}\]]')
        self.float_sep_pattern = re.compile('\d*[\.,] \d*')
        self.float_pattern = re.compile('(?<![a-zA-Z:])[-+]?\d*[\.,]\d+')
        
    def translate(self, s, rdf):
        dash_elements = []
        lines = rdf.split('\n')
        float_values = []
        for line in lines:
            dash_elements.extend(re.findall('.-.', line))
            parts = line.split(' | ')
            # find float numbers in subject or object strings
            curr_floats = self.float_pattern.findall( parts[0] )
            float_values.extend(curr_floats)
            curr_floats = self.float_pattern.findall( parts[-1] )
            float_values.extend(curr_floats)
        float_values = set([re.sub(',', '.', v) for v in float_values])
        
        dash_elements = set(dash_elements)
        
        ### normailze float values
        points = [0]
        for e in self.float_sep_pattern.finditer(s):
            points.extend(e.span())
        points.append(len(s))
        
        to_join = []
        for i, (start, end) in enumerate(GetShiftingWindows(points)):
            chunk = s[start:end]
            if i%2:
                replace_chunk = re.sub(',', '.', chunk)
                replace_chunk = re.sub(' +', '', replace_chunk)
                if replace_chunk in float_values:
                    chunk = replace_chunk
            to_join.append(chunk)
        s = ''.join(to_join)
        
        ### collapse spaces for paired punctuations
        points = [0]
        for e in self.paired_pattern.finditer(s):
            points.extend(e.span())
        points.append(len(s))
        
        to_join = []
        for i, (start, end) in enumerate(GetShiftingWindows(points)):
            chunk = s[start:end]
            if i%2:
                ch_start, ch_end = chunk[0], chunk[-1]
                chunk = chunk[1:-1].strip()
                chunk = ch_start+chunk+ch_end
            to_join.append(chunk)
        res = ''.join(to_join)
        
        ### 'a - b' cases
        collapse_elements = re.findall('. - .', res)
        for e in collapse_elements:
            wo_space_e = re.sub(' +', '', e)
            if wo_space_e in dash_elements:
                res = re.sub(e, wo_space_e, res)
        
        return res
    
detok = Detokenizer()

In [None]:
with open('results.txt', 'w') as f:
    ending = ''
    for i, h in enumerate(hypothesis):
        best_candidate = h[-1]
        rdf = recs[i][0]
        line = detok.translate(best_candidate, rdf)
        f.write(ending+line)
        ending = '\n'