In [1]:
import tensorflow as tf

  return f(*args, **kwds)


In [2]:
# Copyright 2007 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.

"""Abstract Base Classes (ABCs) according to PEP 3119."""

from _weakrefset import WeakSet


def abstractmethod(funcobj):
    """A decorator indicating abstract methods.

    Requires that the metaclass is ABCMeta or derived from it.  A
    class that has a metaclass derived from ABCMeta cannot be
    instantiated unless all of its abstract methods are overridden.
    The abstract methods can be called using any of the normal
    'super' call mechanisms.

    Usage:

        class C(metaclass=ABCMeta):
            @abstractmethod
            def my_abstract_method(self, ...):
                ...
    """
    funcobj.__isabstractmethod__ = True
    return funcobj


class abstractclassmethod(classmethod):
    """
    A decorator indicating abstract classmethods.

    Similar to abstractmethod.

    Usage:

        class C(metaclass=ABCMeta):
            @abstractclassmethod
            def my_abstract_classmethod(cls, ...):
                ...

    'abstractclassmethod' is deprecated. Use 'classmethod' with
    'abstractmethod' instead.
    """

    __isabstractmethod__ = True

    def __init__(self, callable):
        callable.__isabstractmethod__ = True
        super().__init__(callable)


class abstractstaticmethod(staticmethod):
    """
    A decorator indicating abstract staticmethods.

    Similar to abstractmethod.

    Usage:

        class C(metaclass=ABCMeta):
            @abstractstaticmethod
            def my_abstract_staticmethod(...):
                ...

    'abstractstaticmethod' is deprecated. Use 'staticmethod' with
    'abstractmethod' instead.
    """

    __isabstractmethod__ = True

    def __init__(self, callable):
        callable.__isabstractmethod__ = True
        super().__init__(callable)


class abstractproperty(property):
    """
    A decorator indicating abstract properties.

    Requires that the metaclass is ABCMeta or derived from it.  A
    class that has a metaclass derived from ABCMeta cannot be
    instantiated unless all of its abstract properties are overridden.
    The abstract properties can be called using any of the normal
    'super' call mechanisms.

    Usage:

        class C(metaclass=ABCMeta):
            @abstractproperty
            def my_abstract_property(self):
                ...

    This defines a read-only property; you can also define a read-write
    abstract property using the 'long' form of property declaration:

        class C(metaclass=ABCMeta):
            def getx(self): ...
            def setx(self, value): ...
            x = abstractproperty(getx, setx)

    'abstractproperty' is deprecated. Use 'property' with 'abstractmethod'
    instead.
    """

    __isabstractmethod__ = True


class ABCMeta(type):

    """Metaclass for defining Abstract Base Classes (ABCs).

    Use this metaclass to create an ABC.  An ABC can be subclassed
    directly, and then acts as a mix-in class.  You can also register
    unrelated concrete classes (even built-in classes) and unrelated
    ABCs as 'virtual subclasses' -- these and their descendants will
    be considered subclasses of the registering ABC by the built-in
    issubclass() function, but the registering ABC won't show up in
    their MRO (Method Resolution Order) nor will method
    implementations defined by the registering ABC be callable (not
    even via super()).

    """

    # A global counter that is incremented each time a class is
    # registered as a virtual subclass of anything.  It forces the
    # negative cache to be cleared before its next use.
    # Note: this counter is private. Use `abc.get_cache_token()` for
    #       external code.
    _abc_invalidation_counter = 0

    def __new__(mcls, name, bases, namespace, **kwargs):
        cls = super().__new__(mcls, name, bases, namespace, **kwargs)
        # Compute set of abstract method names
        abstracts = {name
                     for name, value in namespace.items()
                     if getattr(value, "__isabstractmethod__", False)}
        for base in bases:
            for name in getattr(base, "__abstractmethods__", set()):
                value = getattr(cls, name, None)
                if getattr(value, "__isabstractmethod__", False):
                    abstracts.add(name)
        cls.__abstractmethods__ = frozenset(abstracts)
        # Set up inheritance registry
        cls._abc_registry = WeakSet()
        cls._abc_cache = WeakSet()
        cls._abc_negative_cache = WeakSet()
        cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
        return cls

    def register(cls, subclass):
        """Register a virtual subclass of an ABC.

        Returns the subclass, to allow usage as a class decorator.
        """
        if not isinstance(subclass, type):
            raise TypeError("Can only register classes")
        if issubclass(subclass, cls):
            return subclass  # Already a subclass
        # Subtle: test for cycles *after* testing for "already a subclass";
        # this means we allow X.register(X) and interpret it as a no-op.
        if issubclass(cls, subclass):
            # This would create a cycle, which is bad for the algorithm below
            raise RuntimeError("Refusing to create an inheritance cycle")
        cls._abc_registry.add(subclass)
        ABCMeta._abc_invalidation_counter += 1  # Invalidate negative cache
        return subclass

    def _dump_registry(cls, file=None):
        """Debug helper to print the ABC registry."""
        print("Class: %s.%s" % (cls.__module__, cls.__qualname__), file=file)
        print("Inv.counter: %s" % ABCMeta._abc_invalidation_counter, file=file)
        for name in sorted(cls.__dict__.keys()):
            if name.startswith("_abc_"):
                value = getattr(cls, name)
                print("%s: %r" % (name, value), file=file)

    def __instancecheck__(cls, instance):
        """Override for isinstance(instance, cls)."""
        # Inline the cache checking
        subclass = instance.__class__
        if subclass in cls._abc_cache:
            return True
        subtype = type(instance)
        if subtype is subclass:
            if (cls._abc_negative_cache_version ==
                ABCMeta._abc_invalidation_counter and
                subclass in cls._abc_negative_cache):
                return False
            # Fall back to the subclass check.
            return cls.__subclasscheck__(subclass)
        return any(cls.__subclasscheck__(c) for c in {subclass, subtype})

    def __subclasscheck__(cls, subclass):
        """Override for issubclass(subclass, cls)."""
        # Check cache
        if subclass in cls._abc_cache:
            return True
        # Check negative cache; may have to invalidate
        if cls._abc_negative_cache_version < ABCMeta._abc_invalidation_counter:
            # Invalidate the negative cache
            cls._abc_negative_cache = WeakSet()
            cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
        elif subclass in cls._abc_negative_cache:
            return False
        # Check the subclass hook
        ok = cls.__subclasshook__(subclass)
        if ok is not NotImplemented:
            assert isinstance(ok, bool)
            if ok:
                cls._abc_cache.add(subclass)
            else:
                cls._abc_negative_cache.add(subclass)
            return ok
        # Check if it's a direct subclass
        if cls in getattr(subclass, '__mro__', ()):
            cls._abc_cache.add(subclass)
            return True
        # Check if it's a subclass of a registered class (recursive)
        for rcls in cls._abc_registry:
            if issubclass(subclass, rcls):
                cls._abc_cache.add(subclass)
                return True
        # Check if it's a subclass of a subclass (recursive)
        for scls in cls.__subclasses__():
            if issubclass(subclass, scls):
                cls._abc_cache.add(subclass)
                return True
        # No dice; update negative cache
        cls._abc_negative_cache.add(subclass)
        return False


