In [1]:
import json
import pickle
import random

import torch
from torch import nn, optim
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import torch.nn.utils.rnn as rnn_utils

import nltk
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.bleu_score import sentence_bleu
import time
import copy

from Vocab import Vocab
from LanguageModel import LanguageModel

import torch
torch.cuda.set_device(1)

print('import over')

copy_thres=1

import over


In [2]:
def batch_words2sentence(words_list):
    return [' '.join(words) for words in words_list]
def batch_tokens2words(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return: words_list corresponding to tokens
    return [[vocab.token2word[token] for token in tokens] for tokens in tokens_list]

def batch_tokens_remove_eos(tokens_list, vocab):
    ##    para: tokens_list is list[list] type
    ##    return pure tokens_list removed eos symbol
    result=[]
    for tokens in tokens_list:
        tokens_filtered=[]
        for token in tokens:
            if token == vocab.word2token['<eos>']:
#                 tokens_filtered.append(token)
                break
            else:
                tokens_filtered.append(token)
        result.append(tokens_filtered)
    return result

def batch_tokens_bleu(references, candidates, smooth_epsilon=0.001):
    ##    para: references and candidates are list[list] type
    ##    return: list of BLEU for every sample
    ##
    bleu_scores=[]
    for ref, candidate in zip(references, candidates):
        if min(len(ref), len(candidate))<4:
            bleu_scores.append(0)
        else:
            bleu_scores.append(sentence_bleu([ref], candidate, smoothing_function = SmoothingFunction(epsilon=smooth_epsilon).method1))
    return bleu_scores

with open('data_set/vocab.pk', 'rb') as f:
    vocab=pickle.load(f)

    
def seqs_split(seqs, vocab):
    seqs = batch_tokens_remove_eos(seqs, vocab)
    simple_sent1s=[]
    simple_sent2s=[]
    for seq in seqs:
        simple_sent1=[]
        simple_sent2=[]
        sent=simple_sent1
        for token in seq:
            if token==vocab.word2token['<split>']:
                sent=simple_sent2
            else:
                sent.append(token)
        simple_sent1s.append(simple_sent1)
        simple_sent2s.append(simple_sent2)
        
    return simple_sent1s, simple_sent2s

def simple_sents_concat(simple_sent1s, simple_sent2s, vocab, max_length):
    simple_sent_lens=[]
    simple_sents=simple_sent1s
    for i, sent in enumerate(simple_sent2s):
        simple_sents[i].append(vocab.word2token['<split>'])
        for token in sent:
            simple_sents[i].append(token)

        #if there is no <split> in simple_sent1s and simple_sent2s, then the length of sents_concat will be longer than max_length
        if len(simple_sents[i])>max_length:
            simple_sents[i] = simple_sents[i][:max_length]
            
        simple_sent_lens.append(len(simple_sents[i]))
            
        while(len(simple_sents[i])<max_length):
            simple_sents[i].append(vocab.word2token['<padding>'])
            
    return simple_sents, simple_sent_lens


def get_lm_inputs_and_labels(sents, vocab, max_length):
    lm_inputs=copy.deepcopy(sents)
    lm_labels=copy.deepcopy(sents)
    lm_input_lens=[]
    
    for sent in lm_inputs:
        if len(sent)>=max_length:
            sent=sent[:max_length-1]
        sent.insert(0, vocab.word2token['<sos>'])
        lm_input_lens.append(len(sent))
        while(len(sent)<max_length):
            sent.append(vocab.word2token['<padding>'])

    for sent in lm_labels:
        if len(sent)>=max_length:
            sent = sent[:max_length-1]
        sent.append(vocab.word2token['<eos>'])
        while(len(sent)<max_length):
            sent.append(vocab.word2token['<padding>'])
        
    return lm_inputs, lm_input_lens, lm_labels


def duplicate_reconstruct_labels(sents, topk):
    return [x for x in sents for ii in range(topk)]


def batch_tokens_bleu_split_version(references, candidates, vocab, smooth_epsilon=0.001):
    # needn't remove '<sos>' token before calling this function, which is different from the 'batch_token_bleu()' version
    #
    ref1, ref2 = seqs_split(references, vocab)
    cand1, cand2 = seqs_split(candidates, vocab)
    bleu_simple_sent1s = batch_tokens_bleu(ref1, cand1)
    bleu_simple_sent2s = batch_tokens_bleu(ref2, cand2)
#     print(bleu_simple_sent1s)
#     print(bleu_simple_sent2s)
    bleu=[]
    for idx in range(len(bleu_simple_sent1s)):
        bleu.append((bleu_simple_sent1s[idx]+bleu_simple_sent2s[idx])/2)
    return bleu


def set_model_grad(model, is_grad):
    for param in model.parameters():
         param.requires_grad = is_grad

In [3]:
seqs=[[8,9,90,5,3,2,1], [5,8,9,90,5,3,2,1], [8,2,9,40,5,3,2,2,1], [8,9,90,5,3,2,1], [8,9,90]]
a,b = seqs_split(seqs, vocab)

print(a)
print(b)

lm_in, lm_in_lens, lm_labels=get_lm_inputs_and_labels(a,vocab, max_length=6)
print(lm_in)
print(lm_in_lens)
print(lm_labels)
lm_in, lm_in_lens, lm_labels=get_lm_inputs_and_labels(b,vocab, max_length=6)
print(lm_in)
print(lm_in_lens)
print(lm_labels)

c,d=simple_sents_concat(a,b,vocab, 3)
print(c)
print(d)


batch_tokens_bleu([[1,2,3,4,5,6]], [[2,3,1,4,5]])

[[8, 9, 90], [], [8], [8, 9, 90], [8, 9, 90]]
[[3], [8, 9, 90, 3], [], [3], []]
[[0, 8, 9, 90, 1, 1], [0, 1, 1, 1, 1, 1], [0, 8, 1, 1, 1, 1], [0, 8, 9, 90, 1, 1], [0, 8, 9, 90, 1, 1]]
[4, 1, 2, 4, 4]
[[8, 9, 90, 2, 1, 1], [2, 1, 1, 1, 1, 1], [8, 2, 1, 1, 1, 1], [8, 9, 90, 2, 1, 1], [8, 9, 90, 2, 1, 1]]
[[0, 3, 1, 1, 1, 1], [0, 8, 9, 90, 3, 1], [0, 1, 1, 1, 1, 1], [0, 3, 1, 1, 1, 1], [0, 1, 1, 1, 1, 1]]
[2, 5, 1, 2, 1]
[[3, 2, 1, 1, 1, 1], [8, 9, 90, 3, 2, 1], [2, 1, 1, 1, 1, 1], [3, 2, 1, 1, 1, 1], [2, 1, 1, 1, 1, 1]]
[[8, 9, 90], [5, 8, 9], [8, 5, 1], [8, 9, 90], [8, 9, 90]]
[3, 3, 2, 3, 3]


[0.013910597740964967]

In [4]:
#fusion data set

with open('./data_set2/fusion_data_set/train_pseudo_simple_sents.pk', 'rb') as f:
    fusion_pseudo_train_set_inputs = pickle.load(f)
with open('./data_set2/fusion_data_set/train_pseudo_simple_sent_lens.pk', 'rb') as f:
    fusion_pseudo_train_set_input_lens = pickle.load(f)
with open('./data_set2/fusion_data_set/train_pseudo_labels.pk', 'rb') as f:
    fusion_pseudo_train_set_labels = pickle.load(f)
with open('./data_set2/fusion_data_set/train_simple_sents_supervised.pk', 'rb') as f:
    fusion_train_set_inputs_supervised = pickle.load(f)
with open('./data_set2/fusion_data_set/train_simple_sent_lens_supervised.pk', 'rb') as f:
    fusion_train_set_input_lens_supervised = pickle.load(f)
with open('./data_set2/fusion_data_set/train_labels_supervised.pk', 'rb') as f:
    fusion_train_set_labels_supervised = pickle.load(f)
    
    
with open('./data_set2/fusion_data_set/validation_simple_sents.pk', 'rb') as f:
    fusion_pseudo_valid_set_inputs = pickle.load(f)
with open('./data_set2/fusion_data_set/validation_simple_sent_lens.pk', 'rb') as f:
    fusion_pseudo_valid_set_input_lens = pickle.load(f)
with open('./data_set2/fusion_data_set/validation_labels.pk', 'rb') as f:
    fusion_pseudo_valid_set_labels = pickle.load(f)
    
    
#split data set

with open('./data_set2/split_data_set/train_complex_sents.pk', 'rb') as f:
    split_train_set_inputs = pickle.load(f)
with open('./data_set2/split_data_set/train_complex_sent_lens.pk', 'rb') as f:
    split_train_set_input_lens = pickle.load(f)
with open('./data_set2/split_data_set/train_pseudo_labels.pk', 'rb') as f:
    split_pseudo_train_set_labels = pickle.load(f)
with open('./data_set2/split_data_set/train_complex_sents_supervised.pk', 'rb') as f:
    split_train_set_inputs_supervised = pickle.load(f)
with open('./data_set2/split_data_set/train_complex_sent_lens_supervised.pk', 'rb') as f:
    split_train_set_input_lens_supervised = pickle.load(f)
with open('./data_set2/split_data_set/train_labels_supervised.pk', 'rb') as f:
    split_train_set_labels_supervised = pickle.load(f)
    
    
with open('./data_set2/split_data_set/validation_complex_sents.pk', 'rb') as f:
    split_valid_set_inputs = pickle.load(f)
with open('./data_set2/split_data_set/validation_complex_sent_lens.pk', 'rb') as f:
    split_valid_set_input_lens = pickle.load(f)
with open('./data_set2/split_data_set/validation_labels.pk', 'rb') as f:
    split_pseudo_valid_set_labels = pickle.load(f)


In [5]:
print(len(split_train_set_inputs), len(split_train_set_input_lens), len(split_pseudo_train_set_labels))
print(len(fusion_pseudo_train_set_inputs), len(fusion_pseudo_train_set_input_lens), len(fusion_pseudo_train_set_labels))

print(len(split_train_set_inputs_supervised), len(split_train_set_input_lens_supervised), len(split_train_set_labels_supervised))
print(len(fusion_train_set_inputs_supervised), len(fusion_train_set_input_lens_supervised), len(fusion_train_set_labels_supervised))

791956 791956 791956
791956 791956 791956
197988 197988 197988
197988 197988 197988


In [6]:
class Encoder(nn.Module):
    def __init__(self, use_cuda, hidden_dim, input_dim, vocab):#, pre_train_weight, is_fix_word_vector = 1):
        super(Encoder, self).__init__()
        
        self.use_cuda = use_cuda
        self.input_dim=input_dim
        self.hidden_dim=hidden_dim
        self.vocab = vocab
        
        self.lstm=torch.nn.LSTM(input_size=self.input_dim, 
                                hidden_size= self.hidden_dim, 
                                bidirectional=True,
                                batch_first=True
                               )
        
        #embedding
        self.embed=nn.Embedding(len(self.vocab.word2token), input_dim)
        #loading pre trained word embedding
        with open('data_set/pre_trained_token_embedding.pk', 'rb') as f:
            pre_train_word_embedding = pickle.load(f)
            
        self.embed.weight.data.copy_(torch.FloatTensor(pre_train_word_embedding))
#         self.embed.weight.requires_grad = False
        
    def order(self, inputs, inputs_len):    #inputs: tensor, inputs_len: 1D tensor
        inputs_len, sort_ids = torch.sort(inputs_len, dim=0, descending=True)
        
        if self.use_cuda:
            inputs = inputs.index_select(0, Variable(sort_ids).cuda())
        else:
            inputs = inputs.index_select(0, Variable(sort_ids))
        
        _, true_order_ids = torch.sort(sort_ids, dim=0, descending=False)
        
        return inputs, inputs_len, true_order_ids
    #
    def forward(self, inputs, inputs_len):
        inputs = Variable(inputs)
        if self.use_cuda:
            inputs=inputs.cuda()
            
        inputs, sort_len, true_order_ids = self.order(inputs, inputs_len)

        in_vecs=self.embed(inputs)

        packed = rnn_utils.pack_padded_sequence(input=in_vecs, lengths=list(sort_len), batch_first =True)
        
        outputs, (hn,cn) = self.lstm(packed)
        outputs, sent_lens = rnn_utils.pad_packed_sequence(outputs)
        
        #print('outpurs size, hn size and cn size: ', outputs.size(), hn.size(), cn.size())
        outputs = outputs.transpose(0,1)  #transpose is necessary
        #print('outpurs size, hn size and cn size: ', outputs.size(), hn.size(), cn.size())
        
        #warnning: outputs, hn and cn have been sorted by sentences length so the order is wrong, now to sort them
        if self.use_cuda:
            outputs = outputs.index_select(0, Variable(true_order_ids).cuda())
        else:
            outputs = outputs.index_select(0, Variable(true_order_ids))
        
        hn = torch.cat((hn[0], hn[1]), dim=1)
        cn = torch.cat((cn[0], cn[1]), dim=1)
        #print('hn size and cn size: ', hn.size(), cn.size())
        
        if self.use_cuda:
            hn = hn.index_select(0, Variable(true_order_ids).cuda())
            cn = cn.index_select(0, Variable(true_order_ids).cuda())
        else:
            hn = hn.index_select(0, Variable(true_order_ids))
            cn = cn.index_select(0, Variable(true_order_ids))
            
        return outputs, (hn,cn)

In [7]:
def _inflate(tensor, times, dim):
    """
    Examples::
        >> a = torch.LongTensor([[1, 2], [3, 4]])
        >> a
        1   2
        3   4
        [torch.LongTensor of size 2x2]
        >> b = ._inflate(a, 2, dim=1)
        >> b
        1   2   1   2
        3   4   3   4
        [torch.LongTensor of size 2x4]
    """
    repeat_dims = [1] * tensor.dim()
    repeat_dims[dim] = times
    return tensor.repeat(*repeat_dims)

class Decoder(nn.Module):
    def __init__(self, use_cuda, encoder, hidden_dim, max_length=25):
        super(Decoder, self).__init__()
        
        self.use_cuda = use_cuda
        self.hidden_dim=hidden_dim
        self.input_dim = encoder.input_dim
        self.max_length = max_length
        self.vocab = encoder.vocab
        self.weight = [1]*len(self.vocab.word2token)
        self.weight[self.vocab.word2token['<padding>']]=0
        #self.weight[self.vocab.word2token['<eos>']]=1.01
        #self.weight[self.vocab.word2token['<split>']]=1.01
        
        self.hidden_size = self.hidden_dim
        self.V = len(self.vocab.word2token)
        self.SOS = self.vocab.word2token['<sos>']
        self.EOS = self.vocab.word2token['<eos>']
        self.log_softmax = nn.LogSoftmax(dim=1)
        
        self.lstmcell = torch.nn.LSTMCell(input_size=self.input_dim, hidden_size=self.hidden_dim*2, bias=True)
        
        #embedding
        self.embed=encoder.embed# reference share
        #fcnn: projection for crossentroy loss
        self.fcnn = nn.Linear(in_features = self.hidden_dim*2+hidden_dim*2, out_features = len(self.vocab.word2token))
        
        self.softmax = nn.Softmax(dim=1)
        self.cost_func = nn.CrossEntropyLoss(weight=torch.Tensor(self.weight), reduce=False)
        self.nll_loss = nn.NLLLoss(weight=torch.Tensor(self.weight), reduce=False)

        print('init lookup embedding matrix size: ', self.embed.weight.data.size())
        
        #copy
        out_features_dim=self.hidden_dim
        self.attent_wh = nn.Linear(in_features = self.hidden_dim*2, out_features = out_features_dim, bias = 0)
        self.attent_ws = nn.Linear(in_features = self.hidden_dim*2, out_features = out_features_dim, bias = 1)
        self.tanh = nn.Tanh()
        self.attent_vt = nn.Linear(in_features = out_features_dim, out_features = 1, bias=0)
        
        self.prob_wh = nn.Linear(in_features = self.hidden_dim*2, out_features = 1, bias=0)
        self.prob_ws = nn.Linear(in_features = self.hidden_dim*2, out_features = 1, bias=0)
        self.prob_wx = nn.Linear(in_features = input_dim, out_features = 1, bias=1)
        self.sigmoid = nn.Sigmoid()
    
    def get_context_vec(self, enc_outputs, this_timestep_input, dec_state):
        batch_size = enc_outputs.size(dim = 0)
        
        wh = self.attent_wh(enc_outputs)
        ws = self.attent_ws(dec_state).unsqueeze(dim=1)
#         print('wh, ws size: ', wh.size(), ws.size())
        ws = ws.expand(ws.size(0), wh.size(1), ws.size(2))
#         print('ws size: ', ws.size())
        weight = self.attent_vt(self.tanh(wh+ws))
#         print('weight size: ', weight.size())
        weight = self.softmax(weight.squeeze(dim=2))
#         print('weight size: ', weight.size())
        context_v = torch.bmm(weight.unsqueeze(dim=1), enc_outputs)
#         print('context_v size: ', context_v.size())
        context_v = context_v.squeeze(dim=1)
        return context_v, weight
    
    def copy_mechanism(self, enc_outputs, this_timestep_input, dec_state, inputs_one_hot, context_v, weight):
        batch_size = enc_outputs.size(dim = 0)
        
#         wh = self.attent_wh(enc_outputs)
#         ws = self.attent_ws(dec_state).unsqueeze(dim=1)
# #         print('wh, ws size: ', wh.size(), ws.size())
#         ws = ws.expand(ws.size(0), wh.size(1), ws.size(2))
# #         print('ws size: ', ws.size())
#         weight = self.attent_vt(self.tanh(wh+ws))
# #         print('weight size: ', weight.size())
#         weight = self.softmax(weight.squeeze(dim=2))
# #         print('weight size: ', weight.size())
#         context_v = torch.bmm(weight.unsqueeze(dim=1), enc_outputs)
# #         print('context_v size: ', context_v.size())
#         context_v = context_v.squeeze(dim=1)
        
        p_wh = self.prob_wh(context_v)
        p_ws = self.prob_ws(dec_state)
        p_wx = self.prob_wx(this_timestep_input)
        if_copy = self.sigmoid(p_wh+p_ws+p_wx)
#         if_copy = 0.3*if_copy
#         if_copy = self._tocuda(Variable(torch.ones(batch_size, 1), requires_grad=0))
#         print('if_copy size: ', if_copy.size())
        
        prob_copy = torch.bmm(inputs_one_hot, weight.unsqueeze(dim=2))
        prob_copy = prob_copy.squeeze(dim=2)
#         prob_copy = self._tocuda(Variable(torch.rand(batch_size, len(self.vocab.word2token)), requires_grad=0))
#         prob_copy = self.softmax(prob_copy)

#         print('prob_copy size: ', prob_copy.size())
#         print(torch.sum(prob_copy, dim=1))
#         print(torch.mean(if_copy))
        
#         if random.random()<0.005:
#             print('if_copy mean: ', torch.mean(if_copy))
#             _, max_ids = torch.max(prob_copy, dim=1)
#             print(self.vocab.token2word[max_ids.data[0]], self.vocab.token2word[max_ids.data[1]], self.vocab.token2word[max_ids.data[2]])
            
            
        return if_copy, prob_copy

    def forward(self, enc_outputs, sent_lens, h0_and_c0, labels, inputs, teaching_rate=0.6, is_train=1):
        labels = Variable(labels)
        if self.use_cuda:
            labels = labels.cuda()

        all_loss = 0
        predicts = []
        max_probs=[]
        batch_size = enc_outputs.size(dim = 0)
        final_hidden_states = h0_and_c0[0]
#         print('enc_outputs size:', enc_outputs.size())

        sents_len = enc_outputs.size(1)
        inputs = inputs[:,:sents_len].unsqueeze(dim=2)
        one_hot = torch.FloatTensor(batch_size, sents_len, len(self.vocab.word2token)).zero_()
        one_hot.scatter_(2, inputs, 1)
        one_hot = one_hot.transpose(1,2)
        one_hot = self._tocuda(Variable(one_hot, requires_grad = 0))
#         print('one_hot size: ', one_hot.size())
        
        for ii in range(self.max_length):
            if ii==0:
                zero_timestep_input = Variable(torch.LongTensor([self.vocab.word2token['<sos>']]*batch_size))
                if self.use_cuda:
                    zero_timestep_input = zero_timestep_input.cuda()
                    
                zero_timestep_input = self.embed(zero_timestep_input)#size: batch_size * self.input_dim
                
                last_timestep_hidden_state,cx = self.lstmcell(zero_timestep_input, h0_and_c0)
                #print('last_timestep_hidden_state: ', last_timestep_hidden_state.size(), cx.size())

                #get context vector
                context_vec, weight = self.get_context_vec(enc_outputs=enc_outputs, this_timestep_input=-1, 
                                                            dec_state = last_timestep_hidden_state)
                logits = self.fcnn(torch.cat([last_timestep_hidden_state, context_vec], dim=1))
                
                #copy or not
                copy_control=random.random()
                if copy_control<copy_thres:
                    if_copy, prob_copy = self.copy_mechanism(enc_outputs=enc_outputs, this_timestep_input=zero_timestep_input, 
                                                            dec_state = last_timestep_hidden_state, inputs_one_hot = one_hot, 
                                                            context_v=context_vec,
                                                            weight = weight)
                    score = (1-if_copy)*self.softmax(logits)+if_copy*prob_copy
                    score = torch.clamp(score, min=10**(-30), max=1)

                #for saving time: no training, no loss calculating
                if is_train:
                    if copy_control<copy_thres:
                        loss = self.nll_loss(torch.log(score), labels[:,0])
                    else:
                        loss = self.cost_func(logits, labels[:,0])
                    all_loss+=loss
                
                #get predicts
                if copy_control<copy_thres:
                    _, max_idxs = torch.max(score, dim=1)
                else:
                    _, max_idxs = torch.max(logits, dim=1)
                predicts.append(torch.unsqueeze(max_idxs, dim=0))
                
                
            else:
                if is_train:
                    rand = random.random()
                    if rand<teaching_rate:
                        this_timestep_input = self.embed(labels[:,ii-1])#label teaching, lookup embedding
                    else:
                        this_timestep_input = self.embed(max_idxs)#last_timestep output, and then look up word embedding
                else:
                    this_timestep_input = self.embed(max_idxs)#last_timestep output, and then look up word embedding
                
                last_timestep_hidden_state,cx = self.lstmcell(this_timestep_input, (last_timestep_hidden_state,cx))
                
                #get context vector
                context_vec, weight = self.get_context_vec(enc_outputs=enc_outputs, this_timestep_input=this_timestep_input, 
                                                            dec_state = last_timestep_hidden_state)
                logits = self.fcnn(torch.cat([last_timestep_hidden_state, context_vec], dim=1))
                
                #copy or not
                copy_control=random.random()
                if copy_control<copy_thres:
                    if_copy, prob_copy = self.copy_mechanism(enc_outputs=enc_outputs, this_timestep_input=this_timestep_input, 
                                                            dec_state = last_timestep_hidden_state, inputs_one_hot = one_hot, 
                                                             context_v=context_vec,
                                                            weight = weight)
                    score = (1-if_copy)*self.softmax(logits)+if_copy*prob_copy
                    score = torch.clamp(score, min=10**(-30), max=1)

                #for saving time: no training, no loss calculating
                if is_train:
                    if copy_control<copy_thres:
                        loss = self.nll_loss(torch.log(score), labels[:,ii])
                    else:
                        loss = self.cost_func(logits, labels[:,ii])
                    all_loss+=loss
                
                #get predicts
                if copy_control<copy_thres:
                    _, max_idxs = torch.max(score, dim=1)
                else:
                    _, max_idxs = torch.max(logits, dim=1)
                predicts.append(torch.unsqueeze(max_idxs, dim=0))
                
        predicts = torch.cat(predicts, dim=0)
        predicts = torch.transpose(predicts, 0, 1)
    
        if is_train:  #training
#             all_loss = torch.cat(all_loss, dim=1)
#             all_loss = torch.mean(all_loss, dim=1)
#             loss = torch.mean(all_loss)
            loss = all_loss/self.max_length
    
            #print('loss size: ', loss.size())
            #torch.cuda.empty_cache()
            if self.use_cuda:
                return loss, predicts.data.cpu().tolist()
            else:
                return loss, predicts.data.tolist()
        else:   #testing
            if self.use_cuda:
                return predicts.data.cpu().tolist()
            else:
                return predicts.data.tolist()
#         if is_train:  #training
#             if self.use_cuda:
#                 return all_loss/(self.max_length+1), predicts.data.cpu().numpy()
#             else:
#                 return all_loss/(self.max_length+1), predicts.data.numpy()
#         else:   #testing
#             if self.use_cuda:
#                 return predicts.data.cpu().numpy()
#             else:
#                 return predicts.data.numpy()
    
    def decode_topk_seqs(self, encoder, inputs, input_lens, topk=3):
        enc_outputs, (enc_hn, enc_cn) = encoder(inputs, input_lens)
        batch_size = enc_outputs.size(dim = 0)
        
        #one hot of inputs
        sents_len = enc_outputs.size(1)
        inputs = inputs[:,:sents_len].unsqueeze(dim=2)
        one_hot = torch.FloatTensor(batch_size, sents_len, len(self.vocab.word2token)).zero_()
        one_hot.scatter_(2, inputs, 1)
        one_hot = one_hot.transpose(1,2)
        one_hot = self._tocuda(Variable(one_hot, requires_grad = 0))
        
        metadata = self.decode_by_beamsearch(encoder_hidden=(enc_hn, enc_cn), encoder_outputs=enc_outputs, inputs_one_hot=one_hot,topk = topk)
        results = metadata['topk_sequence']
        results =torch.cat(results, dim = 2)
        results=results.view(batch_size*topk, -1)
        if self.use_cuda:
            results = results.data.cpu().tolist()
        else:
            results = results.data.tolist()
#         results=batch_tokens_remove_eos(results, self.vocab)

#         labels = [x for x in labels for ii in range(topk)]
#         labels = batch_tokens_remove_eos(labels, self.vocab)
#         bleu_scores = batch_tokens_bleu(references=labels, candidates=results, smooth_epsilon=0.01)
        
#         bleu_scores = torch.FloatTensor(bleu_scores).view(batch_size, topk)
#         bleu_max, _ = torch.max(bleu_scores, dim=1)
        
#         bleu_mean = torch.mean(bleu_scores, dim=1).unsqueeze(dim=1)
#         bleu_scores = bleu_scores-bleu_mean
#         bleu_scores = bleu_scores.view(-1)
        
#         bleu_scores = self._tocuda(Variable(bleu_scores, requires_grad = 0))
#         log_probs = metadata['score']
#         log_probs = log_probs.view(batch_size*topk)
#         loss = -torch.dot(log_probs, bleu_scores)/batch_size/topk
#         return loss, results, torch.mean(bleu_mean.squeeze()), torch.mean(bleu_max)

        log_probs = metadata['score']
        log_probs = log_probs.view(batch_size*topk)
        
        return results, log_probs
        
        
        
    def _tocuda(self, var):
        if self.use_cuda:
            return var.cuda()
        else:
            return var
    def decode_by_beamsearch(self, encoder_hidden=None, encoder_outputs=None, inputs_one_hot=None, topk = 10):
        self.k = topk
        batch_size = encoder_outputs.size(dim=0)
        
        self.pos_index = self._tocuda(Variable(torch.LongTensor(range(batch_size)) * self.k).view(-1, 1))

        hidden = tuple([_inflate(h, self.k, 1).view(batch_size*self.k, -1) for h in encoder_hidden])
        #print('hidden0 size: (%s, %s)'%(hidden[0].size(), hidden[1].size()))

        encoder_outputs = _inflate(encoder_outputs, self.k, 1).view(batch_size*self.k, encoder_outputs.size(1), encoder_outputs.size(2))
        inputs_one_hot = _inflate(inputs_one_hot, self.k, 1).view(batch_size*self.k, inputs_one_hot.size(1), inputs_one_hot.size(2))
        
        # Initialize the scores; for the first step,
        # ignore the inflated copies to avoid duplicate entries in the top k
        sequence_scores = torch.Tensor(batch_size * self.k, 1)
        sequence_scores.fill_(-float('Inf'))
        sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0)
        sequence_scores = self._tocuda(Variable(sequence_scores))

        # Initialize the input vector
        input_var = self._tocuda(Variable(torch.LongTensor([self.SOS] * batch_size * self.k)))

        # Store decisions for backtracking
        stored_outputs = list()
        stored_scores = list()
        stored_predecessors = list()
        stored_emitted_symbols = list()
        stored_hidden = list()

        for ii in range(0, self.max_length):
            # Run the RNN one step forward
            #print('setp: %s'%ii)
            input_vec = self.embed(input_var)
            #print('input_var and input_vec size: ', input_var.size(), input_vec.size())
            
            hidden = self.lstmcell(input_vec, hidden)
            #print('hidden size: (%s, %s)'%(hidden[0].size(), hidden[1].size()))
            
            #get context vector
            context_vec, weight = self.get_context_vec(enc_outputs=encoder_outputs, this_timestep_input=-1, 
                                                            dec_state = hidden[0])
            
            logits = self.fcnn(torch.cat([hidden[0], context_vec], dim=1))
#             print('logits size', logits.size())
#             print(encoder_outputs.size())
#             print(input_vec.size())
#             print(hidden[0].size())
#             print(inputs_one_hot.size())
            if_copy, prob_copy = self.copy_mechanism(enc_outputs=encoder_outputs, this_timestep_input=input_vec.squeeze(dim=1), 
                                                     dec_state = hidden[0], inputs_one_hot = inputs_one_hot, 
                                                     context_v = context_vec, weight=weight)
#             print('if_copy size', if_copy.size(), 'prob_copy size', prob_copy.size())
            
            score = (1-if_copy)*self.softmax(logits)+if_copy*prob_copy
            score = torch.clamp(score, min=10**(-30), max=1)
#             print('score size: ', score.size())

            # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1)
            sequence_scores = _inflate(sequence_scores, self.V, 1)
            sequence_scores += torch.log(score).squeeze(1)
            scores, candidates = sequence_scores.view(batch_size, -1).topk(self.k, dim=1)

            # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
            input_var = (candidates % self.V).view(batch_size * self.k, 1)
            sequence_scores = scores.view(batch_size * self.k, 1)

            # Update fields for next timestep
            predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view(batch_size * self.k, 1)
            if isinstance(hidden, tuple):
                hidden = tuple([h.index_select(0, predecessors.squeeze()) for h in hidden])
            else:
                hidden = hidden.index_select(0, predecessors.squeeze())

            # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
            stored_scores.append(sequence_scores.clone())
            eos_indices = input_var.data.eq(self.EOS)
            if eos_indices.nonzero().dim() > 0:
                sequence_scores.data.masked_fill_(eos_indices, -float('inf'))

            # Cache results for backtracking
            stored_predecessors.append(predecessors)
            stored_emitted_symbols.append(input_var)
#             stored_hidden.append(hidden)

        # Do backtracking to return the optimal values
        output, h_t, h_n, s, l, p = self._backtrack(hidden,
                                                    stored_predecessors, stored_emitted_symbols,
                                                    stored_scores, batch_size, self.hidden_size)

        metadata = {}

        metadata['score'] = s
        metadata['topk_length'] = l
        metadata['topk_sequence'] = p
        metadata['length'] = [seq_len[0] for seq_len in l]
        metadata['sequence'] = [seq[0] for seq in p]
        
#         torch.cuda.empty_cache()
        
        return metadata

    def _backtrack(self, hidden, predecessors, symbols, scores, b, hidden_size):
        """Backtracks over batch to generate optimal k-sequences.

        Args:
            nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network
            nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network
            predecessors [(batch*k)] * sequence_length: A Tensor of predecessors
            symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens
            scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1]
            b: Size of the batch
            hidden_size: Size of the hidden state

        Returns:
            output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences.

            score [batch, k]: A list containing the final scores for all top-k sequences

            length [batch, k]: A list specifying the length of each sequence in the top-k candidates

            p (batch, k, sequence_len): A Tensor containing predicted sequence
        """

        lstm = isinstance(hidden, tuple)

        # initialize return variables given different types
        output = list()
        h_t = list()
        p = list()
        # Placeholder for last hidden state of top-k sequences.
        # If a (top-k) sequence ends early in decoding, `h_n` contains
        # its hidden state when it sees EOS.  Otherwise, `h_n` contains
        # the last hidden state of decoding.
        if lstm:
            state_size = hidden[0].size()
            h_n = tuple([torch.zeros(state_size), torch.zeros(state_size)])
        else:
            h_n = torch.zeros(nw_hidden[0].size())
        l = [[self.max_length] * self.k for _ in range(b)]  # Placeholder for lengths of top-k sequences
                                                                # Similar to `h_n`

        # the last step output of the beams are not sorted
        # thus they are sorted here
        sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k)
        # initialize the sequence scores with the sorted last step beam scores
        s = sorted_score.clone()

        batch_eos_found = [0] * b   # the number of EOS found
                                    # in the backward loop below for each batch

        t = self.max_length - 1
        # initialize the back pointer with the sorted order of the last step beams.
        # add self.pos_index for indexing variable with b*k as the first dimension.
        t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k)
        while t >= 0:
            # Re-order the variables with the back pointer
            current_symbol = symbols[t].index_select(0, t_predecessors)
            # Re-order the back pointer of the previous step with the back pointer of
            # the current step
            t_predecessors = predecessors[t].index_select(0, t_predecessors).squeeze()

            # This tricky block handles dropped sequences that see EOS earlier.
            # The basic idea is summarized below:
            #
            #   Terms:
            #       Ended sequences = sequences that see EOS early and dropped
            #       Survived sequences = sequences in the last step of the beams
            #
            #       Although the ended sequences are dropped during decoding,
            #   their generated symbols and complete backtracking information are still
            #   in the backtracking variables.
            #   For each batch, everytime we see an EOS in the backtracking process,
            #       1. If there is survived sequences in the return variables, replace
            #       the one with the lowest survived sequence score with the new ended
            #       sequences
            #       2. Otherwise, replace the ended sequence with the lowest sequence
            #       score with the new ended sequence
            #
            eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero()
            if eos_indices.dim() > 0:
                for i in range(eos_indices.size(0)-1, -1, -1):
                    # Indices of the EOS symbol for both variables
                    # with b*k as the first dimension, and b, k for
                    # the first two dimensions
                    idx = eos_indices[i]
                    b_idx = int(idx[0] / self.k)
                    # The indices of the replacing position
                    # according to the replacement strategy noted above
                    res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1
                    batch_eos_found[b_idx] += 1
                    res_idx = b_idx * self.k + res_k_idx

                    # Replace the old information in return variables
                    # with the new ended sequence information
                    t_predecessors[res_idx] = predecessors[t][idx[0]]

                    current_symbol[res_idx, :] = symbols[t][idx[0]]
                    s[b_idx, res_k_idx] = scores[t][idx[0]]
                    l[b_idx][res_k_idx] = t + 1

            # record the back tracked results
            p.append(current_symbol)
            t -= 1

        # Sort and re-order again as the added ended sequences may change
        # the order (very unlikely)
        s, re_sorted_idx = s.topk(self.k)
        for b_idx in range(b):
            l[b_idx] = [l[b_idx][k_idx.data[0]] for k_idx in re_sorted_idx[b_idx,:]]

        re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k)

        # Reverse the sequences and re-order at the same time
        # It is reversed because the backtracking happens in reverse time order
