In [1]:
import collections
import tensorflow as tf
import numpy as np

In [None]:
class BatchedInput(
    collections.namedtuple("BatchedInput",
                           ("initializer", "source", "target_input",
                            "target_output", "source_sequence_length",
                            "target_sequence_length"))):
  pass

In [None]:
def get_iterator(src_dataset,
                 tgt_dataset,
                 src_vocab_table,
                 tgt_vocab_table,
                 batch_size,
                 sos,
                 eos,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_parallel_calls=4,
                 output_buffer_size=None,
                 skip_count=None,
                 num_shards=1,
                 shard_index=0,
                 reshuffle_each_iteration=True,
                 use_char_encode=False):
  output_buffer_size = batch_size * 1000


  src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))

  tgt_sos_id='0'
  tgt_eos_id='1'
  src_sos_id='0'
  src_eos_id='1'

 #把一行转化成一个个单词序列
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (
          tf.string_split([src]).values, tf.string_split([tgt]).values),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

  # Filter zero length input sequences.
  src_tgt_dataset = src_tgt_dataset.filter(
      lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))


  # Convert the word strings to ids.  Word strings that are not in the

#   src_tgt_dataset = src_tgt_dataset.map(
#         lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
#                           tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
#         num_parallel_calls=num_parallel_calls)
  #src_tgt_dataset是单词序列,序列值是单词索引
  src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)
  # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
  #对目标头为加入了tgt_sos_id,tgt_eos_id,(src,tgt)->(src,<sos>+tgt,tgt+<eos>)
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (src,
                        tf.concat(([tgt_sos_id], tgt), 0),
                        tf.concat((tgt, [tgt_eos_id]), 0)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Add in sequence lengths.

  src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt_in, tgt_out: (
            src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
        num_parallel_calls=num_parallel_calls)
  #src_tgt_dataset 表示(src,tgt_in,tgt_out,src.size,tgt_in.size)
  src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)

  # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...)
  '''
  x是一个(src,tgtin,tgtout,size(src),size(tgt))的数据集,
  返回批处理的数据库,shape=(N,) 每个元素是一个(src,tgtin,tgtout,size(src),size(tgt))
  由于每一行长度不同(src,tgin,tgout),会做pad处理,pad_value分别是src_eos_id,tgt_eos_id,tgt_eos_id
  '''
  def batching_func(x):
    return x.padded_batch(
        batch_size,
        # The first three entries are the source and target line rows;
        # these have unknown-length vectors.  The last two entries are
        # the source and target row sizes; these are scalars.
        padded_shapes=(
            tf.TensorShape([None]),  # src
            tf.TensorShape([None]),  # tgt_input
            tf.TensorShape([None]),  # tgt_output
            tf.TensorShape([]),  # src_len
            tf.TensorShape([])),  # tgt_len
        # Pad the source and target sequences with eos tokens.
        # (Though notice we don't generally need to do this since
        # later on we will be masking out calculations past the true sequence.
        padding_values=(
            src_eos_id,  # src
            tgt_eos_id,  # tgt_input
            tgt_eos_id,  # tgt_output
            0,  # src_len -- unused
            0))  # tgt_len -- unused


  batched_dataset = batching_func(src_tgt_dataset)
  batched_iter = batched_dataset.make_initializable_iterator()
  (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len,
   tgt_seq_len) = (batched_iter.get_next())
  return BatchedInput(
      initializer=batched_iter.initializer,
      source=src_ids,
      target_input=tgt_input_ids,
      target_output=tgt_output_ids,
      source_sequence_length=src_seq_len,
      target_sequence_length=tgt_seq_len)

