You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working on pix2pix example. I don't get an error when I give the picture to the model, but when I give the picture to model_optimizer.target I am getting error. Why?
OK, so if you were just trying to evaluate this on a single device, you'd just remove this line from create_optimizer:
optimizer=flax.jax_utils.replicate(optimizer)
the replicate fn is broadcasting the params to every "device" being used now - on a single gpu just 1 (which is kind of silly), but on multigpu or tpu: 4, 8, etc. Notice you didn't replicate the model above but replicated the optimizer (again both hold params inside them for convenience) which is why the second broke.
Now, assuming you do want to do replicated SPMD computation across multiple devices, you would keep that replicate call in create_optimizer but you also need to define a model eval function to be pmapped so that it can use the replicated parameters it stores across devices, for example:
@jax.pmapdefeval_w_pmap(model, x, prng_key):
withflax.nn.stochastic(prng_key):
# model is just a container for replicated paramsreturnmodel(x)
ldc=jax.local_device_count()
pmap_test_input=jax.random.normal(jax.random.PRNGKey(1), (ldc, 1, 256, 256, 3))
pmap_rngs=jax.random.split(jax.random.PRNGKey(0), ldc)
eval_w_pmap(generator_optimizer.target, pmap_test_input, pmap_rngs).shape# (1, 8, 256, 256, 3)
I hope that helps explain what's going on - please let me know if it's still not clear!
I am working on pix2pix example. I don't get an error when I give the picture to the
model
, but when I give the picture tomodel_optimizer.target
I am getting error. Why?Code
Error
Related pr: #186
Thank you!
The text was updated successfully, but these errors were encountered: