Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
234eb5a
add hparam index flags to submission_runner
priyakasimbeg Sep 13, 2023
7a0f0f9
fix
priyakasimbeg Sep 13, 2023
fe05acd
clarify
priyakasimbeg Sep 13, 2023
77f44fb
reformatting
priyakasimbeg Sep 13, 2023
0100e44
fix
priyakasimbeg Sep 13, 2023
3c5882b
fix
priyakasimbeg Sep 13, 2023
301de4a
Switch to absolute paths in Dockerfile
runame Sep 18, 2023
cc6d0db
Only set DEBIAN_FRONTEND where necessary
runame Sep 18, 2023
e7b854c
Add instructions for running Singularity/Apptainer container to README
runame Sep 18, 2023
5c48b2b
Merge branch 'dev' into singularity
runame Sep 21, 2023
7970388
update tuning search spaces for speech_workloads
priyakasimbeg Sep 25, 2023
862f500
add tabulate for deepspeech debugging
priyakasimbeg Sep 26, 2023
1816234
update target setting algo for conformer to adamw
priyakasimbeg Sep 26, 2023
6b9655f
tabulate typo fix
priyakasimbeg Sep 26, 2023
28db392
typo
priyakasimbeg Sep 26, 2023
fab70f9
add lr logging target_setting nadamw
priyakasimbeg Sep 27, 2023
fc07904
Merge branch 'speech_targets' of github.com:mlcommons/algorithmic-eff…
priyakasimbeg Sep 27, 2023
1ea08b0
reverse padding fixes
priyakasimbeg Sep 28, 2023
20376e0
Merge pull request #516 from runame/singularity
priyakasimbeg Sep 28, 2023
c3cf664
log_step_hint
priyakasimbeg Sep 28, 2023
784230f
log step hint in submission runner
priyakasimbeg Sep 28, 2023
e83937b
add logging
priyakasimbeg Sep 28, 2023
0606b01
lgoging
priyakasimbeg Sep 28, 2023
b23aae4
logging
priyakasimbeg Sep 28, 2023
27cb037
fix logging
priyakasimbeg Sep 28, 2023
a0c8624
more logging;
priyakasimbeg Sep 28, 2023
d977605
remove inheritance
priyakasimbeg Sep 28, 2023
9749f00
temp change to conformer
priyakasimbeg Sep 28, 2023
869bdad
fix step hint
priyakasimbeg Sep 29, 2023
0a60f89
deepspeech inheritance fix
priyakasimbeg Sep 29, 2023
54f3eb9
Merge branch 'dev' into speech_targets
priyakasimbeg Sep 29, 2023
06df206
remove lr schedule logging
priyakasimbeg Sep 29, 2023
aa6c538
debugging statements
priyakasimbeg Sep 29, 2023
3afcc9f
fix
priyakasimbeg Sep 29, 2023
3c72358
fix
priyakasimbeg Sep 29, 2023
bc55c34
fix imports
priyakasimbeg Sep 29, 2023
1323b57
import fix
priyakasimbeg Sep 29, 2023
0c67481
fix
priyakasimbeg Sep 30, 2023
70cfc72
merge fix
priyakasimbeg Sep 30, 2023
a76083b
fix
priyakasimbeg Sep 30, 2023
88e9d9e
fix
priyakasimbeg Sep 30, 2023
1e44349
copy jax and pytorch loss_fn, model_fn and _eval_model_on_split to de…
priyakasimbeg Sep 30, 2023
347d1df
fix imports
priyakasimbeg Sep 30, 2023
3351f73
Merge pull request #509 from mlcommons/hparam_trial_indices
priyakasimbeg Sep 30, 2023
f7789db
fix block
priyakasimbeg Sep 30, 2023
711a8fe
fix and formatting
priyakasimbeg Oct 2, 2023
345644b
import fix
priyakasimbeg Oct 2, 2023
36a0a73
test
priyakasimbeg Oct 2, 2023
33e1896
add import fix
priyakasimbeg Oct 2, 2023
2b40f5f
fix
priyakasimbeg Oct 2, 2023
29c123e
fix
priyakasimbeg Oct 2, 2023
0229fd8
add init for deepspeech workloads
priyakasimbeg Oct 2, 2023
9485422
missing import
priyakasimbeg Oct 2, 2023
8f85841
clean up deepspeech refactoring
priyakasimbeg Oct 2, 2023
5ea7b13
fix lint
priyakasimbeg Oct 2, 2023
fd1e49c
add targets to deepspeech
priyakasimbeg Oct 2, 2023
4274538
lint
priyakasimbeg Oct 2, 2023
09eed7e
formatting
priyakasimbeg Oct 3, 2023
a664134
Change padding for Deepspeech LSTM layer
priyakasimbeg Oct 5, 2023
8c642dc
Merge pull request #526 from mlcommons/speech_targets_clean
priyakasimbeg Oct 5, 2023
2cb31a2
Adjust runtime budget for self-tuning ruleset and check that tuning s…
runame Oct 6, 2023
5407ab1
Remove test target from scoring
runame Oct 6, 2023
b84efff
Merge pull request #535 from runame/scoring
priyakasimbeg Oct 6, 2023
cadcef0
Merge pull request #534 from runame/self-tuning-budget
priyakasimbeg Oct 6, 2023
4131232
Merge pull request #533 from mlcommons/deepspeech-padding-change
priyakasimbeg Oct 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