class ABC(metaclass=ABCMeta):
    """Helper class that provides a standard way to create an ABC using
    inheritance.
    """
    pass


def get_cache_token():
    """Returns the current ABC cache token.

    The token is an opaque object (supporting equality testing) identifying the
    current version of the ABC cache for virtual subclasses. The token changes
    with every call to ``register()`` on any ABC.
    """
    return ABCMeta._abc_invalidation_counter


In [None]:
def get_initializer(init_op, seed=None, init_weight=None):
    """Create an initializer. init_weight is only for uniform."""
    if init_op == "uniform":
        assert init_weight
        return tf.random_uniform_initializer(-init_weight, init_weight, seed=seed)
    elif init_op == "glorot_normal":#以0为中心的截断正态分布中抽取样本
        return tf.keras.initializers.glorot_normal(seed=seed)
    elif init_op == "glorot_uniform":#均匀分布初始化
        return tf.keras.initializers.glorot_uniform(seed=seed)
    else:
        raise ValueError("Unknown init_op %s" % init_op)
def get_device_str(device_id, num_gpus):
    """Return a device string for multi-GPU setup."""
    if num_gpus == 0:
        return "/cpu:0"
    device_str_output = "/gpu:%d" % (device_id % num_gpus)
    return device_str_output
def _single_cell(unit_type, num_units, forget_bias, dropout, mode,
                 residual_connection=False, device_str=None, residual_fn=None):
    dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0
    # Cell Type
    if unit_type == "lstm":
        single_cell=tf.contrib.rnn.BasicLSTMCell(num_units=num_units,forget_bias=forget_bias)
    elif unit_type=="gru":
        single_cell = tf.contrib.rnn.GRUCell(num_units)
    elif unit_type == "layer_norm_lstm":
        #LSTM unit with layer normalization and recurrent dropout.
        single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(num_units,forget_bias=forget_bias,layer_norm=True)
    elif unit_type == "nas":
        #Neural Architecture Search (NAS) recurrent network cell.
        #RNN作为一个 controller去生成模型的描述符，然后根据描述符得到模型，
        #进而得到该模型在数据集上的准确度。接着将该准确度作为奖励信号(reward signal)对controller进行更新。
        #如此不断迭代找到合适的网络结构。         
        single_cell = tf.contrib.rnn.NASCell(num_units)
    else:
        raise ValueError("Unknown unit type %s!" % unit_type)
    # Dropout (= 1 - keep_prob)
    if dropout > 0.0:
        single_cell = tf.contrib.rnn.DropoutWrapper(cell=single_cell, input_keep_prob=(1.0 - dropout))
    # Residual
    # RNNCell wrapper that ensures cell inputs are added to the outputs.     
    if residual_connection:
        single_cell=tf.contrib.rnn.ResidualWrapper(single_cell,residual_fn=residual_fn)
    # Device Wrapper
    if device_str:
        single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
    return single_cell
    

