Skip to content

Commit

Permalink
add example for links.EmbedID
Browse files Browse the repository at this point in the history
  • Loading branch information
keisuke-umezawa committed Aug 5, 2017
1 parent 978ce99 commit 3717dbf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
4 changes: 2 additions & 2 deletions chainer/functions/connection/embed_id.py
Expand Up @@ -117,8 +117,8 @@ def embed_id(x, W, ignore_label=None):
>>> x
array([2, 1], dtype=int32)
>>> W = np.array([[0, 0, 0],
[1, 1, 1],
[2, 2, 2]]).astype('f')
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
Expand Down
20 changes: 19 additions & 1 deletion chainer/links/connection/embed_id.py
Expand Up @@ -22,11 +22,29 @@ class EmbedID(link.Link):
ignore_label (int or None): If ``ignore_label`` is an int value,
``i``-th column of return value is filled with ``0``.
.. seealso:: :func:`chainer.functions.embed_id`
.. seealso:: :func:`~chainer.functions.embed_id`
Attributes:
W (~chainer.Variable): Embedding parameter matrix.
.. admonition:: Example
>>> W = np.array([[0, 0, 0],
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W)
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> y = l(x)
>>> y.data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
"""

ignore_label = None
Expand Down

0 comments on commit 3717dbf

Please sign in to comment.