## Installation
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker container](#install-in-docker) (recommended).
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended).

*TL;DR to install the Jax version for GPU run:*

Expand Down Expand Up @@ -89,7 +89,8 @@ pip3 install -e '.[full]'
</details>

## Docker
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
Alternatively, a Singularity/Apptainer container can also be used (see instructions below).


**Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs.
Expand Down Expand Up @@ -133,6 +134,25 @@ To use the Docker container as an interactive virtual environment, you can run a
### Running Docker Container (End-to-end)
To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container).

### Using Singularity/Apptainer instead of Docker
Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile.

To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli):
```bash
pip3 install spython
cd algorithmic-efficiency/docker
spython recipe Dockerfile &> Singularity.def
```
Now we can build the Apptainer image by running
```bash
singularity build --fakeroot <singularity_image_name>.sif Singularity.def
```
To start a shell session with GPU support (by using the `--nv` flag), we can run
```bash
singularity shell --nv <singularity_image_name>.sif
```
Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html).

# Getting Started
For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md).
## Running a workload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _build_input_queue(
}

padded_batch = data_utils.shard_and_maybe_pad_np(
numpy_batch, padding_value=1.0, global_batch_size=global_batch_size)
numpy_batch, padding_value=1.0)
yield padded_batch

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,20 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

@property
def validation_target_value(self) -> float:
return 0.118232

@property
def test_target_value(self) -> float:
return 0.073397

@property
def step_hint(self) -> int:
"""Max num steps the baseline algo was given to reach the target."""
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,20 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']

@property
def validation_target_value(self) -> float:
return 0.118232

@property
def test_target_value(self) -> float:
return 0.073397

@property
def step_hint(self) -> int:
"""Max num steps the baseline algo was given to reach the target."""
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours

This file was deleted.

23 changes: 11 additions & 12 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

# To build Docker image
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
ARG DEBIAN_FRONTEND=noninteractive

# Installing machine packages
RUN echo "Setting up machine"
RUN apt-get update
RUN apt-get install -y curl tar
RUN apt-get install -y git python3 pip wget ffmpeg
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg
RUN apt-get install libtcmalloc-minimal4
RUN apt-get install unzip
RUN apt-get install pigz
Expand All @@ -34,38 +33,38 @@ RUN echo "Setting up algorithmic_efficiency repo"
ARG branch="main"
ARG framework="both"
ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git
RUN git clone $git_url && cd algorithmic-efficiency
RUN cd algorithmic-efficiency && git checkout $branch
RUN git clone $git_url && cd /algorithmic-efficiency
RUN cd /algorithmic-efficiency && git checkout $branch

RUN cd algorithmic-efficiency && pip install -e '.[full]'
RUN cd /algorithmic-efficiency && pip install -e '.[full]'

RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_cpu]' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "both" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
fi

RUN cd algorithmic-efficiency && pip install -e '.[wandb]'
RUN cd /algorithmic-efficiency && pip install -e '.[wandb]'

RUN cd algorithmic-efficiency && git fetch origin
RUN cd algorithmic-efficiency && git pull
RUN cd /algorithmic-efficiency && git fetch origin
RUN cd /algorithmic-efficiency && git pull

# Todo: remove this, this is temporary for developing
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh

ENTRYPOINT ["bash", "algorithmic-efficiency/docker/scripts/startup.sh"]
ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"]
4 changes: 2 additions & 2 deletions reference_algorithms/target_setting_algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ python3 submission_runner.py \
--experiment_dir=$ROOT_DIR \
--experiment_name=target_setting \
--workload=librispeech_conformer \
--submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py \
--submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \
--tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
```
```bash
Expand All @@ -123,7 +123,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc
--experiment_dir=$ROOT_DIR \
--experiment_name=target_setting \
--workload=librispeech_conformer \
--submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py \
--submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py \
--tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
```

Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
{
"learning_rate": {
"feasible_points": [
0.001308209823469072
0.002106913873888147
]
},
"beta1": {
"feasible_points": [
0.9731333693827139
0.8231189937738506
]
},
"beta2": {
"feasible_points": [
0.9981232922116359
0.8774571227688758
]
},
"warmup_steps": {
"feasible_points": [
9999
1199
]
},
"weight_decay": {
"feasible_points": [
0.16375311233774334
0.27590534177690645
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
"warmup_steps": {
"feasible_points": [
1200
720
]
},
"weight_decay": {
Expand Down
21 changes: 5 additions & 16 deletions scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


def generate_eval_cols(metrics):
splits = ['train', 'validation', 'test']
splits = ['train', 'validation']
return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)]


Expand Down Expand Up @@ -108,15 +108,13 @@ def get_index_that_reaches_best(workload_df, metric_col):

def get_index_that_reaches_target(workload_df,
validation_metric,
test_metric,
validation_target,
test_target):
validation_target):
"""Get the eval index in which a workload reaches the target metric_col.

Args:
workload_df: A subset of a submission's trials DataFrame that
includes only the trials in a single workload.
metric_col: Name of array column in workload_df (e.g., `validation/l1_loss`).
metric_col: Name of array column in workload_df (e.g. `validation/l1_loss`).
target: Target value for metric_col.

Returns:
Expand All @@ -125,20 +123,13 @@ def get_index_that_reaches_target(workload_df,
"""
is_minimized = check_if_minimized(validation_metric)
validation_series = workload_df[validation_metric]
test_series = workload_df[test_metric]

validation_series = validation_series[validation_series != np.nan]
validation_series = validation_series[test_series != np.nan]
test_series = test_series[validation_series != np.nan]
test_series = test_series[test_series != np.nan]

op = operator.le if is_minimized else operator.ge
validation_target_reached = validation_series.apply(
lambda x: op(x, validation_target))
test_target_reached = test_series.apply(lambda x: op(x, test_target))

target_reached = pd.Series(validation_target_reached[0]
& test_target_reached[0])
target_reached = pd.Series(validation_target_reached[0])
# Remove trials that never reach the target
target_reached = target_reached[target_reached.apply(np.any)]

Expand Down Expand Up @@ -188,12 +179,10 @@ def get_times_for_submission(submission,
workload_init_kwargs=workload_init_kwargs)
metric_name = workload_obj.target_metric_name
validation_metric = f'validation/{metric_name}'
test_metric = f'test/{metric_name}'
validation_target = workload_obj.validation_target_value
test_target = workload_obj.test_target_value

trial_idx, time_idx = get_index_that_reaches_target(
group, validation_metric, test_metric, validation_target, test_target)
group, validation_metric, validation_target)
if time_idx > -1:
time_val = group[time_col].loc[trial_idx][time_idx]
else:
Expand Down
29 changes: 25 additions & 4 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import gc
import importlib
import itertools
import json
import os
import struct
Expand Down Expand Up @@ -133,6 +134,14 @@
flags.DEFINE_boolean('save_checkpoints',
True,
'Whether or not to checkpoint the model at every eval.')
flags.DEFINE_integer(
'hparam_start_index',
None,
'Start index to slice set of hyperparameters in tuning search space.')
flags.DEFINE_integer(
'hparam_end_index',
None,
'End index to slice set of hyperparameters in tuning spearch space.')
flags.DEFINE_integer(
'rng_seed',
None,
Expand Down Expand Up @@ -331,9 +340,12 @@ def train_once(

train_state['accumulated_submission_time'] += (
train_step_end_time - train_state['last_step_end_time'])
# Use 3x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
else 3 * workload.max_allowed_runtime_sec)
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] <
workload.max_allowed_runtime_sec)
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
# Check if submission is eligible for an untimed eval.
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):
Expand Down Expand Up @@ -455,6 +467,8 @@ def score_submission_on_workload(workload: spec.Workload,
num_tuning_trials: Optional[int] = None,
log_dir: Optional[str] = None,
save_checkpoints: Optional[bool] = True,
hparam_start_index: Optional[bool] = None,
hparam_end_index: Optional[bool] = None,
rng_seed: Optional[int] = None):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
Expand Down Expand Up @@ -500,7 +514,9 @@ def score_submission_on_workload(workload: spec.Workload,
json.load(search_space_file), num_tuning_trials)
all_timings = []
all_metrics = []
for hi, hyperparameters in enumerate(tuning_search_space):
tuning_search_space_iter = itertools.islice(
enumerate(tuning_search_space), hparam_start_index, hparam_end_index)
for hi, hyperparameters in tuning_search_space_iter:
# Generate a new seed from hardware sources of randomness for each trial.
if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]
Expand Down Expand Up @@ -545,7 +561,7 @@ def score_submission_on_workload(workload: spec.Workload,
all_timings.append(timing)
all_metrics.append(metrics)
score = min(all_timings)
for ti in range(num_tuning_trials):
for ti, _ in tuning_search_space_iter:
logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}')
logging.info(f'Hyperparameters: {tuning_search_space[ti]}')
logging.info(f'Metrics: {all_metrics[ti]}')
Expand All @@ -554,6 +570,9 @@ def score_submission_on_workload(workload: spec.Workload,
logging.info(f'Total number of evals: {num_evals}')
logging.info('=' * 20)
else:
if tuning_search_space is not None:
raise ValueError(
'Cannot provide a tuning search space when using self tuning.')
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
Expand Down Expand Up @@ -621,6 +640,8 @@ def main(_):
num_tuning_trials=FLAGS.num_tuning_trials,
log_dir=logging_dir_path,
save_checkpoints=FLAGS.save_checkpoints,
hparam_start_index=FLAGS.hparam_start_index,
hparam_end_index=FLAGS.hparam_end_index,
rng_seed=FLAGS.rng_seed)
logging.info(f'Final {FLAGS.workload} score: {score}')

Expand Down