Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions algoperf/workloads/cifar/cifar_jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import jax_utils

from algoperf import spec
from algoperf.data_utils import shard_and_maybe_pad_np
Expand Down Expand Up @@ -186,5 +185,4 @@ def create_input_iter(
),
ds,
)
it = jax_utils.prefetch_to_device(it, 2)
return it
2 changes: 1 addition & 1 deletion algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str, Any]
) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: x / num_examples, total_metrics)
return jax.tree.map(lambda x: x / num_examples, total_metrics)
20 changes: 9 additions & 11 deletions algorithms/baselines/self_tuning/jax_nadamw_full_budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils

from algoperf import jax_sharding_utils, spec

Expand Down Expand Up @@ -212,7 +211,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
)
optimizer_state = opt_init_fn(params_zeros_like)

return jax_utils.replicate(optimizer_state), opt_update_fn
return optimizer_state, opt_update_fn


def train_step(
Expand Down Expand Up @@ -304,15 +303,12 @@ def update_params(
grad_clip = hyperparameters['grad_clip']
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate
dropout_rate = hyperparameters['dropout_rate']

# Create shardings for each argument
mesh = jax.sharding.Mesh(jax.devices(), ('batch'))
replicated = jax_sharding_utils.get_replicate_sharding(
mesh
) # No partitioning
sharded = jax_sharding_utils.get_batch_sharding(
mesh
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
sharded = (
jax_sharding_utils.get_batch_dim_sharding()
) # Partition along batch dimension

# Create the sharding rules for each argument
Expand Down Expand Up @@ -362,8 +358,8 @@ def update_params(
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
{
'loss': loss[0],
'grad_norm': grad_norm[0],
'loss': loss,
'grad_norm': grad_norm,
},
global_step,
)
Expand Down Expand Up @@ -417,6 +413,8 @@ def get_batch_size(workload_name):
return 128
elif workload_name == 'mnist':
return 16
elif workload_name == 'cifar':
return 128
else:
raise ValueError(f'Unsupported workload name: {workload_name}.')

Expand Down
2 changes: 1 addition & 1 deletion docs/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ We recommend using a Docker container to ensure a similar environment to our sco
docker build -t <docker_image_name> . --build-arg framework=<framework>
```

The `framework` flag can be either `pytorch`, `jax` or `both`. Specifying the framework will install the framework specific dependencies.
The `framework` flag can be either `pytorch` or `jax`. Specifying the framework will install the framework specific dependencies.
The `docker_image_name` is arbitrary.

#### Running Docker Container (Interactive)
Expand Down