#         output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)]
        p = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(p)]
        #    --- fake output ---
        output = None
        #    --- fake ---
        return output, h_t, h_n, s, l, p

    def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')):
            score[idx] = masking_score

    def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')):
        if len(idx.size()) > 0:
            indices = idx[:, 0]
            tensor.index_fill_(dim, indices, masking_score)

In [8]:
class Seq2Seq(nn.Module):
    def __init__(self, use_cuda, input_dim, hidden_dim, vocab, max_length = 25):
        super(Seq2Seq, self).__init__()
        
        self.use_cuda = use_cuda
        self.enc = Encoder(use_cuda=use_cuda, hidden_dim=hidden_dim, input_dim=input_dim, vocab=vocab)
        self.dec = Decoder(use_cuda=use_cuda, encoder=self.enc, hidden_dim=hidden_dim, max_length=max_length)
        if use_cuda:
            self.enc = self.enc.cuda()
            self.dec = self.dec.cuda()
    def forward(self, inputs, input_lens, labels, is_train=1, teaching_rate=1):
        enc_outputs, (enc_hn, enc_cn) = self.enc(torch.LongTensor(inputs), torch.LongTensor(input_lens))
        if is_train:
            loss, predicts = self.dec(enc_outputs = enc_outputs, 
                                    h0_and_c0=(enc_hn, enc_cn), 
                                    sent_lens=input_lens,
                                    labels=torch.LongTensor(labels), 
                                    is_train=1, 
                                    teaching_rate = 1,
                                    inputs = inputs
                                    )
            return loss, predicts
        else:
            predicts = self.dec(enc_outputs = enc_outputs, 
                                h0_and_c0=(enc_hn, enc_cn), 
                                sent_lens=input_lens,
                                labels=torch.LongTensor(labels), 
                                is_train=0, 
                                teaching_rate = 1,
                                inputs = inputs
                                )
            return predicts
#     def train_using_rl(self, inputs, input_lens, labels, is_train=1, teaching_rate=1):
#         enc_outputs, (enc_hn, enc_cn) = self.enc(torch.LongTensor(inputs), torch.LongTensor(input_lens))
#         loss, predicts, bleu_mean = self.dec.train_using_rl_2(enc_outputs = enc_outputs, 
#                                                 h0_and_c0=(enc_hn, enc_cn), 
#                                                 sent_lens=input_lens,
#                                                 labels=labels,
#                                                 is_train=1, 
#                                                 teaching_rate = 1
#                                                 )
#         return loss, predicts, bleu_mean

    def tocuda(self, x):
        if self.use_cuda:
            return x.cuda()
        else:
            return x
        
    def train_using_reward(self, inputs, input_lens, reconstruct_labels, reconstruct_model, language_model, topk=3, loss_ratio=0.5):
        dec_seqs, log_probs = self.dec.decode_topk_seqs(self.enc, inputs, input_lens, topk=topk)
#         enc_outputs, (enc_hn, enc_cn) = self.enc(torch.LongTensor(inputs), torch.LongTensor(input_lens))
#         results = self.dec.decode_no_labels(enc_outputs=enc_outputs, h0_and_c0=(enc_hn, enc_cn), topk=topk)
        simple_sent1s, simple_sent2s = seqs_split(dec_seqs, self.enc.vocab)
        
        lm_input1s, lm_input1_lens, lm_label1s = get_lm_inputs_and_labels(simple_sent1s, self.enc.vocab, self.dec.max_length)
        simple_sent1s_ppl = language_model.get_sentences_ppl(torch.LongTensor(lm_input1s), 
                                                      torch.LongTensor(lm_input1_lens), 
                                                      torch.LongTensor(lm_label1s)
                                                    )
        lm_input2s, lm_input2_lens, lm_label2s = get_lm_inputs_and_labels(simple_sent2s, self.enc.vocab, self.dec.max_length)
        simple_sent2s_ppl = language_model.get_sentences_ppl(torch.LongTensor(lm_input2s), 
                                                      torch.LongTensor(lm_input2_lens), 
                                                      torch.LongTensor(lm_label2s)
                                                    )
        
        simple_inputs, simple_input_lens = simple_sents_concat(simple_sent1s, simple_sent2s, self.enc.vocab, self.dec.max_length)
        #reconstruct labels
        reconstruct_loss, predicts = reconstruct_model.forward(torch.LongTensor(simple_inputs), 
                                     torch.LongTensor(simple_input_lens), 
                                     labels=reconstruct_labels, 
                                     is_train=1, teaching_rate=1)
        
        #rm_rewards: reconstruct model rewards
        #lm_rewards: language model rewards
        rm_rewards=-reconstruct_loss.data
        lm_rewards=(1/self.tocuda(torch.Tensor(simple_sent1s_ppl))+1/self.tocuda(torch.Tensor(simple_sent2s_ppl)))/2
        
        rm_rewards_mean = torch.mean(rm_rewards.view(-1, topk), dim=1)
        lm_rewards_mean = torch.mean(lm_rewards.view(-1, topk), dim=1)
        rm_rewards = rm_rewards.view(-1, topk) - rm_rewards_mean.unsqueeze(dim=1)
        lm_rewards = lm_rewards.view(-1, topk) - lm_rewards_mean.unsqueeze(dim=1)
        
        rm_rewards = rm_rewards.view(-1)
        lm_rewards = lm_rewards.view(-1)
        
        #sum both rewards up
        rewards = loss_ratio*rm_rewards+(1-loss_ratio)*lm_rewards
        rewards = Variable(rewards, requires_grad=0)
        
        #regarding rewards as weights of every seq
        loss = -torch.dot(log_probs, rewards)/log_probs.size(0)
        
#         labels = [x for x in labels for ii in range(topk)]
#         labels = batch_tokens_remove_eos(labels, self.vocab)
#         bleu_scores = batch_tokens_bleu(references=labels, candidates=results, smooth_epsilon=0.01)
        
#         bleu_scores = torch.FloatTensor(bleu_scores).view(batch_size, topk)
#         bleu_max, _ = torch.max(bleu_scores, dim=1)
        
#         bleu_mean = torch.mean(bleu_scores, dim=1).unsqueeze(dim=1)
#         bleu_scores = bleu_scores-bleu_mean
#         bleu_scores = bleu_scores.view(-1)
        
#         bleu_scores = self._tocuda(Variable(bleu_scores, requires_grad = 0))
        
#         log_probs = metadata['score']
#         log_probs = log_probs.view(batch_size*topk)
    
#         loss = -torch.dot(log_probs, bleu_scores)/batch_size/topk
        
        return loss, reconstruct_loss, torch.mean(rm_rewards_mean), torch.mean(lm_rewards_mean)
    
    


In [None]:
lm_hidden_dim=512
lm_input_dim=300
use_cuda=1

language_model = LanguageModel(use_cuda = use_cuda, input_dim = lm_input_dim, hidden_dim = lm_hidden_dim, vocab = vocab)
#512
model_path = './models_language_model/time-[2019-02-26-13-18-56]-info=[language_model]-loss=4.003012180-bleu=-1.0000-hidden_dim=512-input_dim=300-epoch=24-batch_size=100-batch_id=[1-[of]-9899]-lr=0.0050'
#2048
# model_path = './models_language_model/time-[2019-02-28-07-04-08]-info=[language_model]-loss=3.475848675-bleu=-1.0000-hidden_dim=2048-input_dim=300-epoch=4-batch_size=100-batch_id=[1-[of]-9899]-lr=0.0050'
# #1024
# model_path = './models_language_model/time-[2019-02-27-21-58-23]-info=[language_model]-loss=4.111208439-bleu=-1.0000-hidden_dim=1024-input_dim=300-epoch=6-batch_size=100-batch_id=[1-[of]-9899]-lr=0.0050'

pre_train = torch.load(model_path, map_location='cpu')
language_model.load_state_dict(pre_train)

if use_cuda:
    language_model = language_model.cuda()
    
language_model.eval()

print('finish loading pre-train weight for language model.')



use_cuda = 1
hidden_dim = 256
input_dim = 100
lr=0.005

split_model = Seq2Seq(use_cuda = use_cuda, input_dim = input_dim, hidden_dim = hidden_dim, 
                          vocab = vocab, max_length = 61)

fusion_model = Seq2Seq(use_cuda = use_cuda, input_dim = input_dim, hidden_dim = hidden_dim, 
                          vocab = vocab, max_length = 51)
