flax.jax_utils package flax.jax_utils flax.jax_utils partial_eval_by_shape Multi device utilities replicate unreplicate prefetch_to_device pmean pad_shard_unpad