From d6e39b15e92b708b886c008110f99859456e17e1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 21 Sep 2025 08:41:31 +0000 Subject: [PATCH 1/5] cifar workload error resolution --- .../cifar/cifar_jax/input_pipeline.py | 1 - .../self_tuning/jax_nadamw_full_budget.py | 19 ++++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 7fbc95bc6..9c83a9b06 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -186,5 +186,4 @@ def create_input_iter( ), ds, ) - it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py index ee424d3b7..890267c64 100644 --- a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py @@ -212,7 +212,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): ) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def train_step( @@ -304,16 +304,11 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - dropout_rate = hyperparameters.dropout_rate + dropout_rate = hyperparameters['dropout_rate'] # Create shardings for each argument - mesh = jax.sharding.Mesh(jax.devices(), ('batch')) - replicated = jax_sharding_utils.get_replicate_sharding( - mesh - ) # No partitioning - sharded = jax_sharding_utils.get_batch_sharding( - mesh - ) # Partition along batch dimension + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension # Create the sharding rules for each argument arg_shardings = ( @@ -362,8 +357,8 @@ def update_params( if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss, + 'grad_norm': grad_norm, }, global_step, ) @@ -417,6 +412,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 82b0957d776ce35b261273704d194e85b264b404 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 21 Sep 2025 08:49:37 +0000 Subject: [PATCH 2/5] ruff formatting issues --- algorithms/baselines/self_tuning/jax_nadamw_full_budget.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py index 890267c64..d093f6a56 100644 --- a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py @@ -308,7 +308,9 @@ def update_params( # Create shardings for each argument replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning - sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension # Create the sharding rules for each argument arg_shardings = ( From 8b48977fdebacb08a558c777341053c652e3c2f0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 21 Sep 2025 08:52:54 +0000 Subject: [PATCH 3/5] ruff linting issues --- algoperf/workloads/cifar/cifar_jax/input_pipeline.py | 1 - algorithms/baselines/self_tuning/jax_nadamw_full_budget.py | 1 - 2 files changed, 2 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 9c83a9b06..307e9e705 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,7 +11,6 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds -from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np diff --git a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py index d093f6a56..7c41d4377 100644 --- a/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py @@ -16,7 +16,6 @@ import jax import jax.numpy as jnp import optax -from flax import jax_utils from algoperf import jax_sharding_utils, spec From 5f6f8397d89fd6eebe3d66850c6291c3c40def01 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 25 Sep 2025 16:24:02 +0000 Subject: [PATCH 4/5] updating the code to remove older way of tree map --- algoperf/workloads/cifar/cifar_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index defc30121..9e2be6a14 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -220,4 +220,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] ) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: x / num_examples, total_metrics) + return jax.tree.map(lambda x: x / num_examples, total_metrics) From fe54a453e822af8437a7d6c8413689e3e021f257 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 26 Sep 2025 04:53:39 +0000 Subject: [PATCH 5/5] documentation changes for docker env setup --- docs/GETTING_STARTED.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 0cc286099..f4318e319 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -123,7 +123,7 @@ We recommend using a Docker container to ensure a similar environment to our sco docker build -t . --build-arg framework= ``` - The `framework` flag can be either `pytorch`, `jax` or `both`. Specifying the framework will install the framework specific dependencies. + The `framework` flag can be either `pytorch` or `jax`. Specifying the framework will install the framework specific dependencies. The `docker_image_name` is arbitrary. #### Running Docker Container (Interactive)