In [1]:
import json
import pickle
import random
from collections import Counter
import os
import copy

import nltk
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.bleu_score import sentence_bleu

from Vocab import Vocab

file_group = 'validation'    #availabe:  train    test    validation    tune

with open('vocab.pk', 'rb') as f:
    vocab = pickle.load(f)

In [2]:
def batch_words2sentence(words_list):
    return [' '.join(words) for words in words_list]
def batch_tokens2words(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return: words_list corresponding to tokens
    return [[vocab.token2word[token] for token in tokens] for tokens in tokens_list]

def batch_tokens_remove_eos(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return pure tokens_list removed eos symbol
    result=[]
    for tokens in tokens_list:
        tokens_filtered=[]
        for token in tokens:
            if token == vocab.word2token['<eos>']:
#                 tokens_filtered.append(token)
                break
            else:
                tokens_filtered.append(token)
        result.append(tokens_filtered)
    return result

def batch_tokens_bleu(references, candidates, smooth_epsilon=0.001):
    ##    para: references and candidates are list[list] type
    ##    return: list of BLEU for every sample
    ##
    bleu_scores=[]
    for ref, candidate in zip(references, candidates):
        if min(len(ref), len(candidate))<4:
            bleu_scores.append(0)
        else:
            bleu_scores.append(sentence_bleu([ref], candidate, smoothing_function = SmoothingFunction(epsilon=smooth_epsilon).method1))
    return bleu_scores


def seqs_split(seqs, vocab):
    seqs = batch_tokens_remove_eos(seqs, vocab)
    simple_sent1s=[]
    simple_sent2s=[]
    for seq in seqs:
        simple_sent1=[]
        simple_sent2=[]
        sent=simple_sent1
        for token in seq:
            if token==vocab.word2token['<split>']:
                sent=simple_sent2
            else:
                sent.append(token)
        simple_sent1s.append(simple_sent1)
        simple_sent2s.append(simple_sent2)
        
    return simple_sent1s, simple_sent2s

def batch_tokens_bleu_split_version(references, candidates, vocab, smooth_epsilon=0.001):
    #
    #
    ref1, ref2 = seqs_split(references, vocab)
    cand1, cand2 = seqs_split(candidates, vocab)
    bleu_simple_sent1s = batch_tokens_bleu(ref1, cand1)
    bleu_simple_sent2s = batch_tokens_bleu(ref2, cand2)
#     print(bleu_simple_sent1s)
#     print(bleu_simple_sent2s)
    bleu=[]
    for idx in range(len(bleu_simple_sent1s)):
        bleu.append((bleu_simple_sent1s[idx]+bleu_simple_sent2s[idx])/2)
        
    return bleu

In [3]:
#split data set

with open('../data_set/split_data_set/train_complex_sents.pk', 'rb') as f:
    train_complex_sents = pickle.load(f)
with open('../data_set/split_data_set/train_complex_sent_lens.pk', 'rb') as f:
    train_complex_sent_lens = pickle.load(f)
with open('../data_set/split_data_set/train_pseudo_labels.pk', 'rb') as f:
    train_pseudo_labels = pickle.load(f)
with open('../data_set/split_data_set/train_labels.pk', 'rb') as f:
    train_labels = pickle.load(f)


In [4]:
sample_rate=0.2
sample_num = int(sample_rate*len(train_pseudo_labels))
all_indices = [x for x in range(len(train_pseudo_labels))]
indices_choice = random.sample(all_indices, sample_num)

print(sample_num, len(set(indices_choice)))

197988 197988


In [5]:
train_complex_sents_supervised=[]
train_complex_sent_lens_supervised=[]
train_labels_supervised=[]

train_complex_sents_tmp=[]
train_complex_sent_lens_tmp=[]
train_pseudo_labels_tmp=[]
train_labels_tmp=[]
    
indices_choice_set = set(indices_choice)
for idx in range(len(train_pseudo_labels)):
    if idx not in indices_choice_set:
        train_complex_sents_tmp.append(train_complex_sents[idx])
        train_complex_sent_lens_tmp.append(train_complex_sent_lens[idx])
        train_pseudo_labels_tmp.append(train_pseudo_labels[idx])
        train_labels_tmp.append(train_labels[idx])
    else:
        train_complex_sents_supervised.append(train_complex_sents[idx])
        train_complex_sent_lens_supervised.append(train_complex_sent_lens[idx])
        train_labels_supervised.append(train_labels[idx])
        
print(len(train_labels_supervised), len(train_pseudo_labels_tmp))

197988 791956


In [6]:
#save
with open('./split_data_set/train_complex_sents.pk', 'wb') as f:
    pickle.dump(train_complex_sents_tmp, f)
with open('./split_data_set/train_complex_sent_lens.pk', 'wb') as f:
    pickle.dump(train_complex_sent_lens_tmp, f)
with open('./split_data_set/train_pseudo_labels.pk', 'wb') as f:
    pickle.dump(train_pseudo_labels_tmp, f)
with open('./split_data_set/train_labels.pk', 'wb') as f:
    pickle.dump(train_labels_tmp, f)
    
with open('./split_data_set/train_complex_sents_supervised.pk', 'wb') as f:
    pickle.dump(train_complex_sents_supervised, f)
with open('./split_data_set/train_complex_sent_lens_supervised.pk', 'wb') as f:
    pickle.dump(train_complex_sent_lens_supervised, f)
with open('./split_data_set/train_labels_supervised.pk', 'wb') as f:
    pickle.dump(train_labels_supervised, f)

In [7]:
#test
print(train_complex_sents_tmp[0])
print(train_complex_sent_lens_tmp[0])
print(train_pseudo_labels_tmp[0])
print('\n')
print(train_complex_sents_supervised[0])
print(train_complex_sent_lens_supervised[0])
print(train_labels_supervised[0])

[33280, 752, 31001, 43095, 10934, 16112, 6111, 13264, 25168, 43095, 10934, 16156, 43872, 25168, 38165, 13264, 21467, 10754, 26356, 34730, 39262, 37807, 406, 40780, 30014, 39650, 21467, 33280, 36955, 13264, 10130, 12834, 8082, 406, 24796, 27513, 2079, 37734, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
38
[33280, 752, 31001, 43095, 10934, 16112, 6111, 13264, 25168, 43095, 10934, 16156, 43872, 25168, 38165, 13264, 21467, 10754, 26356, 37734, 5, 34730, 39262, 37807, 406, 40780, 30014, 39650, 21467, 33280, 36955, 13264, 10130, 12834, 8082, 406, 24796, 27513, 2079, 37734, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


[9194, 12421, 1818, 15472, 40167, 4493, 13126, 42074, 13264, 12421, 20589, 13126, 14912, 21467, 43223, 31508, 25168, 21698, 27513, 43017, 14054, 4435, 18038, 13126, 21467, 7258, 27513, 4367, 37586, 40848, 25168, 34399, 42273, 13126, 19682, 4910, 9240, 27513, 14754, 21449, 13222, 37734, 1, 1, 1, 1, 1, 1, 1, 1]
42
[9194, 12421, 1818, 15472, 40167, 4493, 13126, 42074, 13

In [8]:
#fusion data set

with open('../data_set/fusion_data_set/train_simple_sents.pk', 'rb') as f:
    train_simple_sents = pickle.load(f)
with open('../data_set/fusion_data_set/train_simple_sent_lens.pk', 'rb') as f:
    train_simple_sent_lens = pickle.load(f)
with open('../data_set/fusion_data_set/train_labels.pk', 'rb') as f:
    train_labels = pickle.load(f)
    
with open('../data_set/fusion_data_set/train_pseudo_simple_sents.pk', 'rb') as f:
    train_pseudo_simple_sents = pickle.load(f)
with open('../data_set/fusion_data_set/train_pseudo_simple_sent_lens.pk', 'rb') as f:
    train_pseudo_simple_sent_lens = pickle.load(f)
with open('../data_set/fusion_data_set/train_pseudo_labels.pk', 'rb') as f:
    train_pseudo_labels = pickle.load(f)

In [9]:
train_simple_sents_supervised=[]
train_simple_sent_lens_supervised=[]
train_labels_supervised=[]

train_pseudo_simple_sents_tmp=[]
train_pseudo_simple_sent_lens_tmp=[]
train_pseudo_labels_tmp=[]

indices_choice_set = set(indices_choice)
for idx in range(len(train_pseudo_labels)):
    if idx not in indices_choice_set:
        train_pseudo_simple_sents_tmp.append(train_pseudo_simple_sents[idx])
        train_pseudo_simple_sent_lens_tmp.append(train_pseudo_simple_sent_lens[idx])
        train_pseudo_labels_tmp.append(train_pseudo_labels[idx])
    else:
        train_simple_sents_supervised.append(train_simple_sents[idx])
        train_simple_sent_lens_supervised.append(train_simple_sent_lens[idx])
        train_labels_supervised.append(train_labels[idx])
        
print(len(train_labels_supervised), len(train_pseudo_labels_tmp))

197988 791956


In [10]:
#save
with open('./fusion_data_set/train_pseudo_simple_sents.pk', 'wb') as f:
    pickle.dump(train_pseudo_simple_sents_tmp, f)
with open('./fusion_data_set/train_pseudo_simple_sent_lens.pk', 'wb') as f:
    pickle.dump(train_pseudo_simple_sent_lens_tmp, f)
with open('./fusion_data_set/train_pseudo_labels.pk', 'wb') as f:
    pickle.dump(train_pseudo_labels_tmp, f)


with open('./fusion_data_set/train_simple_sents_supervised.pk', 'wb') as f:
    pickle.dump(train_simple_sents_supervised, f)
with open('./fusion_data_set/train_simple_sent_lens_supervised.pk', 'wb') as f:
    pickle.dump(train_simple_sent_lens_supervised, f)
with open('./fusion_data_set/train_labels_supervised.pk', 'wb') as f:
    pickle.dump(train_labels_supervised, f)

In [11]:
#test
print(train_pseudo_simple_sents_tmp[0])
print(train_pseudo_simple_sent_lens_tmp[0])
print(train_pseudo_labels_tmp[0])
print('\n')
print(train_simple_sents_supervised[0])
print(train_simple_sent_lens_supervised[0])
print(train_labels_supervised[0])

[33280, 752, 31001, 43095, 10934, 16112, 6111, 13264, 25168, 43095, 10934, 16156, 43872, 25168, 38165, 13264, 21467, 10754, 26356, 37734, 5, 34730, 39262, 37807, 406, 40780, 30014, 39650, 21467, 33280, 36955, 13264, 10130, 12834, 8082, 406, 24796, 27513, 2079, 37734, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
40
[33280, 752, 31001, 43095, 10934, 16112, 6111, 13264, 25168, 43095, 10934, 16156, 43872, 25168, 38165, 13264, 21467, 10754, 26356, 34730, 39262, 37807, 406, 40780, 30014, 39650, 21467, 33280, 36955, 13264, 10130, 12834, 8082, 406, 24796, 27513, 2079, 37734, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


[9194, 12421, 1818, 15472, 40167, 4493, 13126, 42074, 13264, 12421, 20589, 13126, 14912, 21467, 43223, 31508, 25168, 21698, 27513, 43017, 14054, 4435, 18038, 13126, 21467, 7258, 27513, 4367, 37734, 5, 3, 13264, 37586, 40848, 25168, 34399, 42273, 13126, 19682, 4910, 9240, 27513, 14754, 21449, 13222, 37734, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
46
[9194, 12421, 1818, 15472, 40167, 4493, 13