In [None]:
import tensorflow as tf

In [None]:
params = {}

In [None]:
def init_checkpoint(dir):
    checkpoint = tf.train.Checkpoint()
    init_path = checkpoint.save(os.path.join(dir, 'init'))
    checkpoint.restore(init_path)

In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz, embedding_matrix):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        # self.enc_units = enc_units
        self.enc_units = enc_units // 2
        self.embedding = tf.keras.layers.Embedding(vocab_size,
                                                   embedding_dim,
                                                   weights=[embedding_matrix],
                                                   trainable=False)
        # tf.keras.layers.GRU自动匹配cpu、gpu
        self.gru = tf.keras.layers.GRU(self.enc_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

        self.bigru = tf.keras.layers.Bidirectional(self.gru, merge_mode='concat')

    def call(self, x, hidden):
        x = self.embedding(x)
        hidden = tf.split(hidden, num_or_size_splits=2, axis=1)
        output, forward_state, backward_state = self.bigru(x, initial_state=hidden)
        state = tf.concat([forward_state, backward_state], axis=1)
        # output, state = self.gru(x, initial_state=hidden)
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, 2*self.enc_units))


In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, embedding_matrix):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
                                                   weights=[embedding_matrix],
                                                   trainable=False)
        self.gru = tf.keras.layers.GRU(self.dec_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')
        # self.dropout = tf.keras.layers.Dropout(0.5)
        self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax)
        # self.fc = tf.keras.layers.Dense(vocab_size)

    def call(self, x, hidden, enc_output, context_vector):
        # def call(self, x, context_vector):

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(x)
        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))

        # output shape == (batch_size, vocab)
        # out = self.dropout(output)
        out = self.fc(output)

        return x, out, state

In [None]:
class BahdanauAttentionCoverage(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttentionCoverage, self).__init__()
        self.Wc = tf.keras.layers.Dense(units)
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, dec_hidden, enc_output, enc_padding_mask, use_coverage=False, prev_coverage=None):
        """
        :param dec_hidden: shape=(16, 256)
        :param enc_output: shape=(16, 200, 256)
        :param enc_padding_mask: shape=(16, 200)
        :param use_coverage:
        :param prev_coverage: None
        :return:
        """
        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)  # shape=(16, 1, 256)
        # att_features = self.W1(enc_output) + self.W2(hidden_with_time_axis)

        def masked_attention(score):
            """
            :param score: shape=(16, 200, 1)
                        ...
              [-0.50474256]
              [-0.47997713]
              [-0.42284346]]]
            :return:
            """
            attn_dist = tf.squeeze(score, axis=2)  # shape=(16, 200)
            attn_dist = tf.nn.softmax(attn_dist, axis=1)  # shape=(16, 200)
            mask = tf.cast(enc_padding_mask, dtype=attn_dist.dtype)
            attn_dist *= mask
            masked_sums = tf.reduce_sum(attn_dist, axis=1)
            attn_dist = attn_dist / tf.reshape(masked_sums, [-1, 1])
            attn_dist = tf.expand_dims(attn_dist, axis=2)
            return attn_dist

        if use_coverage and prev_coverage is not None:  # non-first step of coverage
            # Multiply coverage vector by w_c to get coverage_features.
            # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
            # shape (batch_size,attn_length)
            e = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis) + self.Wc(prev_coverage)))
            # Calculate attention distribution
            attn_dist = masked_attention(e)
            # Update coverage vector
            coverage = attn_dist + prev_coverage

        else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))  # shape=(16, 200, 1)
            # Calculate attention distribution
            attn_dist = masked_attention(e)  # shape=(16, 200, 1)
            if use_coverage:  # first step of training
                coverage = attn_dist  # initialize coverage
            else:
                coverage = []

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attn_dist * enc_output  # shape=(16, 200, 256)
        context_vector = tf.reduce_sum(context_vector, axis=1)  # shape=(16, 256)
        # coverage  shape=(16, 200, 1)
        return context_vector, tf.squeeze(attn_dist, -1), coverage


In [None]:
class Pointer(tf.keras.layers.Layer):

    def __init__(self):
        super(Pointer, self).__init__()
        self.w_s_reduce = tf.keras.layers.Dense(1)
        self.w_i_reduce = tf.keras.layers.Dense(1)
        self.w_c_reduce = tf.keras.layers.Dense(1)

    def call(self, context_vector, state, dec_inp):
        return tf.nn.sigmoid(self.w_s_reduce(state) + self.w_c_reduce(context_vector) + self.w_i_reduce(dec_inp))

In [None]:
class PGN(tf.keras.Model):
    def __init__(self, params):
        super(PGN, self).__init__()
        self.embedding_matrix = load_word2vec(params)
        self.params = params
        self.encoder = Encoder(params["vocab_size"],
                                           params["embed_size"],
                                           params["enc_units"],
                                           params["batch_size"],
                                           self.embedding_matrix)
        self.attention = BahdanauAttentionCoverage(params["attn_units"])
        self.decoder = Decoder(params["vocab_size"],
                                           params["embed_size"],
                                           params["dec_units"],
                                           params["batch_size"],
                                           self.embedding_matrix)
        self.pointer = Pointer()

    def call_encoder(self, enc_inp):
        enc_hidden = self.encoder.initialize_hidden_state()
        # [batch_sz, max_train_x, enc_units], [batch_sz, enc_units]
        enc_output, enc_hidden = self.encoder(enc_inp, enc_hidden)
        return enc_output, enc_hidden

    def call(self, enc_output, dec_hidden, enc_inp,
             enc_extended_inp, dec_inp, batch_oov_len,
             enc_padding_mask, use_coverage, prev_coverage):
        predictions = []
        attentions = []
        coverages = []
        p_gens = []
        context_vector, attn_dist, coverage_next = self.attention(dec_hidden,  # shape=(16, 256)
                                                                  enc_output,  # shape=(16, 200, 256)
                                                                  enc_padding_mask,  # (16, 200)
                                                                  use_coverage,
                                                                  prev_coverage)  # None
        for t in range(dec_inp.shape[1]):
            # Teachering Forcing
            dec_x, pred, dec_hidden = self.decoder(tf.expand_dims(dec_inp[:, t], 1),
                                                   dec_hidden,
                                                   enc_output,
                                                   context_vector)
            context_vector, attn_dist, coverage_next = self.attention(dec_hidden,
                                                                      enc_output,
                                                                      enc_padding_mask,
                                                                      use_coverage,
                                                                      coverage_next)
            p_gen = self.pointer(context_vector, dec_hidden, tf.squeeze(dec_x, axis=1))
            predictions.append(pred)
            coverages.append(coverage_next)
            attentions.append(attn_dist)
            p_gens.append(p_gen)
        
        final_dists = decoding.calc_final_dist(enc_extended_inp,
                                                predictions,
                                                attentions,
                                                p_gens,
                                                batch_oov_len,
                                                self.params["vocab_size"],
                                                self.params["batch_size"])
        # outputs = dict(logits=tf.stack(final_dists, 1), dec_hidden=dec_hidden, attentions=attentions, coverages=coverages)
        if self.params['mode'] == "train":
            outputs = dict(logits=final_dists, dec_hidden=dec_hidden, attentions=attentions, coverages=coverages, p_gens=p_gens)
        else:
            outputs = dict(logits=tf.stack(final_dists, 1),
                           dec_hidden=dec_hidden,
                           attentions=tf.stack(attentions, 1),
                           coverages=tf.stack(coverages, 1),
                           p_gens=tf.stack(p_gens, 1))
        
        return outputs

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='none')


def loss_function(real, outputs, padding_mask, cov_loss_wt, use_coverage):
    pred = outputs["logits"]
    attn_dists = outputs["attentions"]
    if use_coverage:
        loss = pgn_log_loss_function(real, pred, padding_mask) + cov_loss_wt * _coverage_loss(attn_dists, padding_mask)
        return loss
    else:
        return seq2seq_loss_function(real, pred, padding_mask)


def seq2seq_loss_function(real, pred, padding_mask):
    """
    跑seq2seq时用的Loss
    :param real: shape=(16, 50)
    :param pred: shape=(16, 50, 30000)
    :return:
    """
    loss = 0
    for t in range(real.shape[1]):
        loss_ = loss_object(real[:, t], pred[:, t])
        mask = tf.cast(padding_mask[:, t], dtype=loss_.dtype)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        loss_ = tf.reduce_mean(loss_)
        loss += loss_
    return loss / real.shape[1]


def pgn_log_loss_function(real, final_dists, padding_mask):
    # Calculate the loss per step
    # This is fiddly; we use tf.gather_nd to pick out the probabilities of the gold target words
    loss_per_step = []  # will be list length max_dec_steps containing shape (batch_size)
    batch_nums = tf.range(0, limit=real.shape[0])  # shape (batch_size)
    for dec_step, dist in enumerate(final_dists):
        # The indices of the target words. shape (batch_size)
        targets = real[:, dec_step]
        indices = tf.stack((batch_nums, targets), axis=1)  # shape (batch_size, 2)
        gold_probs = tf.gather_nd(dist, indices)  # shape (batch_size). prob of correct words on this step
        losses = -tf.math.log(gold_probs)
        loss_per_step.append(losses)
    # Apply dec_padding_mask and get loss
    _loss = _mask_and_avg(loss_per_step, padding_mask)
    return _loss


def _mask_and_avg(values, padding_mask):
    """Applies mask to values then returns overall average (a scalar)
    Args:
      values: a list length max_dec_steps containing arrays shape (batch_size).
      padding_mask: tensor shape (batch_size, max_dec_steps) containing 1s and 0s.
    Returns:
      a scalar
    """
    # padding_mask is Tensor("Cast_2:0", shape=(64, 400), dtype=float32)
    padding_mask = tf.cast(padding_mask, dtype=values[0].dtype)
    dec_lens = tf.reduce_sum(padding_mask, axis=1)  # shape batch_size. float32
    values_per_step = [v * padding_mask[:, dec_step] for dec_step, v in enumerate(values)]
    values_per_ex = sum(values_per_step) / dec_lens  # shape (batch_size); normalized value for each batch member
    return tf.reduce_mean(values_per_ex)  # overall average


def _coverage_loss(attn_dists, padding_mask):
    """Calculates the coverage loss from the attention distributions.
    Args:
      attn_dists: The attention distributions for each decoder timestep.
      A list length max_dec_steps containing shape (batch_size, attn_length)
      padding_mask: shape (batch_size, max_dec_steps).
    Returns:
      coverage_loss: scalar
    """
    coverage = tf.zeros_like(attn_dists[0])  # shape (batch_size, attn_length). Initial coverage is zero.
    # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).
    covlosses = []
    for a in attn_dists:
        covloss = tf.reduce_sum(tf.minimum(a, coverage), [1])  # calculate the coverage loss for this step
        covlosses.append(covloss)
        coverage += a  # update the coverage vector
    coverage_loss = _mask_and_avg(covlosses, padding_mask)
    return coverage_loss

In [None]:
def train_model_pgn(model, dataset, params, ckpt_manager):
    # optimizer = tf.keras.optimizers.Adagrad(params['learning_rate'],
    #                                         initial_accumulator_value=params['adagrad_init_acc'],
    #                                         clipnorm=params['max_grad_norm'])
    optimizer = tf.keras.optimizers.Adam(name='Adam', learning_rate=params["learning_rate"])

    @tf.function()
    def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len, enc_padding_mask, padding_mask):
        # loss = 0
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = model.call_encoder(enc_inp)
            dec_hidden = enc_hidden
            outputs = model(enc_output,  # shape=(3, 200, 256)
                            dec_hidden,  # shape=(3, 256)
                            enc_inp,  # shape=(3, 200)
                            enc_extended_inp,  # shape=(3, 200)
                            dec_inp,  # shape=(3, 50)
                            batch_oov_len,  # shape=()
                            enc_padding_mask,  # shape=(3, 200)
                            params['is_coverage'],
                            prev_coverage=None)
            loss = loss_function(dec_tar,
                                 outputs,
                                 padding_mask,
                                 params["cov_loss_wt"],
                                 params['is_coverage'])

        # variables = model.trainable_variables
        variables = model.encoder.trainable_variables + \
                    model.attention.trainable_variables + \
                    model.decoder.trainable_variables + \
                    model.pointer.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))
        return loss

    best_loss = 20
    epochs = params['epochs']
    for epoch in range(epochs):
        t0 = time.time()
        step = 0
        total_loss = 0
        # for step, batch in enumerate(dataset.take(params['steps_per_epoch'])):
        for batch in dataset:
            loss = train_step(batch[0]["enc_input"],  # shape=(16, 200)
                              batch[0]["extended_enc_input"],  # shape=(16, 200)
                              batch[1]["dec_input"],  # shape=(16, 50)
                              batch[1]["dec_target"],  # shape=(16, 50)
                              batch[0]["max_oov_len"],  # ()
                              batch[0]["sample_encoder_pad_mask"],  # shape=(16, 200)
                              batch[1]["sample_decoder_pad_mask"])  # shape=(16, 50)

            step += 1
            total_loss += loss
            if step % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, step, total_loss / step))

        if epoch % 2 == 0:
            if total_loss / step < best_loss:
                best_loss = total_loss / step
                ckpt_save_path = ckpt_manager.save()
                print('Saving checkpoint for epoch {} at {} ,best loss {}'.format(epoch + 1, ckpt_save_path, best_loss))
                print('Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / step))
                print('Time taken for 1 epoch {} sec\n'.format(time.time() - t0))


In [None]:
def train(params):
    assert params["mode"].lower() == "train", "change training mode to 'train'"

    vocab = Vocab(params["vocab_path"], params["vocab_size"])
    print('true vocab is ', vocab)

    print("Creating the batcher ...")
    b = batcher(vocab, params)

    print("Building the model ...")
    if params.get('use_pgn'):
        model = PGN(params)
        ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model)
        checkpoint_dir = params["pgn_model_dir"]
    else:
        model = SequenceToSequence(params)
        ckpt = tf.train.Checkpoint(SequenceToSequence=model)
        checkpoint_dir = params["seq2seq_model_dir"]
    print("Creating the checkpoint manager", checkpoint_dir)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, checkpoint_dir, max_to_keep=5, init_fn=init_checkpoint(checkpoint_dir))
    ckpt_manager.restore_or_initialize()
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    print("Starting the training ...")
    if params.get('use_pgn'):
        train_model_pgn(model, b, params, ckpt_manager)
    else:
        train_model(model, b, params, ckpt_manager)