# Import

In [None]:
import os
import numpy as np
import time

from Util import my_helper
import copy
import itertools
import random
import pickle as cPickle
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.python.layers import core as core_layers


from Model import LM
from Util.myAttWrapper import SelfAttWrapper
from Util import myResidualCell
from Util.bleu import BLEU
from Util.myUtil import *


tf.logging.set_verbosity(tf.logging.INFO)
sess_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))




# Util

In [None]:
def _construct_blank(num, length, style='random'):
    data = X_indices[:num]
    blank_data = []
    for d in data:
        if style == 'random':
            pos_list = np.sort(random.sample(range(len(d)-2), length)).tolist()
            blank_data.append((d, pos_list))
        elif style == 'middle':
            l = int(length * (len(d) - 1))
            pos_list  = [int((len(d)-1-l) / 2.0) + i for i in range(l)]
            blank_data.append((d, pos_list))
    return blank_data

def show_blank(idx, pos):
    t = [len(id2w)+5 if i in pos else a for i,a in enumerate(idx)]
    s = ' '.join([id2w.get(tt,'_') for tt in t])
    return s

def idx2str(idx):
    return " ".join(id2w.get(idxx, '_') for idxx in idx)

def str2idx(idx):
    idx = idx.strip()
    return [w2id[idxx] for idxx in idx.split(' ')]

def cal_candidate_list(inputs, pos):
    idx = copy.deepcopy(inputs)
    idx[pos] = model.qid_list[0]
    candidate_list = []
    for t in range(len(model.dp.X_w2id)):
        temp = copy.deepcopy(idx)
        temp = [k if k!=model.qid_list[0] else t for k in temp]
        candidate_list.append(temp)
    return candidate_list


def replace_list(idx, pos_list, target):
    t = [idxx for idxx in idx]
    if target:
        for i,p in enumerate(pos_list):
            t[p] = target[i]
    else:
        for i,p in enumerate(pos_list):
            t[p] = -1
    return t

def cal_optimal(idx, pos, max_it=-1):
    X_batch = cal_candidate_list(idx, pos)
    X_batch_len = [len(x) for x in X_batch]
    
    if max_it > 0:
        batch_loss = []
        t = 0
        while t+max_it<len(X_batch):
            batch_loss += model.sess.run(model.batch_loss, {model.X: X_batch[t:t+max_it],
                                                        model.X_seq_len: X_batch_len[t:t+max_it],
                                                        model.output_keep_prob:1,
                                                        model.input_keep_prob:1}).tolist()
            t += max_it
        batch_loss += model.sess.run(model.batch_loss, {model.X: X_batch[t:],
                                                        model.X_seq_len: X_batch_len[t:],
                                                        model.output_keep_prob:1,
                                                        model.input_keep_prob:1}).tolist()
    else:
        batch_loss = model.sess.run(model.batch_loss, {model.X: X_batch,
                                                        model.X_seq_len: X_batch_len,
                                                        model.output_keep_prob:1,
                                                        model.input_keep_prob:1}).tolist()
    argsort_batch_loss = np.argsort(batch_loss)
    sort_loss = [batch_loss[i] for i in argsort_batch_loss]
    sort_idx = [X_batch[i][pos] for i in argsort_batch_loss]
    return sort_loss, sort_idx

def cal_dist(vector1, vector2):
    return np.linalg.norm(vector1-vector2), np.dot(vector1,vector2)/(np.linalg.norm(vector1)*(np.linalg.norm(vector2)))


def _init_blank(idx, pos):
    c_idx = copy.deepcopy(idx)
    o_idx = []
    pos = np.sort(pos).tolist()
    for i in range(len(idx)):
        if i in pos:
            if i == 0:
                o_idx.append(w2id[model.generate(1)[1][0][0]])
            else:
                prefix = idx2str(o_idx)
                infer = model.infer(prefix)
                infer = infer[np.argmax([len(inf) for inf in infer])]
                if len(str2idx(infer)) <= i:
                    o_idx.append(w2id['<PAD>'])
                else:
                    o_idx.append(str2idx(infer)[i])
        else:
            o_idx.append(c_idx[i])
    init_word = [id2w[o_idx[i]] for i in pos]    
    return o_idx, idx2str(o_idx), init_word

def _init_data(name):
    w2id, id2w = cPickle.load(open('Data/%s/w2id_id2w.pkl' % name,'rb'))
    X_indices = cPickle.load(open('Data/%s/index.pkl' % name,'rb'))
    return X_indices, w2id, id2w

