Skip to content
This repository has been archived by the owner on Mar 19, 2021. It is now read-only.

Commit

Permalink
Fix byte encoding
Browse files Browse the repository at this point in the history
tf.strings.unicode_decode decodes to unicode codepoints. So, any non-latin-1
character falls outside the range. This change introduces actual byte
encodings.

In some experiments, I found that with real byte representations, it helps
to have two layers in the byte RNN (presumably to combine characters with
multi-byte encodings), so this also changes the default hyperparameters.
  • Loading branch information
danieldk committed Aug 8, 2019
1 parent 7d781cf commit 605c6d5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
36 changes: 23 additions & 13 deletions sticker-graph/sticker_graph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def affine_transform(self, prefix, x, n_outputs):
x, w, b), [
batch_size, -1, n_outputs])

def decode_word(self, word, max_subword_len):
bytes = tf.decode_raw(word, tf.uint8)[:max_subword_len]
return tf.pad(bytes, [[0, max_subword_len - tf.shape(bytes)[0]]])

def decode_words(self, words, max_subword_len):
shape = tf.shape(words)

# Flatten, because map_fn only works on axis 0
flat = tf.reshape(words, [-1])

# Decode words.
decoded = tf.map_fn(lambda x: self.decode_word(x, max_subword_len), flat, dtype=tf.uint8)

# Restore shape
return tf.reshape(decoded, [shape[0], shape[1], max_subword_len])


def masked_softmax_loss(self, prefix, logits, labels, mask):
# Compute losses
Expand Down Expand Up @@ -143,27 +159,21 @@ def subword_reprs(self):
# Convert strings to a byte tensor.
#
# Shape: [batch_size, seq_len, subword_len]
subword_bytes = tf.strings.unicode_decode(
self.subwords, input_encoding='UTF-8')
subword_bytes_padded = subword_bytes.to_tensor(
0)[:, :, :self.args.subword_len]
with tf.device("/cpu:0"):
subword_bytes = tf.cast(self.decode_words(self.subwords, self.args.subword_len), tf.int32)

# Get the lengths of the subwords. Only the last dimension should
# be ragged, so no actual padding should happen.
#
# Get the lengths of the subwords.
# Shape: [batch_size, seq_len]
subword_lens = tf.math.minimum(
subword_bytes.row_lengths(
axis=-1).to_tensor(0),
self.args.subword_len)
with tf.device("/cpu:0"):
subword_lens = tf.count_nonzero(subword_bytes, axis=-1, dtype=tf.int32)

# Lookup byte embeddings, this results in a tensor of shape.
#
# Shape: [batch_size, seq_len, max_bytes_len, byte_embed_size]
byte_embeds = tf.get_variable(
"byte_embeds", [
256, self.args.byte_embed_size])
byte_reprs = tf.nn.embedding_lookup(byte_embeds, subword_bytes_padded)
byte_reprs = tf.nn.embedding_lookup(byte_embeds, subword_bytes)

byte_reprs = tf.contrib.layers.dropout(
byte_reprs,
Expand All @@ -173,7 +183,7 @@ def subword_reprs(self):
# Prepare shape for applying the RNN:
#
# Shape: [batch_size * seq_len, max_bytes_len, byte_embed_size]
bytes_shape = tf.shape(subword_bytes_padded)
bytes_shape = tf.shape(subword_bytes)
byte_reprs = tf.reshape(
byte_reprs, [-1, bytes_shape[2], self.args.byte_embed_size])
byte_lens = tf.reshape(subword_lens, [-1])
Expand Down
4 changes: 2 additions & 2 deletions sticker-graph/sticker_graph/write_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def get_common_parser():
"--subword_layers",
type=int,
help="character RNN hidden layers",
default=1)
default=2)
parser.add_argument(
"--subword_len",
type=int,
help="number of characters in character-based representations",
default=20)
default=40)
parser.add_argument(
"--subword_residual",
action='store_true',
Expand Down

0 comments on commit 605c6d5

Please sign in to comment.