def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers,
                    forget_bias, dropout, mode, num_gpus, base_gpu=0,single_cell_fn=None):
    cell_list = []
    cell_list = _cell_list(unit_type=unit_type,
                         num_units=num_units,
                         num_layers=num_layers,
                         num_residual_layers=num_residual_layers,
                         forget_bias=forget_bias,
                         dropout=dropout,
                         mode=mode,
                         num_gpus=num_gpus,
                         base_gpu=base_gpu,
                         single_cell_fn=single_cell_fn)
    
    if len(cell_list)==1:
        return cell_list[0]
    else:
        return tf.contrib.rnn.MultiRNNCell(cell_list)

def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
               forget_bias, dropout, mode, num_gpus, base_gpu=0,
               single_cell_fn=None, residual_fn=None):
    """Create a list of RNN cells."""
    if not single_cell_fn:
        single_cell_fn = _single_cell
    
    #Multi-GPU
    cell_list=[]
    for i in range(num_layers):
        single_cell = single_cell_fn(
            unit_type=unit_type,
            num_units=num_units,
            forget_bias=forget_bias,
            dropout=dropout,
            mode=mode,
            residual_connection=(i >= num_layers - num_residual_layers),
            device_str=get_device_str(i + base_gpu, num_gpus),
            residual_fn=residual_fn)
        cell_list.append(single_cell)
    return cell_list


def _create_attention_images_summary(final_context_state):
    """create attention image and attention summary."""
    attention_images = (final_context_state.alignment_history.stack())
    # Reshape to (batch, src_seq_len, tgt_seq_len,1)
    attention_images = tf.expand_dims(tf.transpose(attention_images, [1, 2, 0]), -1)
    # Scale to range [0, 255]
    attention_images *= 255
    attention_summary = tf.summary.image("attention_images", attention_images)
    return attention_summary   
def create_attention_mechanism(attention_option, num_units, memory,source_sequence_length, mode):
    """Create attention mechanism based on the attention_option."""
    del mode  # unused
    # Mechanism
    if attention_option=="luong":
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units, memory, memory_sequence_length=source_sequence_length)
    elif attention_option == "scaled_luong":
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units,memory,memory_sequence_length=source_sequence_length,scale=True)
    elif attention_option == "bahdanau":
        attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units,memory,memory_sequence_length=source_sequence_length,normalize=True)
    else:
        raise ValueError("Unknown attention option %s" % attention_option)
    return attention_mechanism


In [None]:
'''
Helper类是抽象基础类，其中定义了几个抽象方法，如initialize、sample、next_inputs，接下来所有的具体类都是继承自Helper抽象类。
CustomHelper类虽然是继承自Helper类的一个具体类，但是这个类没有外加太多约束，它需要用户自定义initialize_fn, sample_fn, 
next_inputs_fn这三个函数，而InferenceHelper类我们可以看成是CustomHelper类的一个特殊情况，由于这个类只在推断的时候使用，
因此在next_inputs函数中只需要将前一时刻的抽样结果作为下一时刻的输入即可；GreedyEmbeddingHelper类也是用于推理过程，
不过它是采取argmax抽样算法来得到输出id，并且经过embedding层作为下一时刻的输入；
而SampleEmbeddingHelper是继承自GreedyEmbeddingHelper类的一个类，与GreedyEmbeddingHelper类不同的是，
SampleEmbeddingLayer是通过抽样算法来得到解码器的输出。
TrainingHelper类也是继承自Helper类的一个具体类，在sample过程中，它采用的是最简单的argmax算法；
而ScheduledEmbeddingTrainingHelper类是继承自TrainingHelper类，其中的sample算法采取的是广义伯努利算法，
并且并不是每一个时刻都会采样，同时这里添加了embedding操作，即根据解码器的输出id从embedding矩阵中查找其对应的embedding向量；
ScheduledOutputTrainingHelper类同样也是继承自TrainingHelper类，没有embedding操作，直接对输出进行抽样。
TrainingHelper：适用于训练的helper。
InferenceHelper：适用于测试的helper。
GreedyEmbeddingHelper：适用于测试中采用Greedy策略sample的helper。
CustomHelper：用户自定义的helper。

4. beam search decoder
除了上面提到的argmax算法和伯努利抽样算法以外，我们还可以使用Beam Search的抽样方法来获得最终的解码序列，
在beam_search_decoder.py文件中，BeamSearchDecoder类是继承自Decoder类的，与之前的BasicDecoder类不同的是，
BasicDecoder类需要设定helper参数，而这里的BeamSearchDecoder没有helper参数，因为它所采用的算法是Beam Search，
其细节在该文件中有实现。
'''


