Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can not reproduce sota wikitext103 results #112

Closed
menghuanlater opened this issue Apr 27, 2020 · 4 comments
Closed

can not reproduce sota wikitext103 results #112

menghuanlater opened this issue Apr 27, 2020 · 4 comments

Comments

@menghuanlater
Copy link

i use the pretrained-xl weights and same vocab to build transformer-xl large(we use tensorflow2.0) to eval the test set. But in my experiments, I find the {tgt_len=128, mem_len=1600, clamp_len=1000} just can reach test ppl around 35, and {tgt_len=384, mem_len=384, clamp_len=1000} can reach test ppl around 24, and {tgt_len=2048, mem_len=2048, clamp_len=1000} can reach test ppl around 20, but all of these settings can not reach the paper result 18.3, why?

`#!usr/bin/env python

-- coding:utf-8 --

import tensorflow as tf
from tensorflow import keras
import numpy as np
import pickle
from DataService import DataObjForWT_PTB as DataObj
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

vocab_size_dic = {
"wikitext-103": 267736,
"enwiki8": 0,
"text8": 0
}

class Vanilla_XL(keras.Model):
def init(self, dataset_name: str, segment_size: int, dropout_attn, dropout_norm, n_layers,
n_heads, d_embed, d_model, ffn_mul, cutoffs):
super(Vanilla_XL, self).init()
self.vocab_size = vocab_size_dic[dataset_name]
self.segment_size = segment_size
self.dropout_attn = dropout_attn
self.dropout_norm = dropout_norm
self.d_model = d_model
self.d_embed = d_embed
self.ffn_mul = ffn_mul
self.cutoffs = cutoffs
self.n_layers = n_layers
self.n_heads = n_heads

    # embedding
    self.token_embedding = AdaptiveEmbedding(
        cutoffs=self.cutoffs, d_embed=self.d_embed, embed_drop_rate=self.dropout_norm,
        input_dim=self.vocab_size, out_dim=d_model, div_value=4
    )

    self.all_encoder_layers = []
    for layer in range(self.n_layers):
        self.all_encoder_layers.append(
            SingleTransformerBlock(
                d_model=d_model, ffn_size=self.ffn_mul * d_model, n_heads=self.n_heads,
                dropout_attn=self.dropout_attn, dropout_norm=self.dropout_norm,
                cur_layer=layer
            )
        )

    self.softmax_out_layer = AdaptiveSoftmax(cutoffs=self.cutoffs, d_embed=self.d_embed,
                                             adaptive_embedding_obj=self.token_embedding, div_value=4)

def call(self, inputs, training=None, **kwargs):
    cache = kwargs["cache"] 
    padding_mask = kwargs["padding_mask"]
    segment_embedding = self.token_embedding(inputs=inputs, is_training=training)
    new_cache = [segment_embedding[:, tf.newaxis, :, :]]

    cur_layer_out = segment_embedding
    for layer in range(self.n_layers):
        cur_layer_out = self.all_encoder_layers[layer](
            inputs=cur_layer_out, cache=cache[:, layer, :, :],
            is_training=training, padding_mask=padding_mask
        )
        if layer != self.n_layers - 1:
            new_cache.append(cur_layer_out[:, tf.newaxis, :, :])
    final_out = cur_layer_out
    g_t = kwargs["ground_truth"]
    no_pad_indices = tf.where(tf.not_equal(g_t, PAD))
    final_out = tf.gather_nd(final_out, no_pad_indices)
    g_t = tf.gather_nd(g_t, no_pad_indices)
    log_prob = self.softmax_out_layer(inputs=final_out, ground_truth=g_t)
    return log_prob, tf.concat(new_cache, axis=1)

class AdaptiveEmbedding(keras.layers.Layer):
def init(self, cutoffs, embed_drop_rate, input_dim, out_dim, d_embed, div_value=4):
super(AdaptiveEmbedding, self).init()
assert isinstance(cutoffs, list)
self.cutoffs = cutoffs
self.input_dim = input_dim
self.out_dim = out_dim
self.d_embed = d_embed
self.div_value = div_value

    self.cluster_embedding_list = []
    self.projection_list = []
    self.dropout_layer = keras.layers.Dropout(rate=embed_drop_rate)

    for i in range(len(self.cutoffs) - 1):
        in_dims = self.cutoffs[i + 1] - self.cutoffs[i]
        o_dims = self.d_embed // (self.div_value ** i)
        self.cluster_embedding_list.append(
            keras.layers.Embedding(
                input_dim=in_dims, output_dim=o_dims,
                weights=[tf.convert_to_tensor(
                    pre_train_weights["transformer/adaptive_embed/cutoff_%d/lookup_table:0" % i],
                    dtype=tf.float32)]
            ))
        self.projection_list.append(
            tf.Variable(
                initial_value=tf.convert_to_tensor(
                    pre_train_weights["transformer/adaptive_embed/cutoff_%d/proj_W:0" % i]
                ), dtype=tf.float32
            )
        )

def call(self, inputs, **kwargs):
    for i in range(len(self.cutoffs) - 1):
        start = self.cutoffs[i]
        end = self.cutoffs[i + 1] 
        actual = tf.math.logical_and(inputs >= start, inputs < end)
        mask = tf.expand_dims(tf.cast(actual, dtype=tf.float32), axis=2)
        new_input = inputs - start
        new_input = tf.where(actual, new_input, tf.zeros_like(new_input, dtype=tf.int32))
        embed = self.cluster_embedding_list[i](inputs=new_input)
        linear_proj = tf.matmul(embed, self.projection_list[i], transpose_b=False)
        x.append(tf.multiply(linear_proj, mask))
    out = tf.zeros_like(x[0], dtype=tf.float32)
    for j in range(len(x)):
        out += x[j]
    out *= self.out_dim ** 0.5
    return self.dropout_layer(out, training=kwargs["is_training"])

class AdaptiveSoftmax(keras.layers.Layer):
def init(self, cutoffs, d_embed, adaptive_embedding_obj, div_value=4):
super(AdaptiveSoftmax, self).init()
self.cutoffs = cutoffs
self.d_embed = d_embed
self.div_value = div_value
assert isinstance(adaptive_embedding_obj, AdaptiveEmbedding)
self.adaptive_embedding_obj = adaptive_embedding_obj
self.tail_clusters_embedding = keras.layers.Embedding(
input_dim=len(self.cutoffs) - 2, output_dim=self.d_embed,
weights=[tf.convert_to_tensor(pre_train_weights["transformer/adaptive_softmax/cutoff_0/cluster_W:0"])]
)
self.clusters_bias = tf.Variable(
initial_value=tf.convert_to_tensor(pre_train_weights["transformer/adaptive_softmax/cutoff_0/cluster_b:0"]),
dtype=tf.float32
)

    self.head_projection = tf.Variable(
        initial_value=tf.convert_to_tensor(
            pre_train_weights["transformer/adaptive_softmax/cutoff_0/proj:0"]
        ), dtype=tf.float32
    )

    self.bias_list = []
    for i in range(len(self.cutoffs) - 1):
        self.bias_list.append(
            tf.convert_to_tensor(pre_train_weights["transformer/adaptive_softmax/cutoff_%d/b:0" % i])
        )
    self.projection_list = self.adaptive_embedding_obj.projection_list

def call(self, inputs, **kwargs):
    x = []
    g_t = kwargs["ground_truth"]
    head_all_vocab_embedding = self.adaptive_embedding_obj.cluster_embedding_list[0](
        inputs=tf.convert_to_tensor([i for i in range(self.cutoffs[1] - self.cutoffs[0])], dtype=tf.int32)
    )  # (c0, dim)

    all_tail_cluster_embedding = self.tail_clusters_embedding(
        inputs=tf.convert_to_tensor([i for i in range(len(self.cutoffs) - 2)], dtype=tf.int32)
    )  # (3, dim)
    head_embedding = tf.concat([head_all_vocab_embedding, all_tail_cluster_embedding], axis=0)
    head_proj_out = tf.matmul(inputs, self.head_projection, transpose_b=True)
    head_logits = tf.matmul(head_proj_out, head_embedding, transpose_b=True)
    head_logits += tf.concat([self.bias_list[0], self.clusters_bias], axis=0)
    head_softmax = tf.nn.softmax(head_logits, axis=-1)

    for i in range(len(self.cutoffs) - 1):
        start = self.cutoffs[i]
        end = self.cutoffs[i + 1]
        cur_cluster_indices = tf.where(tf.math.logical_and(g_t >= start, g_t < end))
        seq_len = tf.shape(cur_cluster_indices)[0]
        cur_g_t = tf.gather_nd(g_t, cur_cluster_indices)
        cur_g_t = cur_g_t - start
        cur_g_t = tf.expand_dims(cur_g_t, axis=1)
        first_dim = tf.expand_dims(tf.range(seq_len, dtype=tf.int32), axis=1)
        r_s = tf.concat([first_dim, cur_g_t], axis=1)
        if i == 0: 
            cur_softmax = tf.gather_nd(head_softmax, cur_cluster_indices)
            cur_out_prob = tf.gather_nd(cur_softmax, r_s)
            cur_out_prob = tf.where(cur_out_prob >= 1e-9, cur_out_prob,
                                    tf.ones_like(cur_out_prob, dtype=tf.float32) * 1e-9)
            cur_log_prob = -tf.math.log(cur_out_prob)
        else:
            pre_softmax = tf.gather_nd(head_softmax, cur_cluster_indices)[..., self.cutoffs[1] + i - 2]
            pre_softmax = tf.where(pre_softmax > 1e-9, pre_softmax,
                                   tf.ones_like(pre_softmax, dtype=tf.float32) * 1e-9)
            pre_log_prob = -tf.math.log(pre_softmax)

            cur_inputs = tf.gather_nd(inputs, cur_cluster_indices)

            all_cur_cluster_embedding = self.adaptive_embedding_obj.cluster_embedding_list[i](
                tf.convert_to_tensor([i for i in range(end - start)], dtype=tf.int32)
            )
            cur_inputs = tf.matmul(cur_inputs, self.projection_list[i], transpose_b=True)
            cur_logits = tf.matmul(cur_inputs, all_cur_cluster_embedding, transpose_b=True)
            cur_logits += self.bias_list[i]

            cur_softmax = tf.nn.softmax(cur_logits, axis=-1)
            cur_out_prob = tf.gather_nd(cur_softmax, r_s)
            cur_out_prob = tf.where(cur_out_prob >= 1e-9, cur_out_prob,
                                    tf.ones_like(cur_out_prob, dtype=tf.float32) * 1e-9)
            cur_log_prob = -tf.math.log(cur_out_prob)

            cur_log_prob += pre_log_prob
        x.append(cur_log_prob)
    return tf.concat(x, axis=0)

class SingleTransformerBlock(keras.layers.Layer):
def init(self, d_model, ffn_size, n_heads, dropout_attn, dropout_norm, cur_layer):
super(SingleTransformerBlock, self).init()
self.n_heads = n_heads
self.cur_layer = cur_layer
self.d_model = d_model

    self.w_query = keras.layers.Dense(
        units=d_model, use_bias=False,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/qkv/kernel:0" % cur_layer][:, 0:d_model]
        )
    )
    self.w_key = keras.layers.Dense(
        units=d_model, use_bias=False,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/qkv/kernel:0" % cur_layer][:, d_model:2 * d_model]
        )
    )
    self.w_value = keras.layers.Dense(
        units=d_model, use_bias=False,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/qkv/kernel:0" % cur_layer][:, 2 * d_model:]
        )
    )
    self.w_rel_pos = keras.layers.Dense(
        units=d_model, use_bias=False,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/r/kernel:0" % cur_layer]
        )
    )

    self.w_attn = keras.layers.Dense(
        units=d_model, use_bias=False,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/o/kernel:0" % cur_layer]
        )
    )

    self.w_ffn_up = keras.layers.Dense(
        units=ffn_size, activation="relu",
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/layer_1/kernel:0" % cur_layer]
        ),
        bias_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/layer_1/bias:0" % cur_layer]
        )
    )
    self.w_ffn_down = keras.layers.Dense(
        units=d_model,
        kernel_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/layer_2/kernel:0" % cur_layer]
        ),
        bias_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/layer_2/bias:0" % cur_layer]
        )
    )

    x_ut = tf.convert_to_tensor(pre_train_weights["transformer/r_w_bias:0"][cur_layer],
                                dtype=tf.float32)  # (head, dim // head)
    x_vt = tf.convert_to_tensor(pre_train_weights["transformer/r_r_bias:0"][cur_layer],
                                dtype=tf.float32)  # (head, dim // head)

    self.ut = tf.Variable(initial_value=tf.reshape(
        x_ut, shape=(d_model,)
    ), dtype=tf.float32, trainable=True)
    self.vt = tf.Variable(initial_value=tf.reshape(
        x_vt, shape=(d_model,)
    ), dtype=tf.float32, trainable=True)

    self.attn_drop = keras.layers.Dropout(rate=dropout_attn)
    self.ffn_drop = keras.layers.Dropout(rate=dropout_norm)

    self.attn_ln = keras.layers.LayerNormalization(
        gamma_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/LayerNorm/gamma:0" % cur_layer]
        ), beta_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/rel_attn/LayerNorm/beta:0" % cur_layer]
        )
    )
    self.ffn_ln = keras.layers.LayerNormalization(
        gamma_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/LayerNorm/gamma:0" % cur_layer]
        ), beta_initializer=tf.constant_initializer(
            pre_train_weights["transformer/layer_%d/ff/LayerNorm/beta:0" % cur_layer]
        )
    )

def call(self, inputs, **kwargs):
    cache = kwargs["cache"]
    fusion_inputs = tf.concat([cache, inputs], axis=1)

    query = tf.concat(tf.split(self.w_query(inputs) + self.ut, axis=2, num_or_size_splits=self.n_heads), axis=0)
    q2 = tf.concat(tf.split(self.w_query(inputs) + self.vt, axis=2, num_or_size_splits=self.n_heads), axis=0)
    key = tf.concat(tf.split(self.w_key(fusion_inputs), axis=2, num_or_size_splits=self.n_heads), axis=0)
    value = tf.concat(tf.split(self.w_value(fusion_inputs), axis=2, num_or_size_splits=self.n_heads), axis=0)

    pos_enc = G.create_pre_relative_encoding(seq_length=fusion_inputs.shape[1], dim=fusion_inputs.shape[2])
    pos_enc = tf.tile(self.w_rel_pos(pos_enc)[tf.newaxis, ...], multiples=[fusion_inputs.shape[0], 1, 1])
    pos_enc = tf.concat(tf.split(pos_enc, axis=2, num_or_size_splits=self.n_heads), axis=0)

    attn_out = self.rel_scaled_dot_product_attention(query=query, key=key, value=value, pos_enc=pos_enc,
                                                     padding_mask=kwargs["padding_mask"], q2=q2,
                                                     look_ahead_mask=G.create_look_ahead_mask(
                                                         q_len=inputs.shape[1], k_len=fusion_inputs.shape[1]))
    attn_out = tf.concat(tf.split(attn_out, axis=0, num_or_size_splits=self.n_heads), axis=2)

    attn_out = self.w_attn(attn_out)
    attn_out = self.attn_drop(attn_out, training=kwargs["is_training"])
    res_out_1 = attn_out + inputs
    ln_out_1 = self.attn_ln(res_out_1)

    ffn_up = self.w_ffn_up(ln_out_1)
    ffn_down = self.w_ffn_down(ffn_up)

    ffn_out = self.ffn_drop(ffn_down, training=kwargs["is_training"])
    res_out_2 = ln_out_1 + ffn_out
    ln_out_2 = self.ffn_ln(res_out_2)
    return ln_out_2

@staticmethod
def rel_scaled_dot_product_attention(query, q2, key, value, pos_enc, padding_mask, look_ahead_mask):
    matmul_qk = tf.matmul(query, key, transpose_b=True)
    matmul_qp = tf.matmul(q2, pos_enc, transpose_b=True)

    pad_zero_1 = tf.zeros(shape=(query.shape[0], key.shape[1] - query.shape[1], key.shape[1]),
                          dtype=tf.float32)
    pad_zero_2 = tf.zeros(shape=(query.shape[0], key.shape[1], 1), dtype=tf.float32)
    matmul_qp = tf.concat([pad_zero_2, tf.concat([pad_zero_1, matmul_qp], axis=1)], axis=2)

    matmul_qp = tf.reshape(matmul_qp, shape=(matmul_qp.shape[0], matmul_qp.shape[2], matmul_qp.shape[1]))[:, 1:, :]

    matmul_qp = matmul_qp[:, -query.shape[1]:, :]

    matmul_out = matmul_qk + matmul_qp
    dk = tf.cast(tf.shape(value)[-1], tf.float32)
    scaled_attention_logits = matmul_out / tf.math.sqrt(dk)

    pad_one = tf.ones(shape=(padding_mask.shape[0], key.shape[1] - query.shape[1]), dtype=tf.float32)
    padding_mask = tf.concat([pad_one, padding_mask], axis=1) 
    padding_mask = tf.tile(padding_mask[:, tf.newaxis, :],
                           multiples=[query.shape[0] // padding_mask.shape[0], query.shape[1], 1])
    look_ahead_mask = tf.tile(look_ahead_mask[tf.newaxis, :], multiples=[query.shape[0], 1, 1])
    mask = tf.multiply(padding_mask, look_ahead_mask)
    scaled_attention_logits += (1 - mask) * -1e9
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

    output = tf.matmul(attention_weights, value)

    return output

class GeneralFunction:
@staticmethod
def create_look_ahead_mask(q_len: int, k_len: int, same_length=True):
mask = tf.linalg.band_part(tf.ones(shape=(k_len, k_len), dtype=tf.float32), -1, 0)[-q_len:, ...]
if same_length:
x = mask[:, 0: q_len]
y = mask[:, q_len:]
x = tf.linalg.band_part(x, 0, -1)
mask = tf.concat([x, y], axis=1)
return mask

@staticmethod
def create_pre_relative_encoding(seq_length: int, dim: int):
    pos = np.arange(start=seq_length - 1, step=-1, stop=-1, dtype=np.float32)[..., np.newaxis]
    pos = np.minimum(pos, 1000)
    all_i = np.arange(dim, dtype=np.float32)[np.newaxis, ...]
    angle_rates = 1 / np.power(10000, (2 * (all_i // 2)) / np.float32(dim))
    angle_rads = pos * angle_rates
    x = np.sin(angle_rads[:, 0::2])
    y = np.cos(angle_rads[:, 1::2])
    pos_enc = tf.convert_to_tensor(tf.concat([x, y], axis=-1), dtype=tf.float32)
    return pos_enc

class Main:
def init(self, **kwargs):
self.kwargs = kwargs
self.data_obj = DataObj(dataset_name=kwargs["dataset_name"], segment_size=kwargs["segment_size"],
pad_id=PAD, batch_size=batch_size)
self.cache = self.get_init_cache()
self.model = Vanilla_XL(
dataset_name=kwargs["dataset_name"], n_heads=kwargs["n_heads"], n_layers=kwargs["n_layers"],
dropout_norm=kwargs["dropout_norm"], dropout_attn=kwargs["dropout_attn"],
d_embed=kwargs["d_embed"], ffn_mul=kwargs["ffn_mul"], segment_size=kwargs["segment_size"],
cutoffs=kwargs["cutoffs"], d_model=kwargs["d_model"]
)

def train(self):
    ppl, count = self.eval(is_valid=True)
    print("valid_ppl: %.3f, all_tokens:%d" % (ppl, count))
    ppl, count = self.eval(is_valid=False)
    print("test_ppl: %.3f, all_tokens:%d" % (ppl, count))

def eval(self, is_valid):
    sum_loss, sum_count = 0.0, 0
    dic = self.data_obj.get_next_valid_test_segment(is_valid=is_valid)
    self.cache = self.get_init_cache()
    while dic is not None:
        loss, count, new_cache = self.eval_step(inputs=dic["input_ids"], ground_truth=dic["ground_truth"],
                                                padding_mask=dic["input_mask"])
        self.cache = tf.concat(
            [self.cache[:, :, self.kwargs["segment_size"]:, :], new_cache], axis=2
        )
        sum_loss += loss
        sum_count += count
        dic = self.data_obj.get_next_valid_test_segment(is_valid=is_valid)
    ppl = tf.exp(sum_loss / sum_count)
    return ppl, sum_count

@tf.function
def eval_step(self, inputs, ground_truth, padding_mask):
    log_prob, new_seg_cache = self.model(inputs=inputs, training=False, padding_mask=padding_mask,
                                         cache=self.cache, ground_truth=ground_truth)
    total_loss = tf.reduce_sum(log_prob)
    count = tf.cast(tf.shape(log_prob)[0], dtype=tf.float32)
    return total_loss, count, new_seg_cache

def get_init_cache(self):
    return tf.zeros(
        shape=(batch_size, self.kwargs["n_layers"], self.kwargs["mem_len"], self.kwargs["d_model"]),
        dtype=tf.float32)

if name == "main":
with open("InitWeights/WT103/weights.p", "rb") as f:
pre_train_weights = pickle.load(f)
dataset = "wikitext-103"
PAD = 0
batch_size = 1
G = GeneralFunction()
_cutoffs = [
1, 20001, 40001, 200001, vocab_size_dic[dataset]
]
a_epoch_segment = {
"384": 268820 // batch_size,
"512": 201615 // batch_size,
"256": 403230 // batch_size
}
E = Main(dataset_name=dataset, segment_size=128, mem_len=1600, n_heads=16, d_model=1024, n_layers=18,
d_embed=1024, batch_size=batch_size, dropout_attn=0.2, dropout_norm=0.2,
ffn_mul=4, cutoffs=_cutoffs, method="AC001")
E.train()
`

@menghuanlater
Copy link
Author

and i found something interesting, if set tgt_len to 1600, and mem_len to 128, the test ppl down to 19.7,

@menghuanlater
Copy link
Author

This bug is caused by tensorflow2.x GPU version for the precision retention method for parallel computing

@tonytan48
Copy link

@menghuanlater Hi, I was curious if you change the environment to TF1.12 and python 2.7 as the author suggest. Did you manage to get the result of perplexity 18.03?

@menghuanlater
Copy link
Author

menghuanlater commented Sep 4, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants