## 以google/seq2seq的 BasicDecoder + TrainingHelper为例看基本的seq2seq解码阶段的实现

![alt text](figure/seq2seq2014.png)

**Notation**：

这个文件里面，我们把上面这个使用真实response作为输入的解码方法叫做**基本版**解码方法；将使用上一个时间的预测的词语作为下一个时间的输入的解码方法叫做**升级版**解码方法。

### 这里是lecture4_seq2seq_part2A里面实现基本版decoder功能的代码
```python
decode_params = basic_decoder.BasicDecoder.default_params()
decode_params["rnn_cell"]["cell_params"]["num_units"] = decoder_hidden_units
decode_params["max_decode_length"] = 16
decode_params

decoder_inputs_embedded = tf.nn.embedding_lookup(input_embeddings, decoder_inputs)

from seq2seq.contrib.seq2seq import helper as decode_helper
with tf.name_scope('minibatch'):
    helper_ = decode_helper.TrainingHelper(
        inputs = decoder_inputs_embedded,
        sequence_length = decoder_inputs_length)
        
decoder_fn = basic_decoder.BasicDecoder(params=decode_params,
                                       mode=mode,
                                       vocab_size=vocab_size)
                                       
decoder_output, decoder_state = decoder_fn(encoder_output.final_state, helper_)


loss = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32), 
        logits=tf.transpose(decoder_output.logits, perm = [1, 0, 2]))
)
```

#### 一些观察
1. 输入的数据是batch-major，而输出的预测是time-major
2. 我们通过decoder_helper输入了decoder_inputs_embedded，即，我们使用`"<eos> + 真实的response"`作为decoder的input，而不是使用上一个时间sample的词语作为下一个时间的输入
3. 调用decoding对象时，使用两个Class： BasicDecoder()和TrainingHelper(),placeholder中的数据，`decoder_inputs_embedded, decoder_inputs_length`通过TraingHelper()传输给解码过程



## 先来看一下`decoder_fn`,即`BasicDecoder`中的操作

![alt text](figure/basic_decoder.png)

1. 由`decoder_output, decoder_state = decoder_fn(encoder_output.final_state, helper_)`可知，BasicDecoder应该有一个`__call__()`函数调用decoding操作； 通过BasicDecoder --> RNNDecoder --> GraphModule这条继承关系，得知BasicDecoder的`__call__()`函数调用RnnDecoder的`_build()`函数
2. 由`self.initialize`和`self.step`可知，`TrainingHelper`用来准备下一个时间（timestep）的 RNN cell 的输入（包括state 和 word)，即`(first_inputs, initial_state) in self.initialize(), (next_inputs, next_state) in self.step()`
3. `self.step()`的操作包括
  + 更新当前时间的Cell的状态， `self.cell(inputs, state)`
  + 预测当前时间的单词,`logits` + `sample_ids`,在基本版的rnn decoder中，这个`sample_ids`只用作输出; 在升级版的rnn_decoder中，这个`sample_ids`还被用作下一个时间的输入词语；不论那个版本的rnn_decoder, `logits`都被用于计算loss
  + 准备下一个时间的Rnn cell的输入状态+词语
4. 看RnnDecoder这个Base class如何基于step()操作实现整个sequence的decoding操作:RnnDecoder._build()

![alt text](figure/rnn_decoder.png)

`_build`解码过程与`raw_rnn`的结构相似，由初始化和时间循环构成
1. `_setup`初始化
2. **iteration**,调用`contrib.seq2seq.decoder.dynamic_decode`实现循环过程，循环中的每一步执行上面的`step`操作
  
```python
outputs, final_state = dynamic_decode(
    decoder=self, # self.step() 通过self传递给decoder
    output_time_major=True,
    impute_finished=False,
    maximum_iterations=maximum_iterations)

```

## 然后看一下TrainingHelper如何处理RNNcell的输入输出

![alt text](figure/demo_helper.png)

```python
def _unstack_ta(inp):
    return tensor_array_ops.TensorArray(
        dtype=inp.dtype, size=array_ops.shape(inp)[0],
        element_shape=inp.get_shape()[1:]).unstack(inp)
 

class TrainingHelper(Helper):
    """A helper for use during training.  Only reads decoder inputs.
    
    Returned sample_ids are the argmax of the RNN output logits.
    """

# Note-1. inputs,即decoder_inputs, 和 sequence_length， 即decoder_inputs_length
# 是helper对象的property，BasicDecoder对象并不直接读入或处理这些数据
def __init__(self, inputs, sequence_length, time_major=False, name=None):
    """Initializer.

    Args:
        inputs: A (structure of) input tensors.
        sequence_length: An int32 vector tensor.
        time_major: Python bool.  Whether the tensors in `inputs` are time major.
          If `False` (default), they are assumed to be batch major.
        name: Name scope for any created operations.

    Raises:
        ValueError: if `sequence_length` is not a 1D tensor.
    """
    
    # Note-2.处理decoder_input
    # 这个Helper使用真实的回复作为解码阶段的输入，这个输入从Tensor的形式转化为TensorArray的形式
    with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
    inputs = ops.convert_to_tensor(inputs, name="inputs")
    # Note-2-continue. TensorArray 适用于time-major的数据的管理，
    #  如果输入的数据是batch-major的话，需要转化为time-major
    if not time_major:
        inputs = nest.map_structure(_transpose_batch_time, inputs)
    self._input_tas = nest.map_structure(_unstack_ta, inputs)
    
    
    
    self._sequence_length = ops.convert_to_tensor(
        sequence_length, name="sequence_length")
    if self._sequence_length.get_shape().ndims != 1:
        raise ValueError(
            "Expected sequence_length to be a vector, but received shape: %s" %
            self._sequence_length.get_shape())

    self._zero_inputs = nest.map_structure(
        lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

    self._batch_size = array_ops.size(sequence_length)

    @property
    def batch_size(self):
        return self._batch_size
    
    # Note-3，通过这个class和下面的CustomHelper模版可知
    # helper对象用来处理RNNcell的输出（e.g. sample),准备RNNcell的输入(e.g. next_inputs)
    # 通过自定义helper对象，可以带来decoding过程的灵活性
    # 通过step()里面调用helper对象间接处理RNNcell的输入，输出，提高BasicDecoder的通用程度
    def initialize(self, name=None):
        with ops.name_scope(name, "TrainingHelperInitialize"):
            finished = math_ops.equal(0, self._sequence_length)
            all_finished = math_ops.reduce_all(finished)
            next_inputs = control_flow_ops.cond(
                all_finished, lambda: self._zero_inputs,
                lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
        return (finished, next_inputs)

    def sample(self, time, outputs, name=None, **unused_kwargs):
        with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
            sample_ids = math_ops.cast(
                math_ops.argmax(outputs, axis=-1), dtypes.int32)
        return sample_ids

    # Note-4. 这里的next_inputs方法和第四节课的raw_rnn的例子非常接近
    # * 通过比较当前时间和每个样本的预期长度上限确定是否结束了decoding过程
    def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
        """next_inputs_fn for TrainingHelper."""
        with ops.name_scope(name, "TrainingHelperNextInputs",
                            [time, outputs, state]):
            next_time = time + 1
            finished = (next_time >= self._sequence_length)
            all_finished = math_ops.reduce_all(finished)
            def read_from_ta(inp):
                return inp.read(next_time)
            next_inputs = control_flow_ops.cond(
                all_finished, lambda: self._zero_inputs,
                lambda: nest.map_structure(read_from_ta, self._input_tas))
        return (finished, next_inputs, state)
```

```python
class CustomHelper(Helper):
    """Base abstract class that allows the user to customize sampling."""

    def __init__(self, initialize_fn, sample_fn, next_inputs_fn):
        """Initializer.
        
        Args:
            initialize_fn: callable that returns `(finished, next_inputs)` for the first iteration.
            sample_fn: callable that takes `(time, outputs, state)` and emits tensor `sample_ids`.
            next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` and emits `(finished, next_inputs, next_state)`.
        """
        self._initialize_fn = initialize_fn
        self._sample_fn = sample_fn
        self._next_inputs_fn = next_inputs_fn
        self._batch_size = None

        @property
        def batch_size(self):
            if self._batch_size is None:
                raise ValueError("batch_size accessed before initialize was called")
            return self._batch_size

        def initialize(self, name=None):
            with ops.name_scope(name, "%sInitialize" % type(self).__name__):
                (finished, next_inputs) = self._initialize_fn()
                if self._batch_size is None:
                    self._batch_size = array_ops.size(finished)
            return (finished, next_inputs)

        def sample(self, time, outputs, state, name=None):
            with ops.name_scope(
                name, "%sSample" % type(self).__name__, (time, outputs, state)):
            return self._sample_fn(time=time, outputs=outputs, state=state)
            
        def next_inputs(self, time, outputs, state, sample_ids, name=None):
            with ops.name_scope(
                name, "%sNextInputs" % type(self).__name__, (time, outputs, state)):
                return self._next_inputs_fn(
                    time=time, outputs=outputs, state=state, sample_ids=sample_ids)
```

## 然后练习自定义helper实现升级版的decoding
![alt text](figure/nct-seq2seq.png)

## 案例学习：`ScheduledOutputTrainingHelper`

![alt text](figure/scheduled_sampling.jpg)

![alt text](figure/scheduled_sampling_2.jpg)

reference: [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks](https://arxiv.org/pdf/1506.03099.pdf)


```python
class ScheduledOutputTrainingHelper(TrainingHelper):
    """A training helper that adds scheduled sampling directly to outputs.

    Returns False for sample_ids where no sampling took place; True elsewhere.
    """

def __init__(self, inputs, sequence_length, sampling_probability,
             time_major=False, seed=None, next_input_layer=None,
             auxiliary_inputs=None, name=None):
    """Initializer.
    
    Args:
        inputs: A (structure) of input tensors.
        sequence_length: An int32 vector tensor.
        sampling_probability: A 0D `float32` tensor: the probability of sampling 
            from the outputs instead of reading directly from the inputs.
        time_major: Python bool.  Whether the tensors in `inputs` are time major.
            If `False` (default), they are assumed to be batch major.
        seed: The sampling seed.
        next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
            `tf.layers.Dense`.  Optional layer to apply to the RNN output to create
            the next input.
        auxiliary_inputs: An optional (structure of) auxiliary input tensors with
            a shape that matches `inputs` in all but (potentially) the final
            dimension. These tensors will be concatenated to the sampled output or
            the `inputs` when not sampling for use as the next input.
        name: Name scope for any created operations.
        
        Raises:
            ValueError: if `sampling_probability` is not a scalar or vector.
    """
    with ops.name_scope(name, "ScheduledOutputTrainingHelper",
                        [inputs, auxiliary_inputs, sampling_probability]):
        self._sampling_probability = ops.convert_to_tensor(
            sampling_probability, name="sampling_probability")
        if self._sampling_probability.get_shape().ndims not in (0, 1):
            raise ValueError(
                "sampling_probability must be either a scalar or a vector. "
                "saw shape: %s" % (self._sampling_probability.get_shape()))
            
        if auxiliary_inputs is None:
            maybe_concatenated_inputs = inputs
        else:
            inputs = ops.convert_to_tensor(inputs, name="inputs")
        auxiliary_inputs = ops.convert_to_tensor(
            auxiliary_inputs, name="auxiliary_inputs")
        maybe_concatenated_inputs = nest.map_structure(
            lambda x, y: array_ops.concat((x, y), -1),
            inputs, auxiliary_inputs)
        if not time_major:
            auxiliary_inputs = nest.map_structure(
                _transpose_batch_time, auxiliary_inputs)
            
        self._auxiliary_input_tas = (
            nest.map_structure(_unstack_ta, auxiliary_inputs)
            if auxiliary_inputs is not None else None)
        
        self._seed = seed
        if (next_input_layer is not None and not isinstance(next_input_layer,
                                                            layers_base._Layer)):  # pylint: disable=protected-access
            raise TypeError("next_input_layer must be a Layer, received: %s" %
                            type(next_input_layer))
            self._next_input_layer = next_input_layer

        super(ScheduledOutputTrainingHelper, self).__init__(
            inputs=maybe_concatenated_inputs,
            sequence_length=sequence_length,
            time_major=time_major,
            name=name)

def initialize(self, name=None):
    return super(ScheduledOutputTrainingHelper, self).initialize(name=name)

def sample(self, time, outputs, state, name=None):
    with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
                        [time, outputs, state]):
        sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
    return math_ops.cast(
        sampler.sample(sample_shape=self.batch_size, seed=self._seed),
        dtypes.bool)

def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                        [time, outputs, state, sample_ids]):
        (finished, base_next_inputs, state) = (
            super(ScheduledOutputTrainingHelper, self).next_inputs(
                time=time,
                outputs=outputs,
                state=state, 
                sample_ids=sample_ids,
                name=name))
        def maybe_sample():
            """Perform scheduled sampling."""
            
            def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
                """Concatenate outputs with auxiliary inputs, if they exist."""
                if self._auxiliary_input_tas is None:
                    return outputs_
                
                next_time = time + 1
                auxiliary_inputs = nest.map_structure(
                    lambda ta: ta.read(next_time), self._auxiliary_input_tas)
                if indices is not None:
                    auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
                return nest.map_structure(
                    lambda x, y: array_ops.concat((x, y), -1), 
                    outputs_, auxiliary_inputs)

            if self._next_input_layer is None:
                return array_ops.where(
                    sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
                    base_next_inputs)

            where_sampling = math_ops.cast(
                array_ops.where(sample_ids), dtypes.int32)
            where_not_sampling = math_ops.cast(
                array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
            outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
            inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                      where_not_sampling)
            sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
                self._next_input_layer(outputs_sampling), where_sampling)

            base_shape = array_ops.shape(base_next_inputs)
            return (array_ops.scatter_nd(indices=where_sampling,
                                         updates=sampled_next_inputs,
                                         shape=base_shape)
                    + array_ops.scatter_nd(indices=where_not_sampling,
                                           updates=inputs_not_sampling,
                                           shape=base_shape))

    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished, lambda: base_next_inputs, maybe_sample)
    return (finished, next_inputs, state)
```