In [9]:
from tensorflow.python.layers import core
# 基础的seq2seq model
class BaseModel(object):
    def __init__(
        self,
        hparams,#超参数
        mode,#train/eval/infer/
        iterator,#迭代次数
        source_vocab_table,#
        target_vocab_table,#
        reverse_target_vocab_table=None,#Lookup table mapping ids to target words. Only required in INFER mode. Defaults to None.
        scope=None,#模型的scope
        extra_args=None):#model_helper.ExtraArgs, for passing customizable functions.
        self.iterator = iterator
        self.mode = mode
        self.src_vocab_table = source_vocab_table
        self.tgt_vocab_table = target_vocab_table

        self.src_vocab_size = hparams.src_vocab_size
        self.tgt_vocab_size = hparams.tgt_vocab_size
        self.num_gpus = hparams.num_gpus
        self.time_major = hparams.time_major


        self.single_cell_fn = None
        # extra_args: to make it flexible for adding external customizable code
        if extra_args:
            self.single_cell_fn = extra_args.single_cell_fn
        # Set num layers
        self.num_encoder_layers = hparams.num_encoder_layers
        self.num_decoder_layers = hparams.num_decoder_layers
        # Set num residual layers
        if hasattr(hparams, "num_residual_layers"):  # compatible common_test_utils
            self.num_encoder_residual_layers = hparams.num_residual_layers
            self.num_decoder_residual_layers = hparams.num_residual_layers
        else:
            self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
            self.num_decoder_residual_layers = hparams.num_decoder_residual_layers

        # Initializer
        initializer=get_initializer(hparams.init_op, hparams.random_seed, hparams.init_weight)
        tf.get_variable_scope().set_initializer(initializer)

        # Projection
        with tf.variable_scope(scope or 'build_network'):
            with tf.variable_scope("decoder/output_projection"):
                self.output_layer = core.Dense(hparams.tgt_vocab_size, use_bias=False, name="output_projection")

        ## Train graph
        res = self.build_graph(hparams, scope=scope)

        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.train_loss = res[1]
            self.word_count=tf.reduce_sum(self.iterator.source_sequence_length)+tf.reduce_sum(self.iterator.target_sequence_length)
        elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
            self.eval_loss = res[1]
        elif self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_logits, _, self.final_context_state, self.sample_id = res
            self.sample_words = reverse_target_vocab_table.lookup(tf.to_int64(self.sample_id))

        if self.mode != tf.contrib.learn.ModeKeys.INFER:
            self.predict_count = tf.reduce_sum(self.iterator.target_sequence_length)

        self.global_step= tf.Variable(0,trainable=False)
        params = tf.trainable_variables()

        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.learning_rate = tf.constant(hparams.learning_rate)
            # warm-up
            self.learning_rate = self._get_learning_rate_warmup(hparams)
            # decay
            self.learning_rate = self._get_learning_rate_decay(hparams)
        
        
    """Subclass must implement this method."""
    def build_graph(self, hparams, scope=None):
        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype):
            # Encoder
            encoder_outputs, encoder_state = self._build_encoder(hparams)
            ## Decoder
            logits, sample_id, final_context_state = self._build_decoder(encoder_outputs, encoder_state, hparams)
            ## Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                with tf.device(get_device_str(self.num_encoder_layers - 1,self.num_gpus)):
                    loss = self._compute_loss(logits)
            else:
                loss = None
            return logits, loss, final_context_state, sample_id
    
    def _compute_loss(self,logits):
        target_output = self.iterator.target_output
        if self.time_major:
            target_output = tf.transpose(target_output)
        max_time = self.get_max_time(target_output)
        #sparse_softmax_cross_entropy_with_logits:lables接受直接的数字标签
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_output,logits=logits)
        '''
        tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                  #  [True, True, True, False, False],
                                  #  [True, True, False, False, False]]

        tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                    #   [True, True, True]],
                                    #  [[True, True, False],
                                    #   [False, False, False]]]
        '''
        target_weights = tf.sequence_mask(self.iterator.target_sequence_length,max_time,dtype=logits.dtype)
        if self.time_major:
            target_weights=tf.transpose(target_weights)
        
        loss = tf.reduce_sum(crossent*target_weights)/tf.to_float(self.batch_size)
        return loss
    
    def _get_learning_rate_warmup(self,hparams):
        """Get learning rate warmup.调整学习率退火方案和动量参数的SGD不仅可以与Adam竞争，而且收敛速度更快。"""
        warmup_steps = hparams.warmup_steps
        warmup_scheme = hparams.warmup_scheme
        # When step < warmup_steps,
        #   learing_rate *= warmup_factor ** (warmup_steps - step)
        if warmup_scheme == "t2t":
            # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
            warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
            inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step))
        else:
            raise ValueError("Unknown warmup scheme %s" % warmup_scheme)
        return tf.cond(self.global_step < hparams.warmup_steps,
                       lambda: inv_decay * self.learning_rate,
                       lambda: self.learning_rate,
                       name="learning_rate_warump_cond")
    
    def _get_learning_rate_decay(self, hparams):
        """Get learning rate decay."""
        if hparams.decay_scheme in ["luong5", "luong10", "luong234"]:
            decay_factor = 0.5
            if hparams.decay_scheme == "luong5":
                start_decay_step = int(hparams.num_train_steps / 2)
                decay_times = 5
            elif hparams.decay_scheme == "luong10":
                start_decay_step = int(hparams.num_train_steps / 2)
                decay_times = 10
            elif hparams.decay_scheme == "luong234":
                start_decay_step = int(hparams.num_train_steps * 2 / 3)
                decay_times = 4
            remain_steps = hparams.num_train_steps - start_decay_step
            decay_steps = int(remain_steps / decay_times)
        elif not hparams.decay_scheme:  # no decay
            start_decay_step = hparams.num_train_steps
            decay_steps = 0
            decay_factor = 1.0
        elif hparams.decay_scheme:
            raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme)
        return tf.cond(self.global_step < start_decay_step,
                        lambda: self.learning_rate,
                        lambda: tf.train.exponential_decay(
                        self.learning_rate,
                        (self.global_step - start_decay_step),
                        decay_steps, decay_factor, staircase=True),
                        name="learning_rate_decay_cond")
    def init_embeddings(self, hparams, scope):
        """Init embeddings."""
        self.embedding_encoder, self.embedding_decoder = (
            model_helper.create_emb_for_encoder_and_decoder(
                share_vocab=hparams.share_vocab,
                src_vocab_size=self.src_vocab_size,
                tgt_vocab_size=self.tgt_vocab_size,
                src_embed_size=hparams.num_units,
                tgt_embed_size=hparams.num_units,
                num_partitions=hparams.num_embeddings_partitions,
                src_vocab_file=hparams.src_vocab_file,
                tgt_vocab_file=hparams.tgt_vocab_file,
                src_embed_file=hparams.src_embed_file,
                tgt_embed_file=hparams.tgt_embed_file,
                scope=scope,))
    def train(self, sess):
        assert self.mode == tf.contrib.learn.ModeKeys.TRAIN
        return sess.run([self.update,
                         self.train_loss,
                         self.predict_count,
                         self.train_summary,
                         self.global_step,
                         self.word_count,
                         self.batch_size,
                         self.grad_norm,
                         self.learning_rate])
    def eval(self, sess):
        assert self.mode == tf.contrib.learn.ModeKeys.EVAL
        return sess.run([self.eval_loss,
                         self.predict_count,
                         self.batch_size])
    def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
        """Maximum decoding steps at inference time."""
        if hparams.tgt_max_len_infer:
            maximum_iterations=hparams.tgt_max_len_infer
        else:
            # TODO(thangluong): add decoding_length_factor flag
            decoding_length_factor = 2.0
            max_encoder_length = tf.reduce_max(source_sequence_length)
            maximum_iterations = tf.to_int32(tf.round(tf.to_float(max_encoder_length)*decoding_length_factor))#干嘛用的？
        return maximum_iterations
    
    
    @abstractmethod
    def _build_encoder(self,hparams):
        pass
    def _build_encoder_cell(self, hparams, num_layers, num_residual_layers,base_gpu=0):
        return create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=hparams.num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=hparams.num_gpus,
            mode=self.mode,
            base_gpu=base_gpu,
            single_cell_fn=self.single_cell_fn)
    @abstractmethod
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,source_sequence_length):
        pass
    def _build_decoder(self,encoder_outputs,encoder_state,hparams):
        """Build and run a RNN decoder with a final projection layer."""
        tgt_sos_id=tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.sos)),tf.int32)#何用？
        tgt_eos_id=tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.eos)),tf.int32)
        iterator=self.iterator
        
        # maximum_iteration: The maximum decoding steps.
        maximum_iterations = self._get_infer_maximum_iterations(hparams,iterator.source_sequence_length)
        
        ## Decoder.
        with tf.variable_scope("decoder") as decoder_scope:
            cell, decoder_initial_state= self._build_decoder_cell(hparams, encoder_outputs, encoder_state,iterator.source_sequence_length)
            
            ## Train or eval
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = iterator.target_input
                if self.time_major:
                    target_input = tf.transpose(target_input)
                decoder_emb_inp = tf.nn.embedding_lookup(self.embedding_decoder,target_input)
                
                # Helper
                #A helper for use during training. Only reads inputs.Returned sample_ids are the argmax of the RNN output logits.
                helper = tf.contrib.seq2seq.TrainingHelper(
                    decoder_emb_inp, iterator.target_sequence_length,
                    time_major=self.time_major)
                
                #Decoder
                my_decoder = tf.contrib.seq2seq.BasicDecoder(cell,helper,decoder_initial_state,)
                
                # Dynamic decoding
                outputs , final_context_state,_=tf.contrib.seq2seq.dynamic_decode(my_decoder,
                                                                                  output_time_major=self.time_major,
                                                                                  swap_memory=True,
                                                                                  scope=decoder_scope)
                sample_id = outputs.sample_id
                #啥意思？
                # Note: there's a subtle difference here between train and inference.
                # We could have set output_layer when create my_decoder
                #   and shared more code between train and inference.
                # We chose to apply the output_layer to all timesteps for speed:
                #   10% improvements for small models & 20% for larger ones.
                # If memory is a concern, we should apply output_layer per timestep.
                logits = self.output_layer(outputs.rnn_output)
            ## Inference
            else:
                beam_width = hparams.beam_width
                length_penalty_weight = hparams.length_penalty_weight
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id
                
                if beam_width > 0:
                    my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                        cell=cell,
                        embedding=self.embedding_decoder,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=beam_width,
                        output_layer=self.output_layer,
                        length_penalty_weight=length_penalty_weight)
                else:
                    # Helper,抽样方法，不如直接beamsearch
                    sampling_temperature = hparams.sampling_temperature
                    if sampling_temperature > 0.0:
                        helper=tf.contrib.seq2seq.SampleEmbeddingHelper(
                            self.embedding_decoder, start_tokens, end_token,
                            softmax_temperature=sampling_temperature,
                            seed=hparams.random_seed)
                    else:
                        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embedding_decoder, start_tokens, end_token)
                    
                    # Decoder
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell=cell,
                        helper=helper,
                        decoder_initial_state=decoder_initial_state,
                        output_layer=self.output_layer)
                    
                # Dynamic decoding
                outputs,final_context_state,_ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    maximum_iterations=maximum_iterations,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)
                
                if beam_width>0:
                    logits = tf.no_op()
                    sample_id = outputs.predicted_ids
                else:
                    logits = outputs.rnn_output
                    sample_id = outputs.sample_id
        return logits, sample_id, final_context_state
                    