In [None]:
src_dataset=tf.data.TextLineDataset('/home/zhangxk/AIProject/nmt/nmt/scripts/iwsl515/train.en')
tgt_dataset=tf.data.TextLineDataset('/home/zhangxk/AIProject/nmt/nmt/scripts/iwsl515/train.vi')
src_vocab_table=tf.contrib.lookup.index_table_from_file('/home/zhangxk/AIProject/nmt/nmt/scripts/iwsl515/vocab.en')
tgt_vocab_table=tf.contrib.lookup.index_table_from_file('/home/zhangxk/AIProject/nmt/nmt/scripts/iwsl515/vocab.vi')
batch_size=3
sos,eos='<sos>','<eos>'

In [None]:
batchinput=get_iterator(src_dataset,
                 tgt_dataset,
                 src_vocab_table,
                 tgt_vocab_table,
                 batch_size,
                 sos,
                 eos,
                 random_seed=0,
                 num_buckets=1,
                 num_shards=1,
                 shard_index=0,
                 reshuffle_each_iteration=False,
                 use_char_encode=False)

In [None]:
print(batchinput.source)
print(batchinput.target_input)
print(batchinput.source_sequence_length)

In [None]:
with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    sess.run(batchinput.initializer)
    for i in range(2):
        x=sess.run([batchinput.source,
                     batchinput.target_input,
                     batchinput.target_output,
                    batchinput.source_sequence_length,
                    batchinput.target_sequence_length])
        print(x[0].shape)
        print(x[1].shape)
        print(x[2].shape)
        print(x[3])
        print(x[4])
        print('xxxxxxxxx')

In [None]:
print(x[0].shape)
print(x[1].shape)
print(x[2].shape)
print(x[3])
print(x[4])

In [2]:
tf.reset_default_graph()
BATCH,T,D=22,10,64
X=tf.placeholder(dtype=tf.float32,shape=[BATCH,T,D])
with tf.variable_scope('my',tf.AUTO_REUSE):
    ndims=[128,256,512]
    cells=[tf.contrib.rnn.BasicLSTMCell(d) for d in ndims]
#state_is_tuple,有多少成返回的state是个tuple,就有多少个元素,每个元素又是LSTMStateTuple,保存lstm的s,h,都是tensor类型
    complete_cell=tf.contrib.rnn.MultiRNNCell(cells,state_is_tuple=True)
    zero_state=complete_cell.zero_state(BATCH,dtype=tf.float32)
    
    states=zero_state
    outputs=[]
    for t in range(T):
        output,states=complete_cell(X[:,t,:],states)
        outputs.append(output)

Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').


In [3]:
print(outputs) #output是最高层的输出,shape(N,D_last_layer)
print(len(outputs))
print(states) #(outputs,(layer1_state,layer2_state,layer3_state))
'''
#总结:对于MultiRNNCell单元,

输入:
    X:[N,T]
    state:(l1_state,l2_state,...ln_state)
    li_state:LSTMStateTuple(c:Tensor,h:Tensor)
输出
    output:(N,ln_ndims),最后一次的h
    state:(l1_state,l2_state,...ln_state),li_state:LSTMStateTuple(c:Tensor,h:Tensor)
'''

[<tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_2:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_5:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_8:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_11:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_14:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_17:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_20:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_23:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_26:0' shape=(22, 512) dtype=float32>, <tf.Tensor 'my/my/multi_rnn_cell/cell_2/basic_lstm_cell/Mul_29:0' shape=(22, 512) dtype=float32>]
10
(LSTMStateTuple(c=<

'\n#总结:对于MultiRNNCell单元,\n\n输入:\n    X:[N,T]\n    state:(l1_state,l2_state,...ln_state)\n    li_state:LSTMStateTuple(c:Tensor,h:Tensor)\n输出\n    output:(N,ln_ndims),最后一次的h\n    state:(l1_state,l2_state,...ln_state),li_state:LSTMStateTuple(c:Tensor,h:Tensor)\n'

In [17]:
tf.reset_default_graph()
BATCH,T,D=22,10,64
X=tf.placeholder(dtype=tf.float32,shape=[BATCH,T,D])
with tf.variable_scope('mMmM',reuse=tf.AUTO_REUSE):
    ndims=[32,64]
    cells=[tf.contrib.rnn.BasicLSTMCell(d) for d in ndims]
    complete_cell=tf.contrib.rnn.MultiRNNCell(cells)
    
    output,states=tf.nn.dynamic_rnn(complete_cell,X,
                      sequence_length=None,
                      initial_state=None,
                      dtype=tf.float32,
                      swap_memory=True,
                      time_major=False)

In [18]:
print(output)
print(states)

Tensor("mMmM/rnn/transpose_1:0", shape=(22, 10, 64), dtype=float32)
(LSTMStateTuple(c=<tf.Tensor 'mMmM/rnn/while/Exit_3:0' shape=(22, 32) dtype=float32>, h=<tf.Tensor 'mMmM/rnn/while/Exit_4:0' shape=(22, 32) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'mMmM/rnn/while/Exit_5:0' shape=(22, 64) dtype=float32>, h=<tf.Tensor 'mMmM/rnn/while/Exit_6:0' shape=(22, 64) dtype=float32>))


In [40]:
#bi_direction rnn
tf.reset_default_graph()
BATCH,T,D=22,10,64
X=tf.placeholder(dtype=tf.float32,shape=[BATCH,T,D])

with tf.variable_scope('bi',reuse=tf.AUTO_REUSE):
    ndims=[32,64]
    cells=[tf.contrib.rnn.BasicLSTMCell(d) for d in ndims]
    forward_cell=tf.contrib.rnn.MultiRNNCell(cells)
    #注意,要new一边不能用上面的cells,不然不会创建新的weight
    cells=[tf.contrib.rnn.BasicLSTMCell(d) for d in ndims]
    backward_cell=tf.contrib.rnn.MultiRNNCell(cells)
    
    output,states=tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,X,dtype=tf.float32,time_major=False)

In [43]:
# print(output)
print(states)
beam_states=tf.contrib.seq2seq.tile_batch(states,3)
print(beam_states)

((LSTMStateTuple(c=<tf.Tensor 'bi/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(22, 32) dtype=float32>, h=<tf.Tensor 'bi/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(22, 32) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'bi/bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(22, 64) dtype=float32>, h=<tf.Tensor 'bi/bidirectional_rnn/fw/fw/while/Exit_6:0' shape=(22, 64) dtype=float32>)), (LSTMStateTuple(c=<tf.Tensor 'bi/bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(22, 32) dtype=float32>, h=<tf.Tensor 'bi/bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(22, 32) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'bi/bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(22, 64) dtype=float32>, h=<tf.Tensor 'bi/bidirectional_rnn/bw/bw/while/Exit_6:0' shape=(22, 64) dtype=float32>)))
((LSTMStateTuple(c=<tf.Tensor 'tile_batch/Reshape:0' shape=(66, 32) dtype=float32>, h=<tf.Tensor 'tile_batch/Reshape_1:0' shape=(66, 32) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'tile_batch/Reshape_2:0' shape=(66, 64) dtype

In [34]:
tf.trainable_variables()

[<tf.Variable 'bi/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0' shape=(96, 128) dtype=float32_ref>,
 <tf.Variable 'bi/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'bi/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0' shape=(96, 256) dtype=float32_ref>,
 <tf.Variable 'bi/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0' shape=(256,) dtype=float32_ref>]

In [39]:
##
tf.reset_default_graph()
s=tf.constant(value=1,shape=[20,5,2])
h=tf.constant(value=1,shape=[20,5,2])
print(s)
c=tf.contrib.seq2seq.tile_batch((s,h),3)
print(c)

Tensor("Const:0", shape=(20, 5, 2), dtype=int32)
(<tf.Tensor 'tile_batch/Reshape:0' shape=(60, 5, 2) dtype=int32>, <tf.Tensor 'tile_batch/Reshape_1:0' shape=(60, 5, 2) dtype=int32>)
