diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 0bc1cd8cc..93f603bca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape)) + params_rng, dropout_rng = jax.random.split(key) + variables = jax.jit( + model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape)) model_state, params = variables.pop('params') return params, model_state diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 0250206a6..b10d4056d 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -234,10 +234,11 @@ def init_model_fn( self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - initial_variables = jax.jit(self._eval_model.init)( - rng, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + params_rng, dropout_rng = jax.random.split(rng) + initial_variables = jax.jit( + self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params)