#pre train para
#20per
split_model_path = './models_saved/time-[2019-03-24-21-30-26]-info=[pre-trained_split_model-20per]-loss=0.543618917-bleu=0.6642-hidden_dim=256-input_dim=100-epoch=2-batch_size=100-batch_id=[1-[of]-1979]-lr=0.0050'
fusion_model_path = './models_saved/time-[2019-03-24-21-30-31]-info=[pre-trained_fusion_model-20per]-loss=0.430331796-bleu=0.7581-hidden_dim=256-input_dim=100-epoch=2-batch_size=100-batch_id=[1-[of]-1979]-lr=0.0050'

#10per
split_model_path = './models_saved/time-[2019-04-01-18-26-31]-info=[pre-trained_split_model-10per]-loss=0.421364784-bleu=0.6692-hidden_dim=256-input_dim=100-epoch=3-batch_size=100-batch_id=[501-[of]-989]-lr=0.0050'
fusion_model_path = './models_saved/time-[2019-04-01-18-26-36]-info=[pre-trained_fusion_model-10per]-loss=0.293084979-bleu=0.7202-hidden_dim=256-input_dim=100-epoch=3-batch_size=100-batch_id=[501-[of]-989]-lr=0.0050'

# #5per
# split_model_path = './models_saved/time-[2019-03-25-13-32-25]-info=[pre-trained_split_model-5per]-loss=0.368953973-bleu=0.6889-hidden_dim=256-input_dim=100-epoch=5-batch_size=100-batch_id=[1-[of]-494]-lr=0.0050'
# fusion_model_path = './models_saved/time-[2019-03-25-13-32-30]-info=[pre-trained_fusion_model-5per]-loss=0.342007816-bleu=0.7406-hidden_dim=256-input_dim=100-epoch=5-batch_size=100-batch_id=[1-[of]-494]-lr=0.0050'

# #unspuer
# split_model_path = './models_saved/time-[2019-03-30-13-37-54]-info=[pre-trained_split_model-unsuper]-loss=0.000472637-bleu=0.6470-hidden_dim=256-input_dim=100-epoch=0-batch_size=100-batch_id=[6501-[of]-7919]-lr=0.0050'
# fusion_model_path = './models_saved/time-[2019-03-30-13-37-59]-info=[pre-trained_fusion_model-unsuper]-loss=0.000631156-bleu=0.5244-hidden_dim=256-input_dim=100-epoch=0-batch_size=100-batch_id=[6501-[of]-7919]-lr=0.0050'


pre_train = torch.load(split_model_path, map_location='cpu')
split_model.load_state_dict(pre_train)
pre_train = torch.load(fusion_model_path, map_location='cpu')
fusion_model.load_state_dict(pre_train)

if use_cuda:
    split_model = split_model.cuda()
    fusion_model = fusion_model.cuda()
    
split_optimizer = optim.Adam(filter(lambda p: p.requires_grad, split_model.parameters()), lr=lr)
fusion_optimizer = optim.Adam(filter(lambda p: p.requires_grad, fusion_model.parameters()), lr=lr)

# set_model_grad(fusion_model, False)

finish loading pre-train weight for language model.
init lookup embedding matrix size:  torch.Size([44380, 100])
init lookup embedding matrix size:  torch.Size([44380, 100])


In [None]:
batch_size=40
split_train_set_size=int(len(split_train_set_inputs)/1)
epochs=10000
train_bleu_mean=-1
train_bleu_max=-1
topk=2
loss_ratio=0.3

sup_bsize=30
dataset_times = int(split_train_set_size/len(split_train_set_inputs_supervised))

#batch_size=35, topk=3  or  batch_size=14, topk=6 or  
start_time = time.time()

asy_cnt=40

def model_train(epoch, batch_size, train_set_size):
    batch_id = 0
    valid_bleu = 0
    for start_idx in range(0, train_set_size-batch_size+1, batch_size):
#         if batch_id<=1199 and epoch==0:
#             batch_id+=1
#             continue


#         now = int(round(time.time()*1000))
#         time_stamp = time.strftime(' --->  starting time-[%Y-%m-%d-%H-%M-%S]-',time.localtime(now/1000))
#         print(time_stamp)
        
        #supervised learning
        if batch_id%asy_cnt<=int(asy_cnt/2):
            set_model_grad(split_model, True)
            set_model_grad(fusion_model, False)
            split_optimizer.zero_grad()#clear  
            sup_idx = (batch_id*sup_bsize)%(int(len(split_train_set_inputs_supervised)/2)-1-sup_bsize)
            split_loss, predicts = split_model.forward(torch.LongTensor(split_train_set_inputs_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         torch.LongTensor(split_train_set_input_lens_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         labels=torch.LongTensor(split_train_set_labels_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         is_train=1, teaching_rate=1)
            split_loss=torch.mean(split_loss)
            split_loss.backward()#retain_graph=True)
            split_optimizer.step()
        else:
            set_model_grad(fusion_model, True)
            set_model_grad(split_model, False)
            fusion_optimizer.zero_grad()#clear
            sup_idx = (batch_id*sup_bsize)%(int(len(split_train_set_inputs_supervised)/2)-1-sup_bsize)
            fusion_loss, predicts = fusion_model.forward(torch.LongTensor(fusion_train_set_inputs_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         torch.LongTensor(fusion_train_set_input_lens_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         labels=torch.LongTensor(fusion_train_set_labels_supervised[sup_idx:sup_idx+sup_bsize]), 
                                         is_train=1, teaching_rate=1)
            fusion_loss = torch.mean(fusion_loss)
            fusion_loss.backward()#retain_graph=True)
            fusion_optimizer.step()
        
        
        #unsupervised learning
        if batch_id%asy_cnt<=int(asy_cnt/2):
#             a=time.time()
            end_idx = start_idx + batch_size
            split_optimizer.zero_grad()#clear
            total_loss, reconstruct_loss, rm_rewards, lm_rewards=split_model.train_using_reward(inputs=torch.LongTensor(split_train_set_inputs[start_idx:end_idx]), 
                                   input_lens=torch.LongTensor(split_train_set_input_lens[start_idx:end_idx]), 
                                   reconstruct_labels=torch.LongTensor(duplicate_reconstruct_labels(fusion_pseudo_train_set_labels[start_idx:end_idx],topk)), 
                                   reconstruct_model=fusion_model, 
                                   language_model=language_model, 
                                   topk=topk, loss_ratio=loss_ratio)
            reconstruct_loss = torch.mean(reconstruct_loss)
            total_loss.backward()#retain_graph=True)
            split_optimizer.step()
#             print('split: all time: ', time.time()-a)
        else:
#             a=time.time()
            end_idx = start_idx + batch_size
            fusion_optimizer.zero_grad()#clear
            total_loss, reconstruct_loss, rm_rewards, lm_rewards=split_model.train_using_reward(inputs=torch.LongTensor(split_train_set_inputs[start_idx:end_idx]), 
                                   input_lens=torch.LongTensor(split_train_set_input_lens[start_idx:end_idx]), 
                                   reconstruct_labels=torch.LongTensor(duplicate_reconstruct_labels(fusion_pseudo_train_set_labels[start_idx:end_idx],topk)), 
                                   reconstruct_model=fusion_model, 
                                   language_model=language_model, 
                                   topk=topk, loss_ratio=loss_ratio)
            reconstruct_loss = loss_ratio*torch.mean(reconstruct_loss)
            reconstruct_loss.backward()#retain_graph=True)
            fusion_optimizer.step()
#             print('fusion: all time: ', time.time()-a)
        #update batch_id
        batch_id+=1
        #timestamp
#         now = int(round(time.time()*1000))
#         time_stamp = time.strftime('time-[%Y-%m-%d-%H-%M-%S]-',time.localtime(now/1000))
#         print(time_stamp)

#         torch.cuda.empty_cache()
        #
        if batch_id%20==1:
            split_model.eval()
            fusion_model.eval()
            set_model_grad(split_model, False)
            set_model_grad(fusion_model, False)
            sample_num = 5
            rand_idx = random.randint(0, train_set_size-sample_num-1)
            
            print('--------split model training sampling display--------')
            #teaching forcing
            loss_, predicts = split_model.forward(torch.LongTensor(split_train_set_inputs[rand_idx:rand_idx+sample_num]), 
                                             torch.LongTensor(split_train_set_input_lens[rand_idx:rand_idx+sample_num]), 
                                             labels=torch.LongTensor(split_pseudo_train_set_labels[rand_idx:rand_idx+sample_num]), 
                                             is_train=1, teaching_rate=1)
            del loss_
            
            predicts = batch_tokens_remove_eos(predicts, vocab)
            labels = batch_tokens_remove_eos(split_pseudo_train_set_labels[rand_idx:rand_idx+sample_num], vocab)
            
            predicts = batch_tokens2words(predicts, vocab)
            labels = batch_tokens2words(labels, vocab)
            
            predicts_sents = batch_words2sentence(predicts)
            labels_sents = batch_words2sentence(labels)
            
            for (predict_sent, label_sent) in zip(predicts_sents, labels_sents):
                print(' 1----> ', predict_sent)
                print(' 2----> ', label_sent)
                print('\n')
            
            now = int(round(time.time()*1000))
            time_stamp = time.strftime('time-[%Y-%m-%d-%H-%M-%S]-',time.localtime(now/1000))
            info_stamp = 'info=[{:s}]-total_loss={:2.9f}-rec_loss={:2.9f}-lm_rewards={:5.4f}-hidden_dim={:n}-input_dim={:n}-epoch={:n}-batch_size={:n}-batch_id=[{:n}-[of]-{:n}]-lr={:1.4f}'.format(
                              'split_model', total_loss.data[0], reconstruct_loss.data[0], lm_rewards, 
                            hidden_dim, input_dim, epoch, batch_size, batch_id, int(train_set_size/batch_size), lr)
            print(time_stamp, info_stamp)
            
            if batch_id%80==1:
                #ground truth
#                 rand_idx=random.randint(0, len(split_valid_set_inputs)-batch_size-1-1)
                rand_idx=2333
                loss_, predicts = split_model.forward(torch.LongTensor(split_valid_set_inputs[rand_idx:rand_idx+batch_size]), 
                                                 torch.LongTensor(split_valid_set_input_lens[rand_idx:rand_idx+batch_size]), 
                                                 labels=torch.LongTensor(split_pseudo_valid_set_labels[rand_idx:rand_idx+batch_size]), 
                                                 is_train=1, teaching_rate=1)
                del loss_
#                 predicts = batch_tokens_remove_eos(predicts, vocab)
#                 labels = batch_tokens_remove_eos(split_pseudo_valid_set_labels[rand_idx:rand_idx+batch_size], vocab)
                
#                 bleu_scores = batch_tokens_bleu(references=labels, candidates=predicts, smooth_epsilon=0.001)
                #split version
                bleu_scores = batch_tokens_bleu_split_version(references=split_pseudo_valid_set_labels[rand_idx:rand_idx+batch_size], 
                                                              candidates=predicts, smooth_epsilon=0.001, vocab=vocab)

                valid_bleu = 0
                for x in bleu_scores:
                    valid_bleu+=x
                valid_bleu/=len(bleu_scores)
                
                #beam search
                dec_seqs, log_probs = split_model.dec.decode_topk_seqs(split_model.enc, 
                                                                       inputs=torch.LongTensor(split_valid_set_inputs[rand_idx:rand_idx+batch_size]), 
                                                                         input_lens=torch.LongTensor(split_valid_set_input_lens[rand_idx:rand_idx+batch_size]),
                                                                         topk=topk)
                predicts = []
                for ii in range(len(dec_seqs)):
                    if ii%topk==0:
                        predicts.append(dec_seqs[ii])
               
                bleu_scores = batch_tokens_bleu_split_version(references = split_pseudo_valid_set_labels[rand_idx:rand_idx+batch_size],
                                                             candidates = predicts,
                                                             smooth_epsilon=0.001,
                                                             vocab=vocab)
                valid_bleu_beam_search=0
                for x in bleu_scores:
                    valid_bleu_beam_search+=x
                valid_bleu_beam_search/=len(bleu_scores)


                info_stamp = 'info=[{:s}]-total_loss={:2.9f}-rec_loss={:2.9f}-lm_rewards={:5.4f}-bleu={:1.4f}-bleu_bs={:1.4f}-hidden_dim={:n}-input_dim={:n}-epoch={:n}-batch_size={:n}-batch_id=[{:n}-[of]-{:n}]-lr={:1.4f}-loss_ratio={:1.4f}'.format(
                              'split_model-semi-10per', total_loss.data[0], reconstruct_loss.data[0], lm_rewards, valid_bleu, valid_bleu_beam_search, 
                            hidden_dim, input_dim, epoch, batch_size, batch_id, int(train_set_size/batch_size), lr, loss_ratio)
                
                print(info_stamp, valid_bleu, valid_bleu_beam_search)
                
                now = int(round(time.time()*1000))
                time_stamp = time.strftime('time-[%Y-%m-%d-%H-%M-%S]-',time.localtime(now/1000))
                torch.save(split_model.state_dict(), ''.join(['./models_saved/', time_stamp, info_stamp]))
                torch.save(fusion_model.state_dict(), ''.join(['./models_saved/', time_stamp, 'info=[fusion_model-semi-10per]']))
            set_model_grad(split_model, True)
            set_model_grad(fusion_model, True)
            split_model.train()
            fusion_model.train()
#             torch.cuda.empty_cache()
for epoch in range(epochs):
    model_train(epoch, batch_size, split_train_set_size)
    
print('running time: %.2f mins'%((time.time()-start_time)/60))

--------split model training sampling display--------
 1---->  <low_freq> continued to have the companies perform near each other . he <split> he hoped to reunite the companies , but ultimately was unsuccessful .
 2---->  <low_freq> continued to have the companies perform near each other ; . <split> he hoped to reunite the companies , but ultimately was unsuccessful .


 1---->  huron centre was a federal federal electoral district represented in the canadian house <split> it of commons from and located in the province of ontario .
 2---->  huron centre was a former federal electoral district represented in the canadian . <split> house of commons , and located in the province of ontario .


 1---->  huron county was continued for electoral purposes in 1845 . and the district itself ( which <split> existed existed for judicial and municipal purposes ) was abolished at the beginning of 1850 .
 2---->  huron county was continued for electoral purposes in 1845 , and the district itself ( .

--------split model training sampling display--------
 1---->  it 's rumored that this team will start a series of their own , the grand prix ( world ) <split> it championship a series that will be ruled by <low_freq> ( the teams ) and not by fia .
 2---->  it 's rumored that this team will start a series of their own , the grand prix ( world . <split> ) championship a series that will be ruled by <low_freq> ( the teams ) and not by fia .


 1---->  it 's said that <low_freq> '' goddess of the earth '' <low_freq> the women on the islands first to stop el from ruling it all <split> that , that 's why the <low_freq> is ruled by women , leadership passes from grandmother to mother and from mother to daughter .
 2---->  it 's said that <low_freq> '' goddess of the earth '' <low_freq> the women on the islands first to stop el from ruling it . <split> all , that 's why the <low_freq> is ruled by women , leadership passes from grandmother to mother and from mother to daughter .


 1---->  it 

--------split model training sampling display--------
 1---->  <low_freq> mansion , formerly known as lamont mansion , is a now derelict mansion - house <split> the on the shores of west bay lake in <low_freq> county , north east wisconsin .
 2---->  <low_freq> mansion , formerly known as lamont mansion , is a now derelict mansion - . <split> house on the shores of west bay lake in <low_freq> county , north east wisconsin .


 1---->  summit , a major station is next with two high <split> the - and the station building above the tracks .
 2---->  summit , a major station is next with two . <split> high platforms and the station building above the tracks .


 1---->  summit airport is a public airport in located summit , <low_freq> - <low_freq> borough , <split> the it is about six miles south - southwest of cantwell , alaska ,
 2---->  summit airport is a public airport in located summit , <low_freq> - <low_freq> borough . <split> , alaska , about six miles south - southwest of cantwel

--------split model training sampling display--------
 1---->  this movie stars fresh actors y and y and the well - trained <low_freq> p. p '' before the film <split> the , sir alongside director - sir tra - sir tra will be a massive breakthrough in my career .
 2---->  this movie stars fresh actors y and y and the well - trained <low_freq> p. p stated before the . <split> film ' acting alongside director - sir tra - sir , will be a massive breakthrough in my career .


 1---->  this movie takes place in rural wales with a large cast of welsh characters . but <split> it was actually filmed in hollywood with canadian , american , irish and english actors .
 2---->  this movie takes place in rural wales with a large cast of welsh characters , . <split> but was actually filmed in hollywood with canadian , american , irish and english actors .


 1---->  this movie touches on the difficulties of marriage , however , few critics considered <split> few '' <low_freq> <low_freq> '' a daring or

--------split model training sampling display--------
 1---->  in the parking lot , ray fights tony until the police arrive . and he <split> he tries to convince them tony did the robbery and set him up .
 2---->  in the parking lot , ray fights tony until the police arrive , and . <split> he tries to convince them tony did the robbery and set him up .


 1---->  in the parliament , clinton used to comment , proudly , about johnson 's indian management . whereas , once <split> once , he was <low_freq> by <low_freq> , about the indian mistresses , who were <low_freq> by johnson .
 2---->  in the parliament , clinton used to comment , proudly , about johnson 's indian management , whereas , . <split> once , he was <low_freq> by <low_freq> , about the indian mistresses , who were had by johnson .


 1---->  in the parliament , which sat first on business , 12 november 1747 , he was chosen for <low_freq> . and died <split> he member for lyme regis , 28 may 1757 , aged fifty nine , was buri

--------split model training sampling display--------
 1---->  a person throws a rolling pin at them . <split> they they start rolling on the wires .
 2---->  a person throws a rolling pin at them . <split> then they start rolling on the wires .


 1---->  a person under the age of 21 consuming non-alcoholic beer is subject to survive for <split> for arrest drinking <low_freq> are more restrictive than those of most other states .
 2---->  a person under the age of 21 consuming non-alcoholic beer is subject to arrest . <split> for underage drinking <low_freq> are more restrictive than those of most other states .


 1---->  a person viewing by cable or satellite may not know what kind of organization is responsible for the program , especially <split> especially if it is syndicated , so what seems to be a station or a network or be neither .
 2---->  a person viewing by cable or satellite may not know what kind of organization is responsible for the program , . <split> especially if it

--------split model training sampling display--------
 1---->  <low_freq> was educated as a cadet and joined the prussian army in 1737 . because of <split> because his small body size he was deployed at <low_freq> for garrison service only .
 2---->  <low_freq> was educated as a cadet and joined the prussian army in 1737 , because . <split> of his small body size he was deployed at <low_freq> for garrison service only .


 1---->  <low_freq> loses her final appeal , and is apparently at peace with her impending death . even enlisting the help of tim mcmanus <split> hanlon , the tells her that hanlon has been murdered and promises her that he will find out who did it .
 2---->  <low_freq> loses her final appeal , and is apparently at peace with her impending death , even enlisting the help of tim . <split> mcmanus , who tells her that hanlon has been murdered and promises her that he will find out who did it .


 1---->  <low_freq> then drove in the go - ahead run with a double in the t

--------split model training sampling display--------
 1---->  on the 4th of december 2013 , the romanian ministry for european funds sent to the european commission a request to the the financing of this <split> for was to for a total of 210 million euro ( of which 142 million euro from the <low_freq> mechanism the <low_freq> mechanism ) .
 2---->  on the 4th of december 2013 , the romanian ministry for european funds sent to the european commission a request to evaluate the financing of . <split> this sector , for a total of 210 million euro ( of which 142 million euro from the eu through the <low_freq> mechanism ) .


 1---->  on the 4th of november <low_freq> <low_freq> won the silver medal for pakistan in the miss earth 2012 trivia <split> she challenge and she received an award for press favorite and won the title for miss <low_freq> .
 2---->  on the 4th of november <low_freq> <low_freq> won the silver medal for pakistan in the miss earth 2012 . <split> trivia challenge and she 

--------split model training sampling display--------
 1---->  the following year she appeared in '' the jerk '' as the mother to steve martin 's <low_freq> received mostly guest spots on tv shows including <split> she '' fantasy island '' , '' the <low_freq> '' , '' amazing stories '' and '' tales from the <low_freq> '' among others .
 2---->  the following year she appeared in '' the jerk '' as the mother to steve martin 's <low_freq> received mostly guest spots on tv shows . <split> including '' fantasy island '' , '' the <low_freq> '' , '' amazing stories '' and '' tales from the <low_freq> '' among others .


 1---->  the following year she joined the ziegfeld follies . <split> she she was also in earl carroll <low_freq> .
 2---->  the following year she joined the ziegfeld follies . <split> , and was also in earl carroll <low_freq> .


 1---->  the following year she left her long - time producer mario <low_freq> and recorded with <low_freq> <low_freq> her masterpiece <split> she

--------split model training sampling display--------
 1---->  bedford is home of the national d - day memorial . it was selected for that honor . <split> it it lost more residents per capita in the normandy landings community any other american community .
 2---->  bedford is home of the national d - day memorial ; it was selected for that honor . <split> because it lost more residents per capita in the normandy landings than any other american community .


 1---->  bedford is served by bedfordshire police the police and crime <split> crime crime force that force is <low_freq> martins . <split> crime crime martins martins of of of of of of martins martins martins martins martins martins martins martins . martins martins martins
 2---->  bedford is served by bedfordshire police the police and . <split> crime commissioner of that force is <low_freq> martins .


 1---->  bedford is the county town of bedfordshire in southern england . <split> it it the main settlement in the borough of 

--------split model training sampling display--------
 1---->  <low_freq> is a dialect of the shona language . . by the <low_freq> cultures <split> it are the eastern part of zimbabwe and across the border in mozambique .
 2---->  <low_freq> is a dialect of the shona language largely spoken by the <low_freq> . <split> people in the eastern part of zimbabwe and across the border in mozambique .


 1---->  <low_freq> entered the police force in san sebastián in 1941 . eventually becoming commander of the '' <split> he <low_freq> <low_freq> - social '' ( the <low_freq> political police division ) in san sebastián .
 2---->  <low_freq> entered the police force in san sebastián in 1941 , eventually becoming commander of the . <split> '' <low_freq> <low_freq> - social '' ( the <low_freq> political police division ) in san sebastián .


 1---->  manzanillo is the state 's primary port and tourist destination . it on the <split> it tourist coast less than two hour 's drive from the capital .
 

--------split model training sampling display--------
 1---->  the east village is a neighborhood in the borough of manhattan in new york city in is east of <split> it manhattan village , south of gramercy and stuyvesant town , south north of the lower east side .
 2---->  the east village is a neighborhood in the borough of manhattan in new york city which lies east . <split> of greenwich village , south of gramercy and stuyvesant town , and north of the lower east side .


 1---->  the east west line ( <low_freq> ) is a high capacity mrt line . <split> the the second mass rapid transit line to be built in singapore .
 2---->  the east west line ( <low_freq> ) is a high capacity mrt line . <split> and the second mass rapid transit line to be built in singapore .


 1---->  the east west line was the second line to be opened , which is from city hall to <low_freq> <split> it extended extended extended to clementi on 12 march 1988 and subsequently to lakeside on 5 november 1988 .
 2----

--------split model training sampling display--------
 1---->  the white canons followed a code of austerity similar to that of cistercian <low_freq> . unlike <split> unlike monks of other orders , they were exempt from the strict episcopal discipline .
 2---->  the white canons followed a code of austerity similar to that of cistercian <low_freq> , . <split> unlike monks of other orders , they were exempt from the strict episcopal discipline .


 1---->  the white city company opened the <split> the white city opened gardens .
 2---->  the white city company opened . <split> the white city pleasure gardens .


 1---->  the white city at shepherds bush now occupies the old site of the wood lane brickworks . from <split> from which more than a million stock bricks were supplied for the foundations of the albert memorial .
 2---->  the white city at shepherds bush now occupies the old site of the wood lane brickworks , . <split> from which more than a million stock bricks were supplied f

--------split model training sampling display--------
 1---->  <low_freq> is a village 16 miles north - north - north of derby , <split> it it <low_freq> and ashbourne , and has a population of about 580 .
 2---->  <low_freq> is a village 16 miles north - north - west of derby . <split> , between <low_freq> and ashbourne , and has a population of about 580 .


 1---->  <low_freq> bald is partly in towns county , in union county . and in part of <split> it the blue ridge mountains and the <low_freq> ranger district of the chattahoochee national forest .
 2---->  <low_freq> bald is partly in towns county and in union county , and is part . <split> of the blue ridge mountains and the <low_freq> ranger district of the chattahoochee national forest .


 1---->  <low_freq> bald is partly in towns county and in union county <split> the , <low_freq> the peak is in towns county .
 2---->  <low_freq> bald is partly in towns county and in union . <split> county , but the peak is in towns county .

--------split model training sampling display--------
 1---->  there were approximately 17,000 at - risk vietnamese on embassy rolls . which using an average of seven dependents per family meant that <split> there the number requiring evacuation was <low_freq> and taken with other categories of vietnamese the number quickly increased to over 200,000 .
 2---->  there were approximately 17,000 at - risk vietnamese on embassy rolls , which using an average of seven dependents per family meant . <split> that the number requiring evacuation was <low_freq> and taken with other categories of vietnamese the number quickly increased to over 200,000 .


 1---->  there were approximately 20 - 60 worker 's guards in finland between 31 august and 30 september 1917 . but on 20 october 1917 30 <split> after 20 defeat in the october parliamentary elections , the finnish labour union proclaimed the need to establish worker 's guards in the country .
 2---->  there were approximately 20 - 60 worker 's g

--------split model training sampling display--------
 1---->  <low_freq> went on to win '' best asia - pacific vocalist '' in the '' original chinese music <split> in '' '' held in china in 2010 with her album '' back to basics <low_freq> '' .
 2---->  <low_freq> went on to win '' best asia - pacific vocalist '' in the '' original chinese . <split> music awards '' held in china in 2010 with her album '' back to basics <low_freq> '' .


 1---->  <low_freq> was a centre forward who did not play league football until he was 23 <split> he , old , he made 277 league appearances and scored 100 goals .
 2---->  <low_freq> was a centre forward who did not play league football until he was . <split> 23 years old , he made 277 league appearances and scored 100 goals .


 1---->  <low_freq> chen ( born 28 august 1961 ) is the anchor of al jazeera america <split> he is flagship evening news show america tonight , which launched august 2013 .
 2---->  <low_freq> chen ( born 28 august 1961 ) is the

--------split model training sampling display--------
 1---->  brandenburg appeared on the group 's second album '' faithful friends '' , which had higher production values <split> it and their first and contained many of the songs that were part of the live act .
 2---->  brandenburg appeared on the group 's second album '' faithful friends '' , which had higher production . <split> values than their first and contained many of the songs that were part of their live act .


 1---->  brandenburg is a fellow of the audio engineering society ( aes ) along with herr <low_freq> and josh <low_freq> . head <split> the of the aes standards committee working group sc - 06 - 04 '' internet audio delivery systems '' .
 2---->  brandenburg is a fellow of the audio engineering society ( aes ) along with herr <low_freq> and josh <low_freq> and . <split> head of the aes standards committee working group sc - 06 - 04 '' internet audio delivery systems '' .


 1---->  <low_freq> is a village and civil

--------split model training sampling display--------
 1---->  some say a cross between missy elliot and a female version of eminem , making her <split> her stand out among her female counter parts for creativity , versatility and personalized flow .
 2---->  some say a cross between missy elliot and a female version of eminem , making . <split> her stand out among her female counter parts for creativity , versatility and personalized flow .


 1---->  some say a local man named <low_freq> woodard painted the mule . other residents contend <split> other that it was painted as an advertisement of a local stock farm .
 2---->  some say a local man named <low_freq> woodard painted the mule ; other residents . <split> contend that it was painted as an advertisement of a local stock farm .


 1---->  some say he could not handle the fame . <split> others others hint at a more sinister reason .
 2---->  some say he could not handle the fame . <split> , others hint at a more sinister reason .

--------split model training sampling display--------
 1---->  sabine agrees to take him with them . assuming that stepan is bean <split> stepan stepan son , while stepan thinks sabine is bean 's fiancée .
 2---->  sabine agrees to take him with them , assuming that stepan is . <split> bean 's son , while stepan thinks sabine is bean 's fiancée .


 1---->  sabine was married for seven years to frederic jules robert <low_freq> ( born 1946 - died 1991 ) . german <split> he citizen , an entrepreneur , former financial banker / broker and owner of the <low_freq> bank in germany .
 2---->  sabine was married for seven years to frederic jules robert <low_freq> ( born 1946 - died 1991 ) a . <split> german citizen , an entrepreneur , former financial banker / broker and owner of the <low_freq> bank in germany .


 1---->  <low_freq> & silver argued that <low_freq> and <low_freq> are the two key features of <low_freq> . putting a task <split> putting , is not <low_freq> , they argue , if there

--------split model training sampling display--------
 1---->  <low_freq> desmond ( born october 7 , 1977 ) is a new zealand actress . played the <split> she actress discord on three television shows ' , ' , ' '' young hercules '' .
 2---->  <low_freq> desmond ( born october 7 , 1977 ) is a new zealand actress who played . <split> the goddess discord on three television shows ' , ' , and '' young hercules '' .


 1---->  <low_freq> ( pronunciation : '' <low_freq> '' - rhymes with '' eggs '' ) grew up on a farm near <low_freq> <split> he grew where , where he was more interested in the mechanical devices used to raise crops than actually farming .
 2---->  <low_freq> ( pronunciation : '' <low_freq> '' - rhymes with '' eggs '' ) grew up on a farm near . <split> <low_freq> , illinois , where he was more interested in the mechanical devices used to raise crops than actually farming .


 1---->  <low_freq> , a georgian who had served under lee in the u.s. army and who hated his fellow south

--------split model training sampling display--------
 1---->  single chimneys were initially fitted to <low_freq> -- <low_freq> when built but <split> they they were replaced with double chimneys between 1939 and 1944 .
 2---->  single chimneys were initially fitted to <low_freq> -- <low_freq> when built . <split> but they were replaced with double chimneys between 1939 and 1944 .


 1---->  single combat is a duel between two single warriors which takes place in the context of a battle <split> the war two armies , with the two often considered the champions of their respective sides .
 2---->  single combat is a duel between two single warriors which takes place in the context of a . <split> battle between two armies , with the two often considered the champions of their respective sides .


 1---->  single cylinder development began in 1937 under project engineer harry wood <split> using using a test unit designed by sir harry ricardo .
 2---->  single cylinder development began in 

--------split model training sampling display--------
 1---->  it is in the parish of <low_freq> and in the barony of few lower . is the <split> it after the fortified house with defended courtyard that was built by john hamilton in 1619 .
 2---->  it is in the parish of <low_freq> and in the barony of few lower and is . <split> named after the fortified house with defended courtyard that was built by john hamilton in 1619 .


 1---->  it is in the parish of nesting , the parish <split> it was is to the east near <low_freq> .
 2---->  it is in the parish of nesting , the . <split> parish church lying to the east near <low_freq> .


 1---->  it is in the plateau region . <split> it has an agricultural topography .
 2---->  it is in the plateau region . <split> and has an agricultural topography .


 1---->  it is in the pursuit of the never - ending goal of reaching and achieving fa - class that articles on <split> it wikipedia are rated on an assessment scale ; this scale serves as an 

--------split model training sampling display--------
 1---->  it was settled initially initially migrant workers who had come to work on the airport in 1991 . <split> he who stayed behind and reclaimed a piece of airport land that was marshy and otherwise unusable .
 2---->  it was settled initially by migrant workers who had come to work on the airport in 1991 . <split> , who stayed behind and reclaimed a piece of airport land that was marshy and otherwise unusable .


 1---->  it was settled mainly by the irish and was named after the <split> it was a seaport town in county derry , ireland .
 2---->  it was settled mainly by the irish and was named after . <split> coleraine , a seaport town in county derry , ireland .


 1---->  it was several more years before a functional analysis of the vertebrate '' slit '' and '' robo '' . <split> it , performed , and demonstrated that slit robo robo signaling regulates <low_freq> axon guidance in vertebrates as well . in it guidance guidance g

--------split model training sampling display--------
 1---->  however , they also announced that over one - hundred landmarks in l.a. were featured in the game , <split> in in their exact geographical locations , such as the los angeles convention center and the staples center .
 2---->  however , they also announced that over one - hundred landmarks in l.a. were featured in the game . <split> , in their exact geographical locations , such as the los angeles convention center and the staples center .


 1---->  however , they also found that many beers have extremely high levels of gluten <split> if so , if unsure , <low_freq> are advised to avoid beer .
 2---->  however , they also found that many beers have extremely high levels of . <split> gluten so , if unsure , <low_freq> are advised to avoid beer .


 1---->  however , they also have their disadvantages . as they cool they are harder <split> they to bend so one has to work in margins of time .
 2---->  however , they also have 

--------split model training sampling display--------
 1---->  south of the church is <low_freq> manor house . this jacobean house <split> this is built of brick with a blue brick <low_freq> pattern .
 2---->  south of the church is <low_freq> manor house , this jacobean . <split> house is built of brick with a blue brick <low_freq> pattern .


 1---->  south of the city centre lies a park around the former <low_freq> estate of the graf <split> the centre estate , of which only the foundation walls and some yard buildings remain .
 2---->  south of the city centre lies a park around the former <low_freq> estate of the . <split> graf <low_freq> family , of which only the foundation walls and some yard buildings remain .


 1---->  south of the city centre lies a park around the former schloss / manor <split> it is of which only the foundation walls and some yard buildings remain .
 2---->  south of the city centre lies a park around the former schloss / . <split> manor <low_freq> of whi

--------split model training sampling display--------
 1---->  douglas thomas ( born 1966 ) is associate professor at the annenberg school for communication at <split> he the university of southern california , he studies technology , communication , and culture .
 2---->  douglas thomas ( born 1966 ) is associate professor at the annenberg school for communication . <split> at the university of southern california where he studies technology , communication , and culture .


 1---->  douglas tilden ( may 1 , 1861 to august 5 , 1935 ) was a world - famous deaf sculptor . <split> he went to the california school for the deaf in berkeley , california ( now in fremont , california ) .
 2---->  douglas tilden ( may 1 , 1861 to august 5 , 1935 ) was a world - famous deaf sculptor . <split> who went to the california school for the deaf in berkeley , california ( now in fremont , california ) .


 1---->  douglas van <low_freq> ( 1901 -- 1995 ) was an american quaker <low_freq> . was born on

--------split model training sampling display--------
 1---->  it is the third sm <low_freq> in the province of cavite . was opened on <split> it november 2005 , a land area of and a floor area of and
 2---->  it is the third sm <low_freq> in the province of cavite and was opened . <split> on november 2005 with a land area of and a floor area of .


 1---->  it is the third sm <low_freq> in the province of pampanga after sm city pampanga and sm city <split> the , and also the second sm <low_freq> located at city of san fernando , pampanga .
 2---->  it is the third sm <low_freq> in the province of pampanga after sm city pampanga and sm . <split> city clark and also the second sm <low_freq> located at city of san fernando , pampanga .


 1---->  it is the third and final installment of the '' castlevania '' series on the game boy advance of was <split> it released in north america on may 6 , 2003 and released in japan on may 6 , 2003 .
 2---->  it is the third and final installment of t

--------split model training sampling display--------
 1---->  depicted in the film as the figurehead of shocker , he is a mysterious elderly man <split> he is vampire - like fangs and dons a long black and red cape .
 2---->  depicted in the film as the figurehead of shocker , he is a mysterious elderly . <split> man with vampire - like fangs and dons a long black and red cape .


 1---->  depicted often as a girl or boy , but was as a woman <split> sometimes or a man , s / he carries a woman , .
 2---->  depicted often as a girl or boy , but sometimes as a . <split> woman or a man , s / he carries a flower basket .


 1---->  depicting a family that owns a japanese confectionery <split> owns owns owns lead character named tsubasa .
 2---->  depicting a family that owns a japanese . <split> confectionery store , lead character named tsubasa .


 1---->  depicting color was of great importance to vincent . in letters to his brother , <split> he he he he often described objects in his p

--------split model training sampling display--------
 1---->  however , abrams was never able to demonstrate that his devices were effective ; no <low_freq> device has been found effective in the diagnosis or <split> the the of any disease , no the united states food and drug administration does not recognize any legitimate medical uses for such devices .
 2---->  however , abrams was never able to demonstrate that his devices were effective ; no <low_freq> device has been found effective in the diagnosis . <split> or treatment of any disease , and the united states food and drug administration does not recognize any legitimate medical uses for such devices .


 1---->  however , abu al - hassan went too far in attempting to impose more authority over the <split> however arab tribes went who revolted and in april 1348 defeated his army near <low_freq> .
 2---->  however , abu al - hassan went too far in attempting to impose more authority over . <split> the arab tribes , who revolted 

--------split model training sampling display--------
 1---->  estimates range between 1,600 and 3,000 different police forces in total are in <split> there operation , there there are over 350,000 police agents in mexico .
 2---->  estimates range between 1,600 and 3,000 different police forces in total are . <split> in operation , and there are over 350,000 police agents in mexico .


 1---->  estimates range from 30,000 to 55,000 killed between spring and autumn 1945 '' , mostly prisoners of war repatriated by <split> mostly the british military authorities from austria , they they had fled , where in these post-war summary executions .
 2---->  estimates range from 30,000 to 55,000 killed between spring and autumn 1945 '' , mostly prisoners of war repatriated . <split> by the british military authorities from austria , where they had fled , died in these post-war summary executions .


 1---->  estimates vary , but there are perhaps <low_freq> speakers . though counting is difficul

--------split model training sampling display--------
 1---->  gene burns ( born december 3 , 1940 ) is an american talk radio host who is <split> he programs broadcast from the studios of <low_freq> ( am <low_freq> ) in san francisco .
 2---->  gene burns ( born december 3 , 1940 ) is an american talk radio host who . <split> two programs broadcast from the studios of <low_freq> ( am <low_freq> ) in san francisco .


 1---->  gene c. mckinney one of five mckinney brothers born in monticello , <split> he all , all of whom served in the army .
 2---->  gene c. mckinney one of five mckinney brothers born in monticello . <split> , florida , all of whom served in the army .


 1---->  gene <low_freq> is an american jazz drummer , born in new york , but currently residing in the united kingdom . he <split> he is a visiting tutor at the birmingham conservatoire , the royal academy of music , trinity and the guildhall .
 2---->  gene <low_freq> is an american jazz drummer , born in new york ,

--------split model training sampling display--------
 1---->  the first life cube , an eight - foot square , was planned in 2010 and erected in 2011 <split> the the and the artist was motivated to build another version of the life cube the following year .
 2---->  the first life cube , an eight - foot square , was planned in 2010 and erected in . <split> 2011 , and the artist was motivated to build another version of the life cube the following year .


 1---->  the first limited overs international was played in 1971 and the governing international cricket council ( icc ) <split> ) , seeing its potential , staged the first limited of cricket world cup in 1975 .
 2---->  the first limited overs international was played in 1971 and the governing international cricket council ( icc . <split> ) , seeing its potential , staged the first limited overs cricket world cup in 1975 .


 1---->  the first <low_freq> convention was signed in <low_freq> , togo in 1975 . arose out of europe 's wis

--------split model training sampling display--------
 1---->  it is now on display at somerset <low_freq> london , <split> voice voice over narration by actor tony <low_freq> .
 2---->  it is now on display at somerset <low_freq> london . <split> , voice over narration by actor tony <low_freq> .


 1---->  it is now one of the best known cultural destinations in south korea . and viewing of <split> viewing the sunrise over the sea of japan ( east sea ) is especially popular .
 2---->  it is now one of the best known cultural destinations in south korea , and viewing . <split> of the sunrise over the sea of japan ( east sea ) is especially popular .


 1---->  it is now operating at the san marcos national fair in <low_freq> park also began showing '' stargate <low_freq> '' in <split> it its early 3d turbo theater , it finally received new branding and was no longer called '' <low_freq> '' .
 2---->  it is now operating at the san marcos national fair in <low_freq> park also began show

--------split model training sampling display--------
 1---->  a phoenix working for <low_freq> , found frozen on <low_freq> <split> <low_freq> found found searching for the shield card .
 2---->  a phoenix working for <low_freq> , found frozen on . <split> <low_freq> mountain while searching for the shield card .


 1---->  a phone with either no or limited internet <split> ability internet or ability to run apps .
 2---->  a phone with either no or limited . <split> internet capabilities or ability to run apps .


 1---->  a <low_freq> is a quantum mechanical description of a special type of vibrational motion . known as normal <split> known modes in classical mechanics , in which a lattice uniformly <low_freq> at the same frequency .
 2---->  a <low_freq> is a quantum mechanical description of a special type of vibrational motion , known as . <split> normal modes in classical mechanics , in which a lattice uniformly <low_freq> at the same frequency .


 1---->  a photo - essay , '' 

--------split model training sampling display--------
 1---->  some releases have '' scottish air & <low_freq> '' as the final track while others have '' <split> others i spent my last $ 10 ( on birth control & beer ) '' .
 2---->  some releases have '' scottish air & <low_freq> '' as the final track while others have . <split> '' i spent my last $ 10 ( on birth control & beer ) '' .


 1---->  some religions are based completely on the use of certain drugs , such as <low_freq> . which are <split> which mostly <low_freq> , being either <low_freq> or <low_freq> , but some are also stimulants and <low_freq> .
 2---->  some religions are based completely on the use of certain drugs , such as <low_freq> , which . <split> are mostly <low_freq> , being either <low_freq> or <low_freq> , but some are also stimulants and <low_freq> .


 1---->  some religious communities require people to remove shoes before they enter holy of the <split> they craft of the are called <low_freq> , <low_freq> , o

In [None]:
stop

In [None]:
sample_num=2
topk=4

predicts, log_probs=split_model.dec.decode_topk_seqs(split_model.enc, inputs=torch.LongTensor(split_valid_set_inputs[0:sample_num]), 
                             input_lens=torch.LongTensor(split_valid_set_input_lens[0:sample_num]), 
                             topk=topk)

predicts = batch_tokens_remove_eos(predicts, vocab)
labels = batch_tokens_remove_eos(split_pseudo_valid_set_labels[0:sample_num], vocab)

predicts = batch_tokens2words(predicts, vocab)
labels = batch_tokens2words(labels, vocab)

predicts_sents = batch_words2sentence(predicts)
labels_sents = batch_words2sentence(labels)

for idx, sent in enumerate(predicts_sents):
    print(' 1----> ', sent)
    if idx%topk==(topk-1):
        print(' 2----> ', labels_sents[int(idx/topk)])
        print('\n')

In [None]:
# copy_thres=1.0
# split_loss, predicts = split_model.forward(torch.LongTensor(split_train_set_inputs[0:sample_num]), 
#                                      torch.LongTensor(split_train_set_input_lens[0:sample_num]), 
#                                      labels=torch.LongTensor(split_pseudo_train_set_labels[0:sample_num]), 
#                                      is_train=1, teaching_rate=1)

# predicts = batch_tokens_remove_eos(predicts, vocab)
# labels = batch_tokens_remove_eos(split_pseudo_train_set_labels[0:sample_num], vocab)

# predicts = batch_tokens2words(predicts, vocab)
# labels = batch_tokens2words(labels, vocab)

# predicts_sents = batch_words2sentence(predicts)
# labels_sents = batch_words2sentence(labels)

# for (predict_sent, label_sent) in zip(predicts_sents, labels_sents):
#     print(' 1----> ', predict_sent)
#     print(' 2----> ', label_sent)
#     print('\n')

In [None]:
stop