In [62]:
import json
import pickle
import random

import torch
from torch import nn, optim
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torch.nn.utils.rnn as rnn_utils

import nltk
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.bleu_score import sentence_bleu
import time
import copy

from Vocab import Vocab
from LanguageModel import LanguageModel

import torch
torch.cuda.set_device(1)

print('import over')

copy_thres=1

import over


In [63]:
def batch_words2sentence(words_list):
    return [' '.join(words) for words in words_list]
def batch_tokens2words(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return: words_list corresponding to tokens
    return [[vocab.token2word[token] for token in tokens] for tokens in tokens_list]

def batch_tokens_remove_eos(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return pure tokens_list removed eos symbol
    result=[]
    for tokens in tokens_list:
        tokens_filtered=[]
        for token in tokens:
            if token == vocab.word2token['<eos>']:
#                 tokens_filtered.append(token)
                break
            else:
                tokens_filtered.append(token)
        result.append(tokens_filtered)
    return result

def batch_tokens_bleu(references, candidates, smooth_epsilon=0.001):
    ##    para: references and candidates are list[list] type
    ##    return: list of BLEU for every sample
    ##
    bleu_scores=[]
    for ref, candidate in zip(references, candidates):
        if min(len(ref), len(candidate))<4:
            bleu_scores.append(0)
        else:
            bleu_scores.append(sentence_bleu([ref], candidate, smoothing_function = SmoothingFunction(epsilon=smooth_epsilon).method1))
    return bleu_scores

with open('data_set/vocab.pk', 'rb') as f:
    vocab=pickle.load(f)

    
def seqs_split(seqs, vocab):
    seqs = batch_tokens_remove_eos(seqs, vocab)
    simple_sent1s=[]
    simple_sent2s=[]
    for seq in seqs:
        simple_sent1=[]
        simple_sent2=[]
        sent=simple_sent1
        for token in seq:
            if token==vocab.word2token['<split>']:
                sent=simple_sent2
            else:
                sent.append(token)
        simple_sent1s.append(simple_sent1)
        simple_sent2s.append(simple_sent2)
        
    return simple_sent1s, simple_sent2s

def simple_sents_concat(simple_sent1s, simple_sent2s, vocab, max_length):
    simple_sent_lens=[]
    simple_sents=simple_sent1s
    for i, sent in enumerate(simple_sent2s):
        simple_sents[i].append(vocab.word2token['<split>'])
        for token in sent:
            simple_sents[i].append(token)

        #if there is no <split> in simple_sent1s and simple_sent2s, then the length of sents_concat will be longer than max_length
        if len(simple_sents[i])>max_length:
            simple_sents[i] = simple_sents[i][:max_length]
            
        simple_sent_lens.append(len(simple_sents[i]))
            
        while(len(simple_sents[i])<max_length):
            simple_sents[i].append(vocab.word2token['<padding>'])
            
    return simple_sents, simple_sent_lens


def get_lm_inputs_and_labels(sents, vocab, max_length):
    lm_inputs=copy.deepcopy(sents)
    lm_labels=copy.deepcopy(sents)
    lm_input_lens=[]
    
    for sent in lm_inputs:
        if len(sent)>=max_length:
            sent=sent[:max_length-1]
        sent.insert(0, vocab.word2token['<sos>'])
        lm_input_lens.append(len(sent))
        while(len(sent)<max_length):
            sent.append(vocab.word2token['<padding>'])

    for sent in lm_labels:
        if len(sent)>=max_length:
            sent = sent[:max_length-1]
        sent.append(vocab.word2token['<eos>'])
        while(len(sent)<max_length):
            sent.append(vocab.word2token['<padding>'])
        
    return lm_inputs, lm_input_lens, lm_labels


def duplicate_reconstruct_labels(sents, topk):
    return [x for x in sents for ii in range(topk)]


def batch_tokens_bleu_split_version(references, candidates, vocab, smooth_epsilon=0.001):
    # needn't remove '<sos>' token before calling this function, which is different from the 'batch_token_bleu()' version
    #
    ref1, ref2 = seqs_split(references, vocab)
    cand1, cand2 = seqs_split(candidates, vocab)
    bleu_simple_sent1s = batch_tokens_bleu(ref1, cand1)
    bleu_simple_sent2s = batch_tokens_bleu(ref2, cand2)
#     print(bleu_simple_sent1s)
#     print(bleu_simple_sent2s)
    bleu=[]
    for idx in range(len(bleu_simple_sent1s)):
        bleu.append((bleu_simple_sent1s[idx]+bleu_simple_sent2s[idx])/2)
    return bleu


def set_model_grad(model, is_grad):
    for param in model.parameters():
         param.requires_grad = is_grad
            
def batch_tokens_remove_padding(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return pure tokens_list removed eos symbol
    result=[]
    for tokens in tokens_list:
        tokens_filtered=[]
        for token in tokens:
            if token == vocab.word2token['<padding>']:
#                 tokens_filtered.append(token)
                break
            else:
                tokens_filtered.append(token)
        result.append(tokens_filtered)
    return result

In [3]:
#split data set

with open('./data_set2/split_data_set/train_complex_sents.pk', 'rb') as f:
    split_train_set_inputs = pickle.load(f)
with open('./data_set2/split_data_set/train_complex_sent_lens.pk', 'rb') as f:
    split_train_set_input_lens = pickle.load(f)
with open('./data_set2/split_data_set/train_pseudo_labels.pk', 'rb') as f:
    split_pseudo_train_set_labels = pickle.load(f)
    
# with open('./data_set2/split_data_set/train_complex_sents_supervised.pk', 'rb') as f:
#     split_train_set_inputs_supervised = pickle.load(f)
# with open('./data_set2/split_data_set/train_complex_sent_lens_supervised.pk', 'rb') as f:
#     split_train_set_input_lens_supervised = pickle.load(f)
# with open('./data_set2/split_data_set/train_labels_supervised.pk', 'rb') as f:
#     split_train_set_labels_supervised = pickle.load(f)
    
with open('./data_set2/split_data_set/train_labels.pk', 'rb') as f:
    split_train_set_labels = pickle.load(f)

In [64]:
print(len(split_train_set_inputs), len(split_train_set_labels))

791956 791956


In [12]:
idx = random.randint(0, len(split_train_set_inputs)-1)
a=split_train_set_inputs[idx]
label = split_train_set_labels[idx]

a = batch_tokens_remove_padding([a], vocab)
label = batch_tokens_remove_eos([label], vocab)
a = batch_tokens2words(a, vocab)
label = batch_tokens2words(label, vocab)
a = batch_words2sentence(a)
label = batch_words2sentence(label)
print(a)
print(label)

['residents of the neighborhood said two brothers who were hamas fighters were in the area at the time of the attack but that the mortar fire had not come from the school compound , but from elsewhere in the neighborhood .']
['residents of the neighborhood said two brothers who were hamas fighters were in the area at the time of the attack . <split> but the residents also said the mortar fire had not come from the school compound , but from elsewhere in the neighborhood .']


In [15]:
a=batch_tokens_remove_padding(split_train_set_inputs, vocab)
label = batch_tokens_remove_eos(split_train_set_labels, vocab)

bleus = batch_tokens_bleu(references=label, candidates=a)
print('haha')
a=batch_tokens2words(a, vocab)
label=batch_tokens2words(label, vocab)
inputs=batch_words2sentence(a)
labels=batch_words2sentence(label)
    


haha


In [90]:
indices=[]
cnt=0
for idx, bleu in enumerate(bleus):
    if bleu>0.3 and bleu<0.45 and len(a[idx])<30 and len(label[idx])<30 and '<low_freq>' not in a[idx] and '<low_freq>' not in label[idx]:
        indices.append(idx)
        cnt+=1
        
print(cnt)

2795


In [99]:
idx = random.choice(indices)
print(inputs[idx])
print(labels[idx])

cape breton north was a provincial electoral district in cape breton , nova scotia , canada , that elected one member of the nova scotia house of assembly .
cape breton north is a former provincial electoral district in nova scotia , canada . <split> it elected one member to the nova scotia house of assembly .


In [102]:
kk='the colony grew quickly upon its founding , but internal disputes and lack of funding spelled its demise by 1850 .'
kkk='the colony grew quickly upon its founding . <split> internal disputes , lack of funding and the draw of urban jobs led to its decline by 1850 .'
print(batch_tokens_bleu(references=[kkk.split(' ')], candidates=[kk.split(' ')]))

[0.3175938828949709]


In [101]:
# length analyse, inputs
lengths=0
for ii in range(len(a)):
    lengths+=len(a[ii])
print(lengths/len(a))

# length analyse, labels
lengths=0
for ii in range(len(a)):
    lengths+=len(label[ii])
print(lengths/len(a))

33.41148498148887
37.89943001883943


In [6]:
#BLEU score of pseudo label and true label
bleu_score = batch_tokens_bleu_split_version(references=split_train_set_labels, candidates=split_pseudo_train_set_labels, vocab=vocab)

s=0
for bleu in bleu_score:
    s+=bleu
print(s/len(bleu_score))

0.6569447777209572


In [41]:
num=len(split_train_set_labels)-1

references=split_train_set_labels[:num]
candidates=split_pseudo_train_set_labels[:num]

ref1, ref2 = seqs_split(references, vocab)
cand1, cand2 = seqs_split(candidates, vocab)
bleu_simple_sent1s = batch_tokens_bleu(ref1, cand1, smooth_epsilon=0.0001)
bleu_simple_sent2s = batch_tokens_bleu(ref2, cand2, smooth_epsilon=0.0001)

# print(bleu_simple_sent1s)
# print(bleu_simple_sent2s)

bleu=[]
for idx in range(len(bleu_simple_sent1s)):
    bleu.append((bleu_simple_sent1s[idx]+bleu_simple_sent2s[idx])/2)

s=0
for x in bleu:
    s+=x
print(s/len(bleu))

0.6568139995249009


In [65]:
#split data set

with open('./data_set2/split_data_set/test_complex_sents.pk', 'rb') as f:
    split_test_set_inputs = pickle.load(f)
with open('./data_set2/split_data_set/test_complex_sent_lens.pk', 'rb') as f:
    split_test_set_input_lens = pickle.load(f)
with open('./data_set2/split_data_set/test_labels.pk', 'rb') as f:
    split_test_set_labels = pickle.load(f)


In [66]:
print(len(split_test_set_labels), len(split_test_set_inputs))

5000 5000


In [67]:
pseudo_labels=[]
for idx, split_test_set_input in enumerate(split_test_set_inputs):
    pseudo_label=[]
    cut_idx = int(split_test_set_input_lens[idx]/2)
    for ii in range(split_test_set_input_lens[idx]):
        pseudo_label.append(split_test_set_input[ii])
        if ii==cut_idx-1:
            pseudo_label.append(vocab.word2token['.'])
            pseudo_label.append(vocab.word2token['<split>'])
            
    pseudo_labels.append(pseudo_label)
    if (len(pseudo_label)-split_test_set_input_lens[idx])!=2:
        print(len(pseudo_label)-split_test_set_input_lens[idx])
    if idx==110:
        print(cut_idx)
        print(split_test_set_inputs[idx])
        print(pseudo_label, len(pseudo_label), split_test_set_input_lens[idx], len(split_test_set_inputs[idx]))


14
[11206, 4309, 41089, 10934, 14725, 40533, 21467, 19385, 3, 28954, 10934, 872, 30803, 13264, 25932, 27513, 31824, 25168, 38312, 10934, 29270, 500, 13264, 25932, 13264, 3, 13264, 13069, 37734, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[11206, 4309, 41089, 10934, 14725, 40533, 21467, 19385, 3, 28954, 10934, 872, 30803, 13264, 37734, 5, 25932, 27513, 31824, 25168, 38312, 10934, 29270, 500, 13264, 25932, 13264, 3, 13264, 13069, 37734] 31 29 50


In [68]:
scores = batch_tokens_bleu_split_version(references=split_test_set_labels, candidates=pseudo_labels, vocab=vocab)
s=0
for score in scores:
    s+=score
print(s/len(scores))

0.6552914947849449


In [69]:
a=[]
b=[]
for label in split_test_set_labels:
    tmp=[]
    for token in label:
        if token!=vocab.word2token['<low_freq>']:
            tmp.append(token)
    a.append(tmp)
    
for label in pseudo_labels:
    tmp=[]
    for token in label:
        if token!=vocab.word2token['<low_freq>']:
            tmp.append(token)
    b.append(tmp)
    
print('h')

scores = batch_tokens_bleu_split_version(references=a, candidates=b, vocab=vocab)
s=0
for score in scores:
    s+=score
print(s/len(scores))

h
0.6501506333778686