In [10]:
# 动态rnn注意力seq2seq
class Model(BaseModel):
    def _build_bidirectional_rnn(self, inputs, sequence_length,
                               dtype, hparams,
                               num_bi_layers,
                               num_bi_residual_layers,
                               base_gpu=0):
        """Create and call biddirectional RNN cells."""
        fw_cell=self._build_encoder_cell(hparams,num_bi_layers,num_bi_residual_layers,base_gpu=base_gpu)
        bw_cell=self._build_encoder_cell(hparams,num_bi_layers,num_bi_residual_layers,base_gpu=(base_gpu+num_bi_residual_layers))
        bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
            fw_cell,
            bw_cell,
            inputs,
            dtype=dtype,
            sequence_length=sequence_length,
            time_major=self.time_major,
            swap_memory=True)
        return tf.concat(bi_outputs,-1),bi_state
        
        
    def _build_encoder(self,hparams):
        num_layers=self.num_encoder_layers
        num_residual_layers = self.num_encoder_residual_layers
        iterator = self.iterator
        
        source = iterator.source
        if self.time_major:
            source = tf.transpose(source)
        with tf.variable_scope("encoder") as scope:
            dtype = scope.dtype
            # Look up embedding, emp_inp: [max_time, batch_size, num_units]
            encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,source)
            
            # Encoder_outputs: [max_time, batch_size, num_units]
            if hparams.encoder_type == "uni":
                cell = self._build_encoder_cell(hparams, num_layers, num_residual_layers)
                
                encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                    cell,
                    encoder_emb_inp,
                    dtype=dtype,
                    sequence_length=iterator.source_sequence_length,
                    time_major=self.time_major,
                    swap_memory=True)
            elif hparams.encoder_type == "bi":
                num_bi_layers = int(num_layers / 2)
                num_bi_residual_layers = int(num_residual_layers / 2)
                encoder_outputs,bi_encoder_state=self._build_bidirectional_rnn(
                    inputs=encoder_emb_inp,
                    sequence_length=iterator.source_sequence_length,
                    dtype=dtype,
                    hparams=hparams,
                    num_bi_layers=num_bi_layers,
                    num_bi_residual_layers=num_bi_residual_layers)
                if num_bi_layers == 1:
                    encoder_state = bi_encoder_state
                else:
                    # alternatively concat forward and backward states
                    encoder_state = []
                    for layer_id in range(num_bi_layers):
                        encoder_state.append(bi_encoder_state[0][layer_id])  # forward
                        encoder_state.append(bi_encoder_state[1][layer_id])  # backward
                    encoder_state = tuple(encoder_state)
            else:
                raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)
        return encoder_outputs, encoder_state
    
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,source_sequence_length):
        if hparams.attention:
            raise ValueError("BasicModel doesn't support attention.")
        
        cell = create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=hparams.num_units,
            num_layers=self.num_decoder_layers,
            num_residual_layers=self.num_decoder_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)
        # For beam search, we need to replicate encoder infos beam_width times
        #使用tile_batch函数处理一下，将batch_size扩展beam_size倍变成batch_size*beam_size
        """
        使用tile_batch函数处理一下，将batch_size扩展beam_size倍变成batch_size*beam_size
        beam search只在test的时候需要。训练的时候知道正确答案，并不需要再进行这个搜索。
        test的时候，假设词表大小为3，内容为a，b，c。beam size是2
        decoder解码的时候：
        1： 生成第1个词的时候，选择概率最大的2个词，假设为a,c,那么当前序列就是a,c
        2：生成第2个词的时候，我们将当前序列a和c，分别与词表中的所有词进行组合，得到新的6个序
        列aa ab ac ca cb cc,然后从其中选择2个得分最高的，作为当前序列，假如为aa cb
        3：后面会不断重复这个过程，直到遇到结束符为止。最终输出2个得分最高的序列。
        """
        if self.mode == tf.contrib.learn.ModeKeys.INFER and hparams.beam_width > 0:
            decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=hparams.beam_width)
        else:
            decoder_initial_state = encoder_state
        return cell, decoder_initial_state

