Skip to content

Cifar dataloader does not work #889

@davidtweedle

Description

@davidtweedle

The cifar dataloader no longer works properly with jax algorithms using jax.jit. I did not test to see if pytorch algorithms still work with cifar.

Description

When running jax_nadamw_full_budget.py optimizer with the cifar workload, an error is thrown which says
len(shards) = 128 but len(devices) = 8
Here is the relevant log:

I0918 18:15:28.313870 140360727609984 submission_runner.py:359] Starting training loop.
Traceback (most recent call last):
File "/algorithmic-efficiency/submission_runner.py", line 869, in
app.run(main)
File "/usr/local/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/algorithmic-efficiency/submission_runner.py", line 834, in main
score = score_submission_on_workload(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/algorithmic-efficiency/submission_runner.py", line 747, in score_submission_on_workload
score, _ = train_once(
^^^^^^^^^^^
File "/algorithmic-efficiency/submission_runner.py", line 375, in train_once
batch = data_selection(
^^^^^^^^^^^^^^^
File "/algorithmic-efficiency/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py", line 446, in data_selection
batch = next(input_queue)
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 147, in prefetch_to_device
enqueue(size) # Fill up the buffer.
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 145, in enqueue
queue.append(jax.tree_util.tree_map(_prefetch, data))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 141, in _prefetch
return jax.device_put_sharded(list(xs), devices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/jax/_src/api.py", line 2636, in device_put_sharded
raise ValueError(f"len(shards) = {len(shards)} must equal "
ValueError: len(shards) = 128 must equal len(devices) = 8.
2025-09-18 18:15:29.477815: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat(). You should use dataset.take(k).cache().repeat() instead.

Steps to Reproduce

  1. In algorithms/baselines/self_tuning/jax_nadamw_full_budget.py add the following two lines in get_batch_size function:

elif workload_name == 'cifar':
return 128

  1. Then run the cifar workload in docker:
    python submission_runner.py
    --framework=jax
    --workload=cifar
    --experiment_dir=/experiment_runs
    --experiment_name=jax_debug_cifar
    --data_dir=/data
    --tuning_ruleset=self
    --submission_path=algorithms/baselines/self_tuning/jax_nadamw_full_budget.py

Source or Possible Fix

I think the cifar is not an officially supported workload, but it can be useful for debugging. So once it is not too much trouble we should fix this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions