Skip to content

Commit

Permalink
add support for tf1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
korepwx committed Mar 8, 2018
1 parent 85eb5c9 commit c31b4e4
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -1,5 +1,6 @@
.idea
.cache
.pytest_cache
*.iml
/config.py
/debug.py
Expand Down
8 changes: 2 additions & 6 deletions .travis.yml
Expand Up @@ -5,13 +5,9 @@ services:
env:
matrix:
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.2
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.3
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.4
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.5
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.2
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.3
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.4
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.5
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.6
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.6
install:
- docker pull "ipwx/travis-tensorflow-docker:py${PYTHON_VERSION}tf${TENSORFLOW_VERSION}"
script:
Expand Down
8 changes: 2 additions & 6 deletions README.rst
Expand Up @@ -11,13 +11,9 @@ TFSnippet
+------------+-------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------+
| TensorFlow | Python 2 | Python 3 |
+============+=================================================================================================+=================================================================================================+
| 1.2 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/1 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/5 |
| 1.2 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/1 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/2 |
+------------+-------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------+
| 1.3 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/2 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/6 |
+------------+-------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------+
| 1.4 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/3 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/7 |
+------------+-------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------+
| 1.5 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/4 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/8 |
| 1.6 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/3 | .. image:: https://travis-matrix-badges.herokuapp.com/repos/korepwx/tfsnippet/branches/master/4 |
+------------+-------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------+

TF Snippet is a set of utilities for writing and testing TensorFlow models.
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_univariate.py
Expand Up @@ -63,12 +63,12 @@ def test_props(self):
np.testing.assert_allclose(categorical.logits.eval(), logits)

def test_dtype(self):
categorical = Categorical(logits=[0., 1.])
categorical = Categorical(logits=tf.constant([0., 1.]))
self.assertEqual(categorical.dtype, tf.int32)
samples = categorical.sample()
self.assertEqual(samples.dtype, tf.int32)

categorical = Categorical(logits=[0., 1.], dtype=tf.int64)
categorical = Categorical(logits=tf.constant([0., 1.]), dtype=tf.int64)
self.assertEqual(categorical.dtype, tf.int64)
samples = categorical.sample()
self.assertEqual(samples.dtype, tf.int64)
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_wrapper.py
Expand Up @@ -17,7 +17,7 @@ def test_distribution(self):
self.assertIs(distrib, d)

def test_zs_distribution(self):
normal = zd.Normal(0., 1.)
normal = zd.Normal(mean=0., std=1.)
distrib = as_distribution(normal)
self.assertIsInstance(distrib, Distribution)
self.assertIsInstance(distrib, ZhuSuanDistribution)
Expand Down
2 changes: 2 additions & 0 deletions tfsnippet/distributions/univariate.py
Expand Up @@ -104,6 +104,8 @@ def __init__(self, logits, dtype=None):
dtype: The value type of samples from the distribution.
(default ``tf.int32``)
"""
if dtype is None:
dtype = tf.int32
super(Categorical, self).__init__(
zd.Categorical(logits=logits, dtype=dtype))

Expand Down
2 changes: 1 addition & 1 deletion tfsnippet/utils/tfver.py
Expand Up @@ -15,4 +15,4 @@ def is_tensorflow_version_higher_or_equal(version):
Returns:
bool: True if higher or equal to, False if not.
"""
return semver.compare(version, tf.__version__) <= 0
return semver.compare(version, tf.__version__, loose=True) <= 0

0 comments on commit c31b4e4

Please sign in to comment.