In [None]:
"""Attention-based sequence-to-sequence model with dynamic RNN support."""
class AttentionModel(Model):
    def __init__(self,
           hparams,
           mode,
           iterator,
           source_vocab_table,
           target_vocab_table,
           reverse_target_vocab_table=None,
           scope=None,
           extra_args=None):
        # Set attention_mechanism_fn
        if extra_args and extra_args.attention_mechanism_fn:
            self.attention_mechanism_fn = extra_args.attention_mechanism_fn
        else:
            self.attention_mechanism_fn = create_attention_mechanism

        super(AttentionModel, self).__init__(
            hparams=hparams,
            mode=mode,
            iterator=iterator,
            source_vocab_table=source_vocab_table,
            target_vocab_table=target_vocab_table,
            reverse_target_vocab_table=reverse_target_vocab_table,
            scope=scope,
            extra_args=extra_args)
        if self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_summary = self._get_infer_summary(hparams)
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,source_sequence_length): 
        """Build a RNN cell with attention mechanism that can be used by decoder."""
        attention_option = hparams.attention
        attention_architecture = hparams.attention_architecture
        if attention_architecture != "standard":
            raise ValueError("Unknown attention architecture %s" % attention_architecture)
        num_units = hparams.num_units
        num_layers = self.num_decoder_layers
        num_residual_layers = self.num_decoder_residual_layers
        beam_width = hparams.beam_width
        dtype = tf.float32
        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs
        
        if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
            memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(source_sequence_length, multiplier=beam_width)
            encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size
        
        attention_mechanism = self.attention_mechanism_fn(attention_option, num_units, memory, source_sequence_length, self.mode)
        cell=create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)
            # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
                             beam_width == 0)
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=alignment_history,
            output_attention=hparams.output_attention,
            name="attention")
        # TODO(thangluong): do we need num_layers, num_gpus?
        cell = tf.contrib.rnn.DeviceWrapper(cell,model_helper.get_device_str(num_layers - 1, self.num_gpus))
        
        if hparams.pass_hidden_state:
            decoder_initial_state = cell.zero_state(batch_size, dtype).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)
        return cell, decoder_initial_state
    
    def _get_infer_summary(self, hparams):
        if hparams.beam_width > 0:
            return tf.no_op()
        return _create_attention_images_summary(self.final_context_state)


