Skip to content

Commit

Permalink
Improve docs of embed_id
Browse files Browse the repository at this point in the history
  • Loading branch information
keisuke-umezawa committed Aug 2, 2017
1 parent 8486edf commit 978ce99
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions chainer/functions/connection/embed_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def backward(self, inputs, grad_outputs):
def embed_id(x, W, ignore_label=None):
"""Efficient linear function for one-hot input.
This function implements so called *word embedding*. It takes two
This function implements so called *word embeddings*. It takes two
arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer
vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d`
float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th
Expand All @@ -96,16 +96,39 @@ def embed_id(x, W, ignore_label=None):
This function is only differentiable on the input ``W``.
Args:
x (~chainer.Variable): Batch vectors of IDs.
W (~chainer.Variable): Representation of each ID (a.k.a.
word embeddings).
ignore_label (int or None): If ``ignore_label`` is an int value,
``i``-th column of return value is filled with ``0``.
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Batch vectors of IDs. Each element must be :class:`numpy.int32`.
W (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Distributed representation of each ID (a.k.a. word embeddings).
ignore_label (:class:`int` or :class:`None`):
If ``ignore_label`` is an int value, ``i``-th column of return
value is filled with ``0``.
Returns:
~chainer.Variable: Output variable.
.. seealso:: :class:`~chainer.links.EmbedID`
.. admonition:: Example
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> W = np.array([[0, 0, 0],
[1, 1, 1],
[2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> F.embed_id(x, W).data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
>>> F.embed_id(x, W, ignore_label=1).data
array([[ 2., 2., 2.],
[ 0., 0., 0.]], dtype=float32)
"""
return EmbedIDFunction(ignore_label=ignore_label)(x, W)

0 comments on commit 978ce99

Please sign in to comment.