def _init_model(name, lr=10.0, l1_reg_lambda=0.00, l2_reg_lambda=0.00, close_loss_rate=0.00):
    qid_list = cPickle.load(open('Data/%s/qid_list.pkl'%name,'rb'))
    qid_list = [w2id[w] for w in qid_list]
    rnn_size = dict()
    rnn_size['Poem'] =  512
    rnn_size['Daily'] = 512 
    rnn_size['APRC'] = 1024
    
    num_layer = dict()
    num_layer['Poem'] = 2
    num_layer['Daily'] = 1
    num_layer['APRC'] = 1
    
    max_infer_length = dict()
    max_infer_length['Poem'] = 33
    max_infer_length['Daily'] = 50
    max_infer_length['APRC'] = 36
    
    model_iter = dict()
    model_iter['Poem'] = 30
    model_iter['Daily'] = 30 
    model_iter['APRC'] = 20
    
    assert name in ['Poem','Daily', 'APRC']

    BATCH_SIZE = 256
    NUM_EPOCH = 30
    train_dir ='Model/%s' % name
    dp = LM_DP(X_indices, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)
    g = tf.Graph() 
    sess = tf.Session(graph=g, config=sess_conf) 
    with sess.as_default():
        with sess.graph.as_default():
            model = LM(
                dp = dp,
                rnn_size = rnn_size[name],
                n_layers = num_layer[name],
                decoder_embedding_dim = rnn_size[name],
                cell_type='lstm',
                close_loss_rate = close_loss_rate,
                max_infer_length = max_infer_length[name],
                att_type='B',
                qid_list = qid_list,
                lr = lr,
                l1_reg_lambda = l1_reg_lambda,
                l2_reg_lambda = l2_reg_lambda,
                is_save = False,
                residual = True,
                is_jieba = False,
                sess=sess
            )


    util = LM_util(dp=dp, model=model)
    model.restore('Model/%s/model-%d'% (name,model_iter[name])) # restore pre-train model
    return model



def _reload(name):
    rnn_size = dict()
    rnn_size['Poem'] =  512
    rnn_size['Daily'] = 512 
    rnn_size['APRC'] = 1024
    
    num_layer = dict()
    num_layer['Poem'] = 2
    num_layer['Daily'] = 1
    num_layer['APRC'] = 1
    
    max_infer_length = dict()
    max_infer_length['Poem'] = 33
    max_infer_length['Daily'] = 50
    max_infer_length['APRC'] = 36
    
    model_iter = dict()
    model_iter['Poem'] = 30
    model_iter['Daily'] = 30 
    model_iter['APRC'] = 20
    
    assert name in ['Poem','Daily', 'APRC']

    model.restore('Model/%s/model-%d'% (name,model_iter[name]))

# TIGS

In [None]:
def cal_optimizer_gibbs(inputs, pos_list, init_word, name='Nesterov', epoch=10, is_show=True, top_k=100, upper_size=2, distance='l2'):
    total_tic = time.time()
    init_time = 0.0
    update_time = 0.0
    assign_time = 0.0
    search_time = 0.0
    cal_pre_time = 0.0
    cal_next_time = 0.0
    
    tic = time.time()
    upper_cnt = 0
    
    # prepare
    K = len(pos_list)
    assert K <= len(model.qid_list)
    idx = copy.deepcopy(inputs)
    for i,pos in enumerate(pos_list):
        idx[pos] = model.qid_list[i]
    if init_word:
        pre_p_list = [w2id[t] for t in init_word]
    else:
        pre_p_list = model.qid_list[:K]
    next_p_list = []
    pre_sentence = replace_list(idx, pos_list, pre_p_list)
    epoch_sentence  = replace_list(idx, pos_list, pre_p_list)
    word_emb = model.sess.run(model.decoder_embedding)
        
    # init specific embedding 
    if init_word:
        feed_dict = dict()
        for j in range(K):
            feed_dict[model.assgin_placeholder_list[j]] = word_emb[[w2id[init_word[j]]]]
        model.sess.run(model.assign_op_list[:K],feed_dict)
    init_time += time.time()-tic

    # search
    for i in range(epoch):
        if i > 1 and epoch_sentence == replace_list(idx, pos_list, pre_p_list):
            upper_cnt += 1
            if upper_cnt >= upper_size:
                if is_show:
                    print('total_epoch %d'% (i+1))
                break
        else:
            upper_cnt = 0
        epoch_sentence = replace_list(idx, pos_list, pre_p_list)
        if is_show:
            print('epoch %d :' %(i+1), idx2str(epoch_sentence))
        ep_tic = time.time()    
        for k in range(K):
            # O-step
            pre_sentence = replace_list(idx, pos_list, pre_p_list)
            tic = time.time()
            if distance == 'cos':
                v, o = model.sess.run([model.nearby_val, model.nearby_idx], {model.nearby_word:[model.qid_list[k]]})
                nearset = o[0][1]
            else:
                v, o = model.sess.run([model.eu_nearby_val, model.eu_nearby_idx], {model.nearby_word:[model.qid_list[k]]})
                nearset = o[1]
            loss, _ = model.sess.run([model.update_loss, 
                                      model.update_op[name+'_%d' % k]], 
                                         {model.X: [idx], 
                                          model.X_seq_len: [len(idx)], 
                                          model.Y:[pre_sentence],
                                          model.output_keep_prob:1,
                                          model.input_keep_prob:1,
                                          model.nearest_emb_placeholder:word_emb[[nearset]]})
            update_time += time.time() - tic


            # P-step
            if i % 1 == 0:
                # candidate
                tic = time.time()
                if distance == 'cos':
                    v, o = model.sess.run([model.nearby_val, model.nearby_idx], {model.nearby_word:[model.qid_list[k]]})
                    candi_pos = o[0][1:top_k+1].tolist() + [pre_p_list[k]]
                else:
                    v, o = model.sess.run([model.eu_nearby_val, model.eu_nearby_idx], {model.nearby_word:[model.qid_list[k]]})
                    candi_pos = o[1:top_k+1].tolist() + [pre_p_list[k]]
                    
                candi_list = [[pre_p_list[j] if j!=k else t for j in range(len(pre_p_list))] for t in candi_pos]
                next_sentences = [replace_list(idx, pos_list, candi) for candi in candi_list]
                search_time += time.time() - tic
                
                # cal loss
                tic = time.time()
                next_loss_list = model.sess.run(model.batch_loss, {model.X: next_sentences, 
                                        model.X_seq_len: [len(idx) for j in range(len(next_sentences))], 
                                        model.output_keep_prob:1,
                                        model.input_keep_prob:1})
                argmin_idx = np.argmin(next_loss_list)
                next_p_pos = candi_pos[argmin_idx]
                cal_next_time += time.time()-tic
                # update
                tic = time.time()
                if next_p_pos != pre_p_list[k]:
                    model.sess.run(model.assign_op_list[k],{model.assgin_placeholder_list[k]:word_emb[[next_p_pos]]})
                    pre_p_list[k] = next_p_pos
                    
                assign_time += time.time() - tic
                if is_show:
                    print('epoch %d_%d :' % (i+1, k),idx2str(replace_list(idx, pos_list, pre_p_list)))
                
    tic = time.time()    
    pre_sentence = replace_list(idx, pos_list, pre_p_list)
    loss = model.sess.run(model.loss, {model.X: [pre_sentence], 
                        model.X_seq_len: [len(idx)], 
                        model.output_keep_prob:1,
                        model.input_keep_prob:1})
    cal_pre_time += time.time() - tic
    total_time = time.time() - total_tic
    
    return pre_p_list, idx2str(pre_p_list),loss

In [None]:
# initialize blank with left-to-right greedy beam search
f_init = cPickle.load(open('results/_URNN-f_res.pkl','rb'))  

# Inference

In [None]:
task_name = 'APRC' # 'Daily', 'Poem'
assert task_name in ['Poem','Daily', 'APRC']
X_indices, w2id, id2w = _init_data(task_name)
model = _init_model(task_name, lr=10.0)

                

In [None]:
import random
is_init = True
for length_ratio in [0.25, 0.5, 0.75]:
    for style in ['random', 'middle']:
        blank_data = cPickle.load(open('Data/%s/%d_%s.pkl'%(task_name, int(length_ratio*100), style),'rb'))
        i = random.sample(range(5000), 1)[0]
        idx, pos_list = blank_data[i]
        prefix = '%s_%d_%s' % (task_name, int(length_ratio*100), style)
        model._opt_init()
        if is_init:
            init_word = [id2w[f_init[prefix+'_URNN-f'][i][p]] for p in pos_list]

        else:
            init_word = random.sample(w2id.keys(), len(pos_list))
        sid, sw, loss = cal_optimizer_gibbs(idx, pos_list, init_word = init_word, is_show=False)
        print('Template:', show_blank(idx, pos_list))
        print('GroundTruth:', idx2str(idx))
        print('TIGS:', idx2str(replace_list(idx, pos_list, sid)))
        print('')