In [None]:
class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell):
    """A MultiCell with GNMT attention style."""
    def __init__(self, attention_cell, cells, use_new_attention=False):
        """Creates a GNMTAttentionMultiCell.
        Args:
          attention_cell: An instance of AttentionWrapper.
          cells: A list of RNNCell wrapped with AttentionInputWrapper.
          use_new_attention: Whether to use the attention generated from current
            step bottom layer's output. Default is False.
        """
        cells = [attention_cell] + cells
        self.use_new_attention = use_new_attention
        super(GNMTAttentionMultiCell,self).__init__(cells,state_is_tuple=True)
        
    def __call__(self,inputs,state,scope=None):
        """Run the cell with bottom layer's attention copied to all upper layers."""
        if not nest.is_sequence(state): 
            raise ValueError("Expected state to be a tuple of length %d, but received: %s"% (len(self.state_size), state))
        with tf.variable_scope(scope or "multi_rnn_cell"):
            new_states = []
            with tf.variable_scope("cell_0_attention"):
                attention_cell = self._cells[0]
                attention_state = state[0]
                cur_inp, new_attention_state = attention_cell(inputs, attention_state)
                new_states.append(new_attention_state)
            
            for i in range(1,len(self._cells)):
                with tf.variable_scope("cell_%d" % i):
                    cell = self.cells[i]
                    cur_state = state[i]
                    
                    if self.use_new_attention:
                        cur_inp = tf.concat([cur_inp,new_attention_state.attention],-1)
                    else:
                        cur_inp = tf.concat([cur_inp,attention_state.attention],-1)
                    
                    cur_inp,new_state =cell(cur_inp,cur_state)
                    new_states.append(new_state)
        return cur_inp,tuple(new_states)
    
    def gnmt_residual_fn(inputs,outputs):
        """Residual function that handles different inputs and outputs inner dims.
        # 
        Args:
        inputs: cell inputs, this is actual inputs concatenated with the attention
          vector.
        outputs: cell outputs

        Returns:
        outputs + actual inputs
        """
        
        def split_input(inp,out):
            out_dim = out.get_shape().as_list()[-1]
            inp_dim = inp.get_shape().as_list()[-1]
            eturn tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
        actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)
        def assert_shape_match(inp, out):
            inp.get_shape().assert_is_compatible_with(out.get_shape())
        nest.assert_same_structure(actual_inputs, outputs)
        nest.map_structure(assert_shape_match, actual_inputs, outputs)
        return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)

In [None]:
"""GNMT attention sequence-to-sequence model with dynamic RNN support."""
class GNMTModel(AttentionModel):
    def __init__(self,
                hparams,
                mode,
                iterator,
                source_vocab_table,
                target_vocab_table,
                reverse_target_vocab_table=None,
                scope=None,
                extra_args=None):
        super(GNMTModel, self).__init__(
            hparams=hparams,
            mode=mode,
            iterator=iterator,
            source_vocab_table=source_vocab_table,
            target_vocab_table=target_vocab_table,
            reverse_target_vocab_table=reverse_target_vocab_table,
            scope=scope,
            extra_args=extra_args)
    def _build_encoder(self, hparams):
        """Build a GNMT encoder."""
        if hparams.encoder_type == "uni" or hparams.encoder_type == "bi":
            return super(GNMTModel, self)._build_encoder(hparams)if hparams
        if hparams.encoder_type != "gnmt":
            raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)
        
        # Build GNMT encoder.
        num_bi_layers = 1
        num_uni_layers = self.num_encoder_layers - num_bi_layers
        iterator = self.iterator
        source = iterator.source
        if self.time_major:
            source = tf.transpose(source)
        with tf.variable_scope("encoder") as scope:
            dtype = scope.dtype
            # Look up embedding, emp_inp: [max_time, batch_size, num_units]
            #   when time_major = True
            #一层bi-rnn加多层uni-rnn
            encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,
                                       source)
            
            # Execute _build_bidirectional_rnn from Model class
            bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn(
                inputs=encoder_emb_inp,
                sequence_length=iterator.source_sequence_length,
                dtype=dtype,
                hparams=hparams,
                num_bi_layers=num_bi_layers,
                num_bi_residual_layers=0,  # no residual connection
            )
            
            uni_cell = create_rnn_cell(
                unit_type=hparams.unit_type,
                num_units=hparams.num_units,
                num_layers=num_uni_layers,
                num_residual_layers=self.num_encoder_residual_layers,
                forget_bias=hparams.forget_bias,
                dropout=hparams.dropout,
                num_gpus=self.num_gpus,
                base_gpu=1,
                mode=self.mode,
                single_cell_fn=self.single_cell_fn)
            # encoder_outputs: size [max_time, batch_size, num_units]
            #   when time_major = True
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                uni_cell,
                bi_encoder_outputs,
                dtype=dtype,
                sequence_length=iterator.source_sequence_length,
                time_major=self.time_major)
            
            # Pass all encoder state except the first bi-directional layer's state to decoder.
            encoder_state = (bi_encoder_state[1],) + ((encoder_state,) if num_uni_layers == 1 else encoder_state)
        return encoder_outputs, encoder_state   
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,source_sequence_length):
        """Build a RNN cell with GNMT attention architecture."""
        # Standard attention
        if hparams.attention_architecture == "standard":
            return super(GNMTModel, self)._build_decoder_cell(
                hparams, encoder_outputs, encoder_state, source_sequence_length)
        
        # GNMT attention
        attention_option = hparams.attention
        attention_architecture = hparams.attention_architecture
        num_units = hparams.num_units
        beam_width = hparams.beam_width
        
        dtype = tf.float32
        
        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs
        
        if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
            memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(source_sequence_length, multiplier=beam_width)
            encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size
        
        attention_mechanism = self.attention_mechanism_fn(attention_option, num_units, memory, source_sequence_length, self.mode)
        cell_list = model_helper._cell_list(  # pylint: disable=protected-access
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=self.num_decoder_layers,
            num_residual_layers=self.num_decoder_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn,
            residual_fn=gnmt_residual_fn)
        # Only wrap the bottom layer with the attention mechanism.
        attention_cell = cell_list.pop(0)
        # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width == 0)
        attention_cell = tf.contrib.seq2seq.AttentionWrapper(
                                    attention_cell,
                                    attention_mechanism,
                                    attention_layer_size=None,  # don't use attention layer.
                                    output_attention=False,
                                    alignment_history=alignment_history,
                                    name="attention")
        if attention_architecture == "gnmt":
            cell = GNMTAttentionMultiCell(attention_cell, cell_list)
        elif attention_architecture == "gnmt_v2":
            cell = GNMTAttentionMultiCell(attention_cell, cell_list, use_new_attention=True)
        else:
            raise ValueError("Unknown attention_architecture %s" % attention_architecture)
        if hparams.pass_hidden_state:
            decoder_initial_state = tuple(zs.clone(cell_state=es)
            if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es
            for zs, es in zip(cell.zero_state(batch_size, dtype), encoder_state))
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)
        return cell,decoder_initial_state
    
    
    def _get_infer_summary(self, hparams):
        # Standard attention
        if hparams.attention_architecture == "standard":
            return super(GNMTModel, self)._get_infer_summary(hparams)

        # GNMT attention
        if hparams.beam_width > 0:
            return tf.no_op()
        return attention_model._create_attention_images_summary(self.final_context_state[0])