Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def initialized(self, key: spec.RandomState,
model: nn.Module) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape))
params_rng, dropout_rng = jax.random.split(key)
variables = jax.jit(
model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape))
model_state, params = variables.pop('params')
return params, model_state

Expand Down
9 changes: 5 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def init_model_fn(
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
initial_variables = jax.jit(self._eval_model.init)(
rng,
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))
params_rng, dropout_rng = jax.random.split(rng)
initial_variables = jax.jit(
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))

initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
Expand Down