Skip to content

Commit

Permalink
update svd-softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Koichiro Tamura committed Dec 15, 2017
1 parent bf54640 commit 22940b4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 60 deletions.
36 changes: 6 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,16 @@ svd-softmax implemented in Tensorflow by [Koichiro Tamura](http://koichirotamura
## Why this project?

Since it is very important to redece calculation cost at softmax output in NL tasks, I tried to implement the idea in [SVD-Softmax: Fast Softmax Approximation on Large Vocabulary Neural Networks](https://papers.nips.cc/paper/7130-svd-softmax-fast-softmax-approximation-on-large-vocabulary-neural-networks).
However, there are some problems at the implement of svd-softmax in Tensorflow. So I would like to discuss them and I would like someone to tell me the solutions.

## Problems to solve

### No gradient defined for operation SVD

SVD(singular value decomposition) method in Tensorflow [tf.svd()](https://www.tensorflow.org/api_docs/python/tf/svd) don't support gradient function in Tensorflow Graph. If you would like to use SVD-softmax in training, you have to implemnt trainable svd-function by yourself.

### Too slow SVD-softmax in GPU

Even when using svd-softmax in evaluation, calculation of svd-softmax is too slow.
For example, I tried to use svd-softmax in [Transformer](https://arxiv.org/abs/1706.03762) using following hyperparameters or enviroments.
## room for improvement

- vocabulary size = 30000
- hidden units = 256
- window size = 256
- num of full view = 2048
- JPO Japanses-Chinese corpus (1 milion pairs)
- 4x TITAN X (Pascal) (Liquid cooling
- Ubuntu 16.04.1 LTS


However, the calculation time is as follows in my experiments.

- calculation full-softmax(codes are as follows): about 0.4sec
### No gradient defined for operation SVD

```
logits = tf.matmul(self.dec_output, tf.transpose(self.weights))
logits = tf.nn.bias_add(logits, self.biases)
logits = tf.nn.softmax(logits)
```
- tf.svd() line 32 in svd_softmax.py: about 2.5sec
SVD(singular value decomposition) method in Tensorflow [tf.svd()](https://www.tensorflow.org/api_docs/python/tf/svd) don't support gradient function in Tensorflow Graph. This means that you have to use other training method like NCE.

That is, calucation of SVD is slower than the speed of calculation of full-softmax.
### more efficient codes for update Top-N words

I don't know how to deal with this problem, so please tell me the solution if you can.
Since tensorflow uses static graph, it is difficult to update words by full-view vector multiplication.
If you can know more efficient way to implement, please tell me.

82 changes: 52 additions & 30 deletions svd_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,69 @@
"""
SVD-Softmax: Fast Softmax Approximation on Large Vocabulary Neural Networks
http://papers.nips.cc/paper/7130-svd-softmax-fast-softmax-approximation-on-large-vocabulary-neural-networks.pdf
implemented in Tensorflow by Koichiro Tamura(http://koichirotamura.com/)
"""

import tensorflow as tf
import math

class SVDSoftmax(object):
"""svd-softmax class"""

def __init__(self, tgt_vocab_size, hidden_units, window_size=2 ** 5, num_full_view=2 ** 11):
"""
initialize SVD
:param tgt_vocab_size: int, num of vocabulary
:param hidden_units: int, num of hidden units
:param window_size: int, width of preview window W( hidden_units/ 8 is recommended)
:param num_full_view: int, num of full-view size
:return: A Tensor [batch_size, seq_length, tgt_vocab_size], output after softmax approximation
"""

self.tgt_vocab_size = tgt_vocab_size
self.hidden_units = hidden_units
self.window_size = window_size
self.num_full_view = num_full_view

# tf.matmul(U, tf.diag(_s))
self.B = tf.Variable(
tf.truncated_normal([self.tgt_vocab_size, self.hidden_units],
stddev=1.0 / math.sqrt(hidden_units)), name="B_SVD", trainable=False)
# transposed V
self.V_t = tf.Variable(
tf.truncated_normal([self.hidden_units, self.hidden_units],
stddev=1.0 / math.sqrt(hidden_units)), name="V_SVD", trainable=False)

def svd_softmax(dec, tgt_vocab_size, hidden_units, window_size=2**5, num_full_view=2**11):
"""
svd-softmax
:param dec: A Tensor [batch_size, seq_length, hidden_units], decoder output
:param tgt_vocab_size: int, num of vocabulary
:param hidden_units: int, num of hidden units
:param window_size: int, width of preview window W( hidden_units/ 8 is recommended)
:param num_full_view: int, num of full-view size
:return: A Tensor [batch_size, seq_length, tgt_vocab_size], output after softmax approximation
"""

with tf.variable_scope("logits", reuse=True):
weights = tf.Variable(
tf.truncated_normal([tgt_vocab_size, hidden_units],
stddev=1.0 / math.sqrt(hidden_units)), name="output_weight")
biases = tf.Variable(tf.zeros([tgt_vocab_size]), name="output_bias")
dec_output = tf.reshape(dec, [-1, hidden_units]) # [batch_size*T_q, hidden]

# svd-softmax
def update_params(self, weights):
"""
update svd parameter B, V_t
:param weights: output weight of softmax
:return:
"""
_s, U, V = tf.svd(weights, full_matrices=False)
B = tf.matmul(U, tf.diag(_s))
self.B.assign(tf.matmul(U, tf.diag(_s)))
self.V_t.assign(tf.transpose(V))
return

_h = tf.einsum('ij,aj->ai', tf.transpose(V), dec_output) # [batch_size*T_q, hidden]
_z = tf.add(tf.einsum('ij,aj->ai', B[:, :window_size], _h[:, :window_size]), biases) # [batch_size*T_q, voc]
def get_output(self, dec_output, biases):
"""
get svd-softmax approximation
:param dec: A Tensor [batch_size*seq_length, hidden_units], decoder output
:param biases: A Tensor [tgt_vocab_size], output bias
:return: A Tensor [batch_size*seq_length, tgt_vocab_size], output after softmax approximation
"""
_h = tf.einsum('ij,aj->ai', self.V_t, dec_output)
_z = tf.add(tf.einsum('ij,aj->ai', self.B[:, :self.window_size], _h[:, :self.window_size]), biases)

top_k = tf.nn.top_k(_z, k=tgt_vocab_size)
_indices, values = top_k.indices, top_k.values # [batch_size*T_q, N]
top_k = tf.nn.top_k(_z, k=self.tgt_vocab_size)
_indices, values = top_k.indices, top_k.values

_z = tf.add(tf.squeeze(tf.matmul(tf.gather(B, _indices[:, :num_full_view]), tf.expand_dims(_h, axis=-1))), tf.gather(biases, _indices[:, :num_full_view])) # [N*T_q, N]
_z = tf.concat([_z, values[:, num_full_view:]], axis=-1)
_z = tf.map_fn(lambda x: tf.gather(x[0], tf.invert_permutation(x[1])), (_z, _indices), dtype=(tf.float32)) # [batch_size*T_q, voc]
_z = tf.add(tf.squeeze(tf.matmul(tf.gather(self.B, _indices[:, :self.num_full_view]), tf.expand_dims(_h, axis=-1))),
tf.gather(biases, _indices[:, :self.num_full_view]))
_z = tf.concat([_z, values[:, self.num_full_view:]], axis=-1)
_z = tf.map_fn(lambda x: tf.gather(x[0], tf.invert_permutation(x[1])), (_z, _indices), dtype=tf.float32)
_z = tf.exp(_z)
Z = tf.expand_dims(tf.reduce_sum(_z, axis=-1), axis=1) # [batch_size*T_q, 1]
Z = tf.expand_dims(tf.reduce_sum(_z, axis=-1), axis=1)
logits = _z / Z

return tf.reshape(logits, [-1, tf.shape(dec)[1], tgt_vocab_size])
return logits

0 comments on commit 22940b4

Please sign in to comment.