From e31f2cac8991fb0fb7b9ebdc9a7fd0861a2bf956 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 4 Mar 2024 15:45:58 +0000 Subject: [PATCH 1/2] Add missing dropout rng for ImageNet-ViT and WMT --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 8 ++++++-- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 0bc1cd8cc..a9ac76ea2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -21,10 +21,14 @@ # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def initialized(self, key: spec.RandomState, + def initialized(self, + key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': key}, jnp.ones(input_shape)) + params_rng, dropout_rng = jax.random.split(key) + variables = jax.jit( + model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape)) model_state, params = variables.pop('params') return params, model_state diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 0250206a6..1b6d6449e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -234,8 +234,9 @@ def init_model_fn( self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) + params_rng, dropout_rng = jax.random.split(rng) initial_variables = jax.jit(self._eval_model.init)( - rng, + {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) From 1d0e8a9067a4c7443b501b1652fe3ec1c6385cee Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 4 Mar 2024 15:50:02 +0000 Subject: [PATCH 2/2] Fix yapf --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 3 +-- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index a9ac76ea2..93f603bca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -21,8 +21,7 @@ # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def initialized(self, - key: spec.RandomState, + def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) params_rng, dropout_rng = jax.random.split(key) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 1b6d6449e..b10d4056d 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -235,10 +235,10 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) - initial_variables = jax.jit(self._eval_model.init)( - {'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + initial_variables = jax.jit( + self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params)