Skip to content

Commit

Permalink
Merge pull request #928 from dwf/shared_floatx_kwargs
Browse files Browse the repository at this point in the history
Pass shared_floatx kwargs to theano.shared.
  • Loading branch information
dmitriy-serdyuk committed Dec 1, 2015
2 parents fce7f6b + b7f5df3 commit 506dfe1
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions blocks/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def shared_floatx_nans(shape, **kwargs):
return shared_floatx(numpy.nan * numpy.zeros(shape), **kwargs)


def shared_floatx(value, name=None, borrow=False, dtype=None):
"""Transform a value into a shared variable of type floatX.
def shared_floatx(value, name=None, borrow=False, dtype=None, **kwargs):
r"""Transform a value into a shared variable of type floatX.
Parameters
----------
Expand All @@ -123,6 +123,8 @@ def shared_floatx(value, name=None, borrow=False, dtype=None):
dtype : :obj:`str`, optional
The `dtype` of the shared variable. Default value is
:attr:`config.floatX`.
\*\*kwargs
Keyword arguments to pass to the :func:`~theano.shared` function.
Returns
-------
Expand All @@ -133,12 +135,11 @@ def shared_floatx(value, name=None, borrow=False, dtype=None):
if dtype is None:
dtype = theano.config.floatX
return theano.shared(theano._asarray(value, dtype=dtype),
name=name,
borrow=borrow)
name=name, borrow=borrow, **kwargs)


def shared_like(variable, name=None):
"""Construct a shared variable to hold the value of a tensor variable.
def shared_like(variable, name=None, **kwargs):
r"""Construct a shared variable to hold the value of a tensor variable.
Parameters
----------
Expand All @@ -148,14 +149,16 @@ def shared_like(variable, name=None):
name : :obj:`str` or :obj:`None`
The name of the shared variable. If None, the name is determined
based on variable's name.
\*\*kwargs
Keyword arguments to pass to the :func:`~theano.shared` function.
"""
variable = tensor.as_tensor_variable(variable)
if name is None:
name = "shared_{}".format(variable.name)
return theano.shared(numpy.zeros((0,) * variable.ndim,
dtype=variable.dtype),
name=name)
name=name, **kwargs)


def reraise_as(new_exc):
Expand Down

0 comments on commit 506dfe1

Please sign in to comment.