diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..b71e42e01 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Change log + +## TODO: algorithmic-efficiency 0.1.0 + +First release of AlgoPerf benchmarking code. diff --git a/README.md b/README.md index 1be096c2e..6ffbab6f7 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,12 @@ - [Getting Started](#getting-started) - [Rules](#rules) - [Contributing](#contributing) +- [Diclaimers](#disclaimers) +- [FAQS](#faqs) - [Citing AlgoPerf Benchmark](#citing-algoperf-benchmark) + ## Installation You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended). @@ -126,7 +129,15 @@ To use the Docker container as an interactive virtual environment, you can run a \ --keep_container_alive true ``` -2. Open a bash terminal + Note: You may have to use double quotes around `algorithmic-efficiency` [path] in the mounting `-v` flag. If the above command fails try replacing the following line: + ```bash + -v $HOME/algorithmic-efficiency:/algorithmic-efficiency2 \ + ``` + with + ``` + -v $HOME"/algorithmic-efficiency:/algorithmic-efficiency" \ + ``` + - Open a bash terminal ```bash docker exec -it /bin/bash ``` @@ -221,9 +232,60 @@ The rules for the MLCommons Algorithmic Efficency benchmark can be found in the If you are interested in contributing to the work of the working group, feel free to [join the weekly meetings](https://mlcommons.org/en/groups/research-algorithms/), open issues. See our [CONTRIBUTING.md](CONTRIBUTING.md) for MLCommons contributing guidelines and setup and workflow instructions. -# Note on shared data pipelines between JAX and PyTorch +# Disclaimers + +## Shared data pipelines between JAX and PyTorch The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads are using the same TensorFlow input pipelines. Due to differences in how Jax and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads. Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. + +# FAQS + +## Setup and Platform + +### My machine only has one GPU. How can I use this repo? +You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes in our reference algorithms `algorithmic-efficiency/baselines` and `algorithmic-efficiency/reference_algorithms` are tuned for a machine with 8 16GB V100 GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' +on the benchmarking hardware, so if you are using fewer +GPUs with higher per GPU memory, please monitor your memory usage +to make make sure it will fit on 8xV100 GPUs with 16GB of VRAM per card. + +### How do I run this on my SLURM cluster? +You may run into issues with `sudo` and `docker` on a SLURM cluster. To run the workloads in a SLURM cluster you can use Apptainer (previously Singularity), see this [section](using-singularity/apptainer-instead-of-docker). +### How can I run this on my AWS/GCP/Azure cloud project? + Depending on your virtual machine, you may have to install the correct GPU drivers and the NVIDIA Docker toolkit. For example, in GCP you will have to do the following. +1. If you don't have a VM instance yet, we recommend creating a +new Compute Instance with the "Deep Learning on Linux" Image in Boot disk options. +2. To install the NVIDIA Docker toolkit, you can use `scripts/cloud-startup.sh` as a startup script for the VM. This will automate the installation of the NVIDIA GPU Drivers and NVIDIA Docker toolkit. + +## Submissions +### Can submission be structured using multiple files? +Yes, your submission can be structured using multiple files. +### Can I install custom dependencies? +You may use custom dependencies as long as they do not conflict with any of the pinned packages in `algorithmic-efficiency/setup.cfg`. +To include your custom dependencies in your submission, please include them in a requirements.txt file. Please refer to the [Software dependencies](https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#software-dependencies) section of our rules. +### How can I know if my code can be run on benchmarking hardware? +The benchmarking hardware specifications are documented in the [Getting Started Document](./getting_started.md). +We recommend monitoring your submission's memory usage so that it does not exceed the available memory +on the competition hardware. We also recommend to do a dry run using a cloud instance. +### Are we allowed to use our own hardware to self-report the results? +You only have to use the competition hardware for runs that are directly involved in the scoring procedure. This includes all runs for the self-tuning ruleset, but only the runs of the best hyperparameter configuration in each study for the external tuning ruleset. For example, you could use your own (different) hardware to tune your submission and identify the best hyperparameter configuration (in each study) and then only run this configuration (i.e. 5 runs, one for each study) on the competition hardware. + +# Citing AlgoPerf Benchmark +If you use the **AlgoPerf** Benchmark in your work, please consider citing: + +> [George E. Dahl, Frank Schneider, Zachary Nado, et al.
+> **Benchmarking Neural Network Training Algorithms**
+> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179) + +```bibtex +@misc{dahl2023algoperf, + title={{Benchmarking Neural Network Training Algorithms}}, + author={Dahl, George E. and Schneider, Frank and Nado, Zachary and Agarwal, Naman and Sastry, Chandramouli Shama and Hennig, Philipp and Medapati, Sourabh and Eschenhagen, Runa and Kasimbeg, Priya and Suo, Daniel and Bae, Juhan and Gilmer, Justin and Peirson, Abel L. and Khan, Bilal and Anil, Rohan and Rabbat, Mike and Krishnan, Shankar and Snider, Daniel and Amid, Ehsan and Chen, Kongtao and Maddison, Chris J. and Vasudev, Rakshith and Badura, Michal and Garg, Ankush and Mattson, Peter}, + year={2023}, + eprint={2306.07179}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algorithmic_efficiency/workloads/fastmri/workload.py index ecfa27547..4677dc2bb 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/workload.py @@ -51,7 +51,7 @@ def num_validation_examples(self) -> int: @property def num_test_examples(self) -> int: - return 3581 + return 3548 @property def eval_batch_size(self) -> int: diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 665c3c894..2da7dcfb3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -71,7 +71,10 @@ def forward(self, x): class Subsample(nn.Module): - def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): + def __init__(self, + encoder_dim: int = 0, + input_dropout_rate: float = 0.0, + num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim self.input_dropout_rate = input_dropout_rate @@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): self.conv2 = Conv2dSubsampling( input_channels=encoder_dim, output_channels=encoder_dim) - self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True) + self.linear = nn.Linear( + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) self.dropout = nn.Dropout(p=self.input_dropout_rate) @@ -123,6 +129,7 @@ def __init__(self, self.kernel = nn.Parameter( torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) self.bias = nn.Parameter(torch.zeros(output_channels)) + self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) def get_same_padding(self, input_shape): in_height, in_width = input_shape[2:] @@ -162,15 +169,11 @@ def forward(self, inputs, paddings): input_length = paddings.shape[1] stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length - padded_paddings = torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) - ], - dim=2) + padded_paddings = F.pad( + paddings[:, None, :], (0, pad_len), mode='constant', value=0) out_padding = F.conv1d( input=padded_paddings, - weight=torch.ones([1, 1, 1], device=paddings.device), + weight=self.paddings_kernel, stride=self.filter_stride[:1]) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) @@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig): self.config = config self.ln = LayerNorm(dim=config.encoder_dim) - self.linear1 = nn.LazyLinear( + self.linear1 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim * config.feed_forward_expansion_factor, bias=True) self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) - self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True) + self.linear2 = nn.Linear( + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True) if config.feed_forward_residual_dropout_rate is None: feed_forward_residual_dropout_rate = 0.1 @@ -253,217 +260,32 @@ def forward(self, inputs): return inputs * scale -class MHSAwithQS(nn.MultiheadAttention): - # pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name +class MHSAwithQS(nn.Module): + def __init__(self, config: ConformerConfig): - super().__init__( - embed_dim=config.encoder_dim, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout_rate, - bias=True, - batch_first=True) + super().__init__() + self.embed_dim = config.encoder_dim + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout_rate + self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) + self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) - def _scaled_in_proj_weight(self): - # Scale the query projection weight. - qs_input = self.in_proj_weight[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2) - in_proj_queryW_scaled = self.qs(qs_input).transpose( - 1, 2).view(*self.in_proj_weight[:self.embed_dim].shape) - in_proj_weight = torch.cat( - [in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]]) - return in_proj_weight - - def _scaled_in_proj_bias(self): - # Scale the query bias. - in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads)).view(-1) - in_proj_bias = torch.cat( - [in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]]) - return in_proj_bias - - def forward(self, - query, - key, - value, - key_padding_mask=None, - need_weights: bool = True, - attn_mask=None, - average_attn_weights: bool = True): - r""" - Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and byte masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - .. note:: - `batch_first` argument is ignored for unbatched inputs. - """ - is_batched = query.dim() == 3 - if key_padding_mask is not None: - _kpm_dtype = key_padding_mask.dtype - if _kpm_dtype != torch.bool and not torch.is_floating_point( - key_padding_mask): - raise AssertionError( - "only bool and floating types of key_padding_mask are supported") - why_not_fast_path = '' - if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif attn_mask is not None: - why_not_fast_path = "attn_mask was not None" - elif query.is_nested and key_padding_mask is not None: - why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" - elif self.num_heads % 2 == 1: - why_not_fast_path = "num_heads is odd" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - - if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) - for x in tensor_args]): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any( - [x is not None and x.requires_grad for x in tensor_args]): - why_not_fast_path = ( - "grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") - if not why_not_fast_path: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - return torch._native_multi_head_attention( - query, - key, - value, - self.embed_dim, - self.num_heads, - in_proj_weight, - in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - key_padding_mask if key_padding_mask is not None else attn_mask, - need_weights, - average_attn_weights, - 1 if key_padding_mask is not None else - 0 if attn_mask is not None else None) - any_nested = query.is_nested or key.is_nested or value.is_nested - assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + - f"The fast path was not hit because {why_not_fast_path}") - - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - - if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) - else: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - in_proj_weight, in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, average_attn_weights=average_attn_weights) - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights - else: - return attn_output, attn_output_weights + def forward(self, inputs, key_padding_mask=None): + batch_size, seq_len, embed_dim = inputs.shape + q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) + q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=~key_padding_mask[:, None, None], + dropout_p=self.dropout, + ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + out = self.out_proj(out) + return out class MultiHeadedSelfAttention(nn.Module): @@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig): def forward(self, outputs, paddings): outputs = self.ln(outputs) - outputs, _ = self.self_attention( - query=outputs, - key=outputs, - value=outputs, - key_padding_mask=paddings==1, - need_weights=False, + outputs = self.self_attention( + outputs, + key_padding_mask=paddings == 1, ) outputs = self.dropout(outputs) return outputs @@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig): self.register_buffer('running_var', running_var) self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) - self.register_buffer('momentum', - torch.FloatTensor([config.batch_norm_momentum])) - self.register_buffer('epsilon', - torch.FloatTensor([config.batch_norm_epsilon])) + self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) - # self.momentum = config.batch_norm_momentum - # self.epsilon = config.batch_norm_epsilon - # self.dim = config.encoder_dim + self.momentum = config.batch_norm_momentum + self.epsilon = config.batch_norm_epsilon def forward(self, inputs, input_paddings): #inputs: NHD #padding: NH + """ + Alternatively: + inputs[input_paddings==0] = F.batch_norm( + input = inputs[input_paddings==0], + running_mean = self.running_mean, + running_var = self.running_var, + weight = 1+self.scale, + bias = self.bias, + training = self.training, + momentum=1-self.momentum, + eps=self.epsilon + ) + inputs.masked_fill(input_paddings[...,None] != 0, 0) + return inputs + """ mask = 1 - input_paddings[:, :, None] if self.training: count = mask.sum() @@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig): else: input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( - encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate) + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate, + num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 24f4eb1fc..c4f4a1247 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -47,8 +47,11 @@ def init_model_fn( input_dropout_rate. """ torch.random.manual_seed(rng[0]) - # Disable cudnn benchmark to avoid OOM errors. + # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) model = conformer_model.ConformerEncoderDecoder( conformer_model.ConformerConfig( attention_residual_dropout_rate=dropout_rate, @@ -57,13 +60,6 @@ def init_model_fn( input_dropout_rate=aux_dropout_rate, use_specaug=self.use_specaug)) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - # Run model once to initialize lazy layers. - # Run the initialization in eval mode to disable BN tracking. - model = model.eval() - t = MAX_INPUT_LENGTH - wave = torch.randn((2, t)) - pad = torch.zeros_like(wave) - _ = model(wave, pad) conformer_model.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/getting_started.md b/getting_started.md index 2942e632b..96e58edab 100644 --- a/getting_started.md +++ b/getting_started.md @@ -13,7 +13,7 @@ To get started you will have to make a few decisions and install the repository 1. Decide if you would like to develop your submission in either Pytorch or Jax. 2. Set up your workstation or VM. We recommend to use a setup similar to the [benchmarking hardware](https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#benchmarking-hardware). The specs on the benchmarking machines are: - - 8 V100 GPUs + - 8 V100 GPUs - 240 GB in RAM - 2 TB in storage (for datasets). 3. Install the algorithmic package and dependencies, see [Installation](./README.md#installation). diff --git a/submission_runner.py b/submission_runner.py index bb06f698d..656599a42 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -217,10 +217,8 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb'] - eager_backend_workloads = [ - 'librispeech_conformer', 'librispeech_deepspeech' - ] + compile_error_workloads = ['librispeech_conformer', 'ogbg', 'criteo1tb'] + eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning( diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 1d243d83e..d414001dd 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -38,11 +38,15 @@ def sd_transform(sd): out = {} for k in sd: if 'Attention' in ''.join(k): - if 'in_proj' in k[-1]: - new_key = k[:-1] + if 'Dense_0' in k[-2]: + # In-proj + new_key = k[:-2] chunks = sd[k].chunk(3) for t, c in zip(['query', 'key', 'value'], chunks): - out[new_key + (t, k[-1].split('_')[-1])] = c + out[new_key + (t, k[-1])] = c + elif 'Dense_1' in k[-2]: + # Out-proj + out[(*k[:-2], 'out', k[-1])] = sd[k] else: out[k] = sd[k] else: