Skip to content

Commit

Permalink
Fix pixelcnn test
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuskunze committed Nov 29, 2019
1 parent bd2104b commit ce6e2b6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/pixelcnn.py
Expand Up @@ -283,7 +283,7 @@ def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999
opt = Adam(exponential_decay(step_size, 1, decay_rate))
state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

for epoch in range(0):
for epoch in range(epochs):
for batch in get_train_batches():
key, update_key = random.split(key)
i = opt.get_step(state)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_examples.py
Expand Up @@ -7,7 +7,7 @@
from jax.random import PRNGKey

from examples.mnist_vae import gaussian_sample, bernoulli_logpdf, gaussian_kl
from examples.pixelcnn import PixelCNNPP
from examples.pixelcnn import PixelCNNPP, image_dtype
from examples.wavenet import calculate_receptive_field, discretized_mix_logistic_loss, Wavenet
from jaxnet import parametrized, Dense, Sequential, Conv, flatten, GRUCell, Rnn, \
Parameter, parameter, Reparametrized, L2Regularized, optimizers
Expand Down Expand Up @@ -244,7 +244,7 @@ def loss(batch):

def test_pixelcnn():
loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1)
images = np.zeros((2, 16, 16, 3), np.uint8)
images = np.zeros((2, 16, 16, 3), image_dtype)
opt = optimizers.Adam()
state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))
# take ~20s, disabled for faster tests:
Expand Down

0 comments on commit ce6e2b6

Please sign in to comment.