From d672b841e03f80ebb1e3b0f6dd87eca646dc44c1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Apr 2025 03:20:46 +0000 Subject: [PATCH 001/123] add jit-friendly dropout w rate in call --- algoperf/jax_utils.py | 121 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 algoperf/jax_utils.py diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py new file mode 100644 index 000000000..ddafd77c6 --- /dev/null +++ b/algoperf/jax_utils.py @@ -0,0 +1,121 @@ +from collections.abc import Sequence + +import jax +import jax.numpy as jnp +from jax import lax, random + +import flax.linen as nn +from flax.linen.module import Module, compact, merge_param +from flax.typing import PRNGKey + + +# Custom Layers +class Dropout(Module): + """Create a dropout layer. + Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes to dropout rate during training by: + 1) adding rate argument to the __call__ method + 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code + + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(4)(x) + ... x = nn.Dropout(0.5, deterministic=not train)(x) + ... return x + + >>> model = MLP() + >>> x = jnp.ones((1, 3)) + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + >>> model.apply(variables, x, train=False) # don't use dropout + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and + masked, whereas if true, no mask is applied and the inputs are returned as + is. + rng_collection: the rng collection name to use when requesting an rng key. + """ + + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + rate: the dropout probability. (_not_ the keep rate!) + rng: an optional PRNGKey used as the random key, if not specified, one + will be generated using ``make_rng`` with the ``rng_collection`` name. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = merge_param("deterministic", self.deterministic, deterministic) + + # Override self.rate if rate is passed to __call__ + if not (self.rate is not None and rate is not None): + rate = merge_param("rate", self.rate, rate) + + if self.legacy: + if rate == 0.0: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) + + if deterministic: + return inputs + + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + + return lax.select( + mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs) + ) + + +# Utilities for debugging +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, + ) + print(tabulate_fn(fake_inputs, train=False)) From aa25e208ef7f283d2ee3d4ebc304e45d6c24fa8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 08:28:01 +0000 Subject: [PATCH 002/123] remove nan_to_num convertion --- algoperf/jax_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index ddafd77c6..3ca3f1bfc 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -104,10 +104,7 @@ def __call__( broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) - - return lax.select( - mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs) - ) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging From 85a35785a523ff44a1fca93e41d3cb082d72d4c2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 09:35:39 +0000 Subject: [PATCH 003/123] update models with custom dropout layer --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 6 +++--- algoperf/workloads/fastmri/fastmri_jax/models.py | 5 +++-- .../imagenet_vit/imagenet_jax/models.py | 13 +++++++------ .../librispeech_jax/models.py | 11 ++++++----- .../librispeech_jax/models.py | 5 +++-- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +++- algoperf/workloads/wmt/wmt_jax/models.py | 16 +++++++++------- 7 files changed, 34 insertions(+), 26 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 6d9a489ff..e89db0c86 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -1,11 +1,11 @@ """A JAX implementation of DLRM-Small.""" - from typing import Sequence import flax.linen as nn from jax import nn as jnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -89,7 +89,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if self.dropout_rate and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + x = Dropout(rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -212,7 +212,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.LayerNorm()(top_mlp_input) if (self.dropout_rate is not None and self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( + top_mlp_input = Dropout( rate=self.dropout_rate, deterministic=not train)( top_mlp_input) logits = top_mlp_input diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 44bff0e21..f29e0be22 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp +from algoperf.jax_utils import Dropout def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -172,7 +173,7 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = nn.Dropout( + x = Dropout( self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x) x = nn.Conv( @@ -186,7 +187,7 @@ def __call__(self, x, train=True): else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) - x = nn.Dropout( + x = Dropout( self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7ce3a0395..902658fbe 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -11,6 +11,7 @@ import jax.numpy as jnp from algoperf import spec +from algoperf.jax_utils import Dropout def posemb_sincos_2d(h: int, @@ -53,7 +54,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = nn.Dropout(rate=self.dropout_rate)(x, train) + x = Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -76,7 +77,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) @@ -85,7 +86,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: use_glu=self.use_glu, dropout_rate=self.dropout_rate, name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y else: y = x @@ -95,7 +96,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -105,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: use_glu=self.use_glu, dropout_rate=self.dropout_rate, name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + y = Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -205,7 +206,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: dropout_rate = self.dropout_rate if dropout_rate is None: dropout_rate = 0.0 - x = nn.Dropout(rate=dropout_rate)(x, not train) + x = Dropout(rate=dropout_rate)(x, not train) x = Encoder( depth=self.depth, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..51e93acc9 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -26,6 +26,7 @@ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter +from algoperf.jax_utils import Dropout @struct.dataclass @@ -129,7 +130,7 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( seq_length=outputs.shape[1]) - outputs = nn.Dropout( + outputs = Dropout( rate=self.input_dropout_rate, deterministic=not train)( outputs) @@ -217,7 +218,7 @@ def __call__(self, inputs, padding_mask=None, train=False): 'config.activation_function_name values, recieved ' f'{config.activation_function_name}') inputs = activation_fn(inputs) - inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( + inputs = Dropout(rate=config.feed_forward_dropout_rate)( inputs, deterministic=not train) inputs = inputs * padding_mask @@ -234,7 +235,7 @@ def __call__(self, inputs, padding_mask=None, train=False): else: feed_forward_residual_dropout_rate = ( config.feed_forward_residual_dropout_rate) - inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( + inputs = Dropout(rate=feed_forward_residual_dropout_rate)( inputs, deterministic=not train) return inputs @@ -416,7 +417,7 @@ def __call__(self, inputs, paddings, train): attention_residual_dropout_rate = 0.1 else: attention_residual_dropout_rate = config.attention_residual_dropout_rate - result = nn.Dropout( + result = Dropout( rate=attention_residual_dropout_rate, deterministic=not train)( result) @@ -578,7 +579,7 @@ def __call__(self, conv_residual_dropout_rate = 0.0 else: conv_residual_dropout_rate = config.conv_residual_dropout_rate - inputs = nn.Dropout( + inputs = Dropout( rate=conv_residual_dropout_rate, deterministic=not train)( inputs) return inputs diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..f937a1692 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -20,6 +20,7 @@ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter +from algoperf.jax_utils import Dropout Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -110,7 +111,7 @@ def __call__(self, inputs, output_paddings, train): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = nn.Dropout( + outputs = Dropout( rate=input_dropout_rate, deterministic=not train)( outputs) @@ -216,7 +217,7 @@ def __call__(self, inputs, input_paddings=None, train=False): feed_forward_dropout_rate = 0.1 else: feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = nn.Dropout(rate=feed_forward_dropout_rate)( + inputs = Dropout(rate=feed_forward_dropout_rate)( inputs, deterministic=not train) return inputs diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..f5710a3ab 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -6,6 +6,8 @@ import jax.numpy as jnp import jraph +from algoperf.jax_utils import Dropout + def _make_embed(latent_dim, name): @@ -50,7 +52,7 @@ def __call__(self, graph, train): dropout_rate = 0.1 else: dropout_rate = self.dropout_rate - dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) + dropout = Dropout(rate=dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 97fee032f..04f46e8ac 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -11,6 +11,8 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout + @struct.dataclass class TransformerConfig: @@ -172,14 +174,14 @@ def __call__(self, inputs): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)( x) - output = nn.Dropout(rate=dropout_rate)( + output = Dropout(rate=dropout_rate)( output, deterministic=cfg.deterministic) return output @@ -229,7 +231,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -293,7 +295,7 @@ def __call__(self, dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -312,7 +314,7 @@ def __call__(self, deterministic=cfg.deterministic)( cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -366,7 +368,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) @@ -436,7 +438,7 @@ def __call__(self, dropout_rate = 0.1 else: dropout_rate = cfg.dropout_rate - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) From 9354079c1a97b51fe722069b29608953f58aa107 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 10:36:48 +0000 Subject: [PATCH 004/123] add functional dropout for criteo, fastmri, and vit --- .../criteo1tb/criteo1tb_jax/models.py | 21 +++++---- .../workloads/fastmri/fastmri_jax/models.py | 20 ++++----- .../imagenet_vit/imagenet_jax/models.py | 45 ++++++++++--------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index e89db0c86..c56748bb1 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -28,7 +28,10 @@ class DLRMResNet(nn.Module): embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=None): + if not dropout_rate: + dropout_rate=self.dropout_rate + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -88,8 +91,8 @@ def scaled_init(key, shape, dtype=jnp.float_): stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( top_mlp_input) x = nn.relu(x) - if self.dropout_rate and layer_idx == num_layers_top - 2: - x = Dropout(rate=self.dropout_rate, deterministic=not train)(x) + if dropout_rate and layer_idx == num_layers_top - 2: + x = Dropout(deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -151,7 +154,10 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=None): + if not dropout_rate: + dropout_rate = self.dropout_rate + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -210,10 +216,9 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) + top_mlp_input = Dropout(deterministic=not train)( + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index f29e0be22..3d5460c18 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -62,10 +62,9 @@ class UNet(nn.Module): use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True): - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 + def __call__(self, x, train=True, dropout_rate=None): + if not dropout_rate: + dropout_rate = self.dropout_rate # pylint: disable=invalid-name _ConvBlock = functools.partial( @@ -144,7 +143,7 @@ class ConvBlock(nn.Module): use_layer_norm: bool @nn.compact - def __call__(self, x, train=True): + def __call__(self, x, train=True, dropout_rate=None): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -153,6 +152,8 @@ def __call__(self, x, train=True): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ + if not dropout_rate: + dropout_rate=self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -173,9 +174,8 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout(broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate ) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -188,8 +188,8 @@ def __call__(self, x, train=True): x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) x = Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 902658fbe..10a90f37d 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -39,8 +39,11 @@ class MlpBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: """Applies Transformer MlpBlock module.""" + if not dropout_rate: + dropout_rate = self.dropout_rate + inits = { 'kernel_init': nn.initializers.xavier_uniform(), 'bias_init': nn.initializers.normal(stddev=1e-6), @@ -54,7 +57,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = Dropout(rate=self.dropout_rate)(x, train) + x = Dropout()(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x @@ -68,7 +71,10 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: + if not dropout_rate: + dropout_rate=self.dropout_rate + if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( @@ -77,16 +83,15 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(rate=self.dropout_rate)(y, train) + y = Dropout()(y, train, dropout_rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = Dropout(rate=self.dropout_rate)(y, train) + name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) + y = Dropout()(y, train, rate=dropout_rate) x = x + y else: y = x @@ -96,7 +101,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(rate=self.dropout_rate)(y, train) + y = Dropout()(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -104,9 +109,8 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = Dropout(rate=self.dropout_rate)(y, train) + name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) + y = Dropout()(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -123,7 +127,10 @@ class Encoder(nn.Module): use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: + if not dropout_rate: + dropout_rate=self.dropout_rate + # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( @@ -132,7 +139,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=self.dropout_rate) + )(dropout_rate=dropout_rate) x = block(x, train) if not self.use_post_layer_norm: return nn.LayerNorm(name='encoder_layernorm')(x) @@ -187,7 +194,9 @@ def get_posemb(self, return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: + def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor: + if not dropout_rate: + dropout_rate = self.dropout_rate # Patch extraction x = nn.Conv( self.width, @@ -203,10 +212,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - x = Dropout(rate=dropout_rate)(x, not train) + x = Dropout()(x, not train, rate=dropout_rate) x = Encoder( depth=self.depth, @@ -214,9 +220,8 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate, name='Transformer')( - x, train=not train) + x, train=not train, dropout_rate=dropout_rate) if self.use_map: x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) From feb9cc5cfd57575e11c91039c0292d26014e99fe Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 5 May 2025 10:48:34 +0000 Subject: [PATCH 005/123] add functional dropout for ogbg --- algoperf/workloads/ogbg/ogbg_jax/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index f5710a3ab..6ced9bef5 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -47,12 +47,10 @@ class GNN(nn.Module): activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train): - if self.dropout_rate is None: - dropout_rate = 0.1 - else: + def __call__(self, graph, train, dropout_rate=None): + if not dropout_rate: dropout_rate = self.dropout_rate - dropout = Dropout(rate=dropout_rate, deterministic=not train) + dropout = Dropout(deterministic=not train, rate=dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) From 9bba078b5e17a7881ffaa294a551129d4acf5c65 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 May 2025 21:43:18 +0000 Subject: [PATCH 006/123] modify wmt model for dropout passing --- algoperf/workloads/wmt/wmt_jax/models.py | 599 +++++++++++---------- algoperf/workloads/wmt/wmt_jax/workload.py | 4 +- 2 files changed, 310 insertions(+), 293 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 04f46e8ac..a5b484320 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -140,325 +140,342 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. + """Transformer MLP / feed-forward block. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - out_dim: optionally specify out dimension. - """ - config: TransformerConfig - out_dim: Optional[int] = None + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ - @nn.compact - def __call__(self, inputs): - """Applies Transformer MlpBlock module.""" - cfg = self.config - actual_out_dim = ( - inputs.shape[-1] if self.out_dim is None else self.out_dim) - x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = x * y - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - x) - output = Dropout(rate=dropout_rate)( - output, deterministic=cfg.deterministic) - return output + config: TransformerConfig + out_dim: Optional[int] = None + + @nn.compact + def __call__(self, inputs, dropout_rate=None): + """Applies Transformer MlpBlock module.""" + cfg = self.config + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) + x = x * y + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(x) + output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic) + return output class Encoder1DBlock(nn.Module): - """Transformer encoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, inputs, encoder_mask=None): - """Applies Encoder1DBlock module. + """Transformer encoder layer. - Args: - inputs: input data. - encoder_mask: encoder self-attention mask. - - Returns: - output after transformer encoder block. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. """ - cfg = self.config - pre_ln = cfg.pre_ln - # Attention block. - assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * x, x, mask=encoder_mask) - - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + inputs - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # MLP block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) - - return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) + config: TransformerConfig + + @nn.compact + def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + """Applies Encoder1DBlock module. + + Args: + inputs: input data. + encoder_mask: encoder self-attention mask. + + Returns: + output after transformer encoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Attention block. + assert inputs.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) + + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + inputs + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # MLP block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = MlpBlock(config=cfg)(y) + + return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) class EncoderDecoder1DBlock(nn.Module): - """Transformer encoder-decoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies EncoderDecoder1DBlock module. + """Transformer encoder-decoder layer. - Args: - targets: input data for decoder - encoded: input data from encoder - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output after transformer encoder-decoder block. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. """ - cfg = self.config - pre_ln = cfg.pre_ln - # Decoder block. - assert targets.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + config: TransformerConfig - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode)( - cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + targets - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # Encoder-Decoder block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - y = y + x - if not pre_ln: - y = nn.LayerNorm(dtype=cfg.dtype)(y) - - # MLP block. - z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) - - return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) + @nn.compact + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: input data for decoder + encoded: input data from encoder + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output after transformer encoder-decoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Decoder block. + assert targets.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + targets + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # Encoder-Decoder block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = y + x + if not pre_ln: + y = nn.LayerNorm(dtype=cfg.dtype)(y) + + # MLP block. + z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y + z = MlpBlock(config=cfg)(z) + + return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, inputs, inputs_positions=None, encoder_mask=None): - """Applies Transformer model on the inputs. + """Transformer Model Encoder for sequence to sequence translation. - Args: - inputs: input data - inputs_positions: input subsequence positions for packed examples. - encoder_mask: decoder self-attention mask. - - Returns: - output of a transformer encoder. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. """ - cfg = self.config - assert inputs.ndim == 2 # (batch, len) - # Input Embedding - if self.shared_embedding is None: - input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - input_embed = self.shared_embedding - x = inputs.astype('int32') - x = input_embed(x) - x = AddPositionEmbs( - config=cfg, decode=False, name='posembed_input')( - x, inputs_positions=inputs_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - x = Encoder1DBlock( - config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) - - encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) - if cfg.pre_ln else x) - - return encoded + config: TransformerConfig + shared_embedding: Any = None + + @nn.compact + def __call__( + self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None + ): + """Applies Transformer model on the inputs. + + Args: + inputs: input data + inputs_positions: input subsequence positions for packed examples. + encoder_mask: decoder self-attention mask. + + Returns: + output of a transformer encoder. + """ + cfg = self.config + assert inputs.ndim == 2 # (batch, len) + + # Input Embedding + if self.shared_embedding is None: + input_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + input_embed = self.shared_embedding + x = inputs.astype("int32") + x = input_embed(x) + x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")( + x, inputs_positions=inputs_positions + ) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask) + + encoded = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) + if cfg.pre_ln + else x + ) + + return encoded class Decoder(nn.Module): - """Transformer Model Decoder for sequence to sequence translation. + """Transformer Model Decoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies Transformer model on the inputs. - - Args: - encoded: encoded input data from encoder. - targets: target inputs. - targets_positions: input subsequence positions for packed examples. - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output of a transformer decoder. + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. """ - cfg = self.config - assert encoded.ndim == 3 # (batch, len, depth) - assert targets.ndim == 2 # (batch, len) + config: TransformerConfig + shared_embedding: Any = None - # Target Embedding - if self.shared_embedding is None: - output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - output_embed = self.shared_embedding - - y = targets.astype('int32') - if not cfg.decode: - y = shift_right(y) - y = output_embed(y) - y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name='posembed_output')( - y, inputs_positions=targets_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - - y = y.astype(cfg.dtype) - - # Target-Input Decoder - for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock( - config=cfg, name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) - y = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) - if cfg.pre_ln else y) - - # Use the transpose of embedding matrix for logit transform. - logits = output_embed.attend(y.astype(jnp.float32)) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - return logits + @nn.compact + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded input data from encoder. + targets: target inputs. + targets_positions: input subsequence positions for packed examples. + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output of a transformer decoder. + """ + cfg = self.config + + assert encoded.ndim == 3 # (batch, len, depth) + assert targets.ndim == 2 # (batch, len) + + # Target Embedding + if self.shared_embedding is None: + output_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + output_embed = self.shared_embedding + + y = targets.astype("int32") + if not cfg.decode: + y = shift_right(y) + y = output_embed(y) + y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")( + y, inputs_positions=targets_positions + ) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + + y = y.astype(cfg.dtype) + + # Target-Input Decoder + for lyr in range(cfg.num_layers): + y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + ) + y = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) + if cfg.pre_ln + else y + ) + + # Use the transpose of embedding matrix for logit transform. + logits = output_embed.attend(y.astype(jnp.float32)) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + return logits class Transformer(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..240ad2c11 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -209,8 +209,8 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = 0.0, + aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" init_fake_batch_size = 2 From 31f601977335e170c54e50d37ece29ecaea9a314 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 May 2025 22:19:20 +0000 Subject: [PATCH 007/123] modify wmt model for dropout passing --- .../librispeech_jax/models.py | 2 +- algoperf/workloads/wmt/wmt_jax/models.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index f937a1692..003bf4ea7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -73,7 +73,7 @@ class Subsample(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train): + def __call__(self, inputs, output_paddings, train, dropout_rate=None): config = self.config outputs = jnp.expand_dims(inputs, axis=-1) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index a5b484320..b84fb6d96 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -236,7 +236,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) @@ -324,7 +324,7 @@ def __call__( # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) @@ -382,7 +382,7 @@ def __call__( # Input Encoder for lyr in range(cfg.num_layers): - x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask) + x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) encoded = ( nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) @@ -464,6 +464,7 @@ def __call__( encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, ) y = ( nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) @@ -503,7 +504,7 @@ def setup(self): self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): + def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None): """Applies Transformer encoder-branch on the inputs. Args: @@ -528,7 +529,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): jnp.equal, dtype=cfg.dtype)) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) + inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate) def decode( self, @@ -595,7 +596,8 @@ def __call__(self, inputs_positions=None, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=None): """Applies Transformer model on the inputs. Args: @@ -612,7 +614,8 @@ def __call__(self, encoded = self.encode( inputs, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) return self.decode( encoded, @@ -620,4 +623,5 @@ def __call__(self, targets, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) + targets_segmentation=targets_segmentation, + dropout_rate=dropout_rate) From e36d29432a960283f9a44baf745644c5c3ddbdaa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:32:41 +0000 Subject: [PATCH 008/123] reformatting and dropout fixes to fastmri and vit --- .../criteo1tb/criteo1tb_jax/models.py | 7 +- .../workloads/fastmri/fastmri_jax/models.py | 18 +- .../workloads/fastmri/fastmri_jax/workload.py | 21 +- .../imagenet_vit/imagenet_jax/models.py | 75 ++- .../imagenet_vit/imagenet_jax/workload.py | 22 +- .../librispeech_jax/models.py | 4 +- algoperf/workloads/wmt/wmt_jax/models.py | 525 +++++++++--------- 7 files changed, 362 insertions(+), 310 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index c56748bb1..0b2126915 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,7 @@ from algoperf.jax_utils import Dropout + class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -30,7 +31,7 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): if not dropout_rate: - dropout_rate=self.dropout_rate + dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -157,7 +158,7 @@ class DlrmSmall(nn.Module): def __call__(self, x, train, dropout_rate=None): if not dropout_rate: dropout_rate = self.dropout_rate - + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -219,6 +220,6 @@ def scaled_init(key, shape, dtype=jnp.float_): if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_input = Dropout(deterministic=not train)( - top_mlp_input, rate=dropout_rate) + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 3d5460c18..7ecca2add 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -21,6 +21,7 @@ from algoperf.jax_utils import Dropout + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points @@ -57,13 +58,13 @@ class UNet(nn.Module): num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: float = 0.0 use_tanh: bool = False use_layer_norm: bool = False @nn.compact def __call__(self, x, train=True, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate # pylint: disable=invalid-name @@ -138,7 +139,7 @@ class ConvBlock(nn.Module): dropout_rate: Dropout probability. """ out_channels: int - dropout_rate: float + dropout_rate: float = 0.0 use_tanh: bool use_layer_norm: bool @@ -152,8 +153,8 @@ def __call__(self, x, train=True, dropout_rate=None): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ - if not dropout_rate: - dropout_rate=self.dropout_rate + if dropout_rate is None: + dropout_rate = self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -174,8 +175,9 @@ def __call__(self, x, train=True, dropout_rate=None): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = Dropout(broadcast_dims=(1, 2), deterministic=not train)( - x, rate=dropout_rate ) + x = Dropout( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), @@ -188,7 +190,7 @@ def __call__(self, x, train=True, dropout_rate=None): x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) x = Dropout( - broadcast_dims=(1, 2), deterministic=not train)( + dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( x, rate=dropout_rate) return x diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..17ce6b442 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -26,12 +26,21 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + if dropout_rate is None: + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) + else: + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + dropout_rate=dropout_rate) + params_rng, dropout_rng = jax.random.split(rng) variables = jax.jit( self._model.init)({'params': params_rng, 'dropout': dropout_rng}, diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 10a90f37d..227f7c297 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -39,9 +39,12 @@ class MlpBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=None) -> spec.Tensor: """Applies Transformer MlpBlock module.""" - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate inits = { @@ -57,7 +60,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = Dropout()(x, train, rate=dropout_rate) + x = Dropout(dropout_rate)(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x @@ -71,9 +74,12 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: - if not dropout_rate: - dropout_rate=self.dropout_rate + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=dropout_rate) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) @@ -83,15 +89,14 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout()(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) - y = Dropout()(y, train, rate=dropout_rate) + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y else: y = x @@ -101,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout()(y, train, rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) @@ -109,8 +114,10 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - name='MlpBlock_3')(y, train, dropout_rate=dropout_rate) - y = Dropout()(y, train)(rate=dropout_rate) + name='MlpBlock_3', + dropout_rate=dropout_rate)( + y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -127,9 +134,12 @@ class Encoder(nn.Module): use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor: - if not dropout_rate: - dropout_rate=self.dropout_rate + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate # Input Encoder for lyr in range(self.depth): @@ -139,7 +149,8 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - )(dropout_rate=dropout_rate) + dropout_rate=dropout_rate)( + dropout_rate=dropout_rate) x = block(x, train) if not self.use_post_layer_norm: return nn.LayerNorm(name='encoder_layernorm')(x) @@ -151,9 +162,12 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + dropout_rate: 0.0 @nn.compact - def __call__(self, x): + def __call__(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), @@ -166,7 +180,7 @@ def __call__(self, x): kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -180,7 +194,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: Optional[float] = 0.0 reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -194,8 +208,12 @@ def get_posemb(self, return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor: - if not dropout_rate: + def __call__(self, + x: spec.Tensor, + *, + train: bool = False, + dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: dropout_rate = self.dropout_rate # Patch extraction x = nn.Conv( @@ -212,7 +230,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - x = Dropout()(x, not train, rate=dropout_rate) + x = Dropout(dropout_rate)(x, not train, rate=dropout_rate) x = Encoder( depth=self.depth, @@ -220,11 +238,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer')( + name='Transformer', + dropout_rate=dropout_rate)( x, train=not train, dropout_rate=dropout_rate) if self.use_map: - x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + x = MAPHead( + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=dropout_rate)( + x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..5d07b5ff8 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -36,13 +36,21 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - self._model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + if dropout_rate is None: + self._model = models.ViT( + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) + else: + self._model = models.ViT( + dropout_rate=dropout_rate, + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 003bf4ea7..4cdb02ee1 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -111,9 +111,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = Dropout( - rate=input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index b84fb6d96..54a917a09 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -140,64 +140,68 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. + """Transformer MLP / feed-forward block. Attributes: config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ - config: TransformerConfig - out_dim: Optional[int] = None - - @nn.compact - def __call__(self, inputs, dropout_rate=None): - """Applies Transformer MlpBlock module.""" - cfg = self.config - actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim - x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(inputs) - x = x * y - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )(x) - output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic) - return output + config: TransformerConfig + out_dim: Optional[int] = None + + @nn.compact + def __call__(self, inputs, dropout_rate=None): + """Applies Transformer MlpBlock module.""" + cfg = self.config + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + inputs) + x = x * y + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )( + x) + output = Dropout()( + output, rate=dropout_rate, deterministic=cfg.deterministic) + return output class Encoder1DBlock(nn.Module): - """Transformer encoder layer. + """Transformer encoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ - config: TransformerConfig + config: TransformerConfig - @nn.compact - def __call__(self, inputs, encoder_mask=None, dropout_rate=None): - """Applies Encoder1DBlock module. + @nn.compact + def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + """Applies Encoder1DBlock module. Args: inputs: input data. @@ -206,60 +210,60 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): Returns: output after transformer encoder block. """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Attention block. - assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - )(cfg.attention_temp * x, x, mask=encoder_mask) - - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - x = x + inputs - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # MLP block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) - - return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) + cfg = self.config + pre_ln = cfg.pre_ln + + # Attention block. + assert inputs.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) + + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + inputs + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # MLP block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) + + return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) class EncoderDecoder1DBlock(nn.Module): - """Transformer encoder-decoder layer. + """Transformer encoder-decoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ - config: TransformerConfig + config: TransformerConfig - @nn.compact - def __call__( - self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=None, - ): - """Applies EncoderDecoder1DBlock module. + @nn.compact + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder @@ -270,81 +274,83 @@ def __call__( Returns: output after transformer encoder-decoder block. """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Decoder block. - assert targets.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode, - )(cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - x = x + targets - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # Encoder-Decoder block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) - y = y + x - if not pre_ln: - y = nn.LayerNorm(dtype=cfg.dtype)(y) - - # MLP block. - z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) - - return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) + cfg = self.config + pre_ln = cfg.pre_ln + + # Decoder block. + assert targets.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + + if dropout_rate is None: + if cfg.attention_dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = x + targets + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # Encoder-Decoder block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = y + x + if not pre_ln: + y = nn.LayerNorm(dtype=cfg.dtype)(y) + + # MLP block. + z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) + + return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. + """Transformer Model Encoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ - config: TransformerConfig - shared_embedding: Any = None + config: TransformerConfig + shared_embedding: Any = None - @nn.compact - def __call__( - self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None - ): - """Applies Transformer model on the inputs. + @nn.compact + def __call__(self, + inputs, + inputs_positions=None, + encoder_mask=None, + dropout_rate=None): + """Applies Transformer model on the inputs. Args: inputs: input data @@ -354,67 +360,66 @@ def __call__( Returns: output of a transformer encoder. """ - cfg = self.config - assert inputs.ndim == 2 # (batch, len) - - # Input Embedding - if self.shared_embedding is None: - input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), - ) - else: - input_embed = self.shared_embedding - x = inputs.astype("int32") - x = input_embed(x) - x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")( - x, inputs_positions=inputs_positions - ) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) - - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) - - encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) - if cfg.pre_ln - else x - ) - - return encoded + cfg = self.config + assert inputs.ndim == 2 # (batch, len) + + # Input Embedding + if self.shared_embedding is None: + input_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + input_embed = self.shared_embedding + x = inputs.astype("int32") + x = input_embed(x) + x = AddPositionEmbs( + config=cfg, decode=False, name="posembed_input")( + x, inputs_positions=inputs_positions) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + x = Encoder1DBlock( + config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) + + encoded = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) + if cfg.pre_ln else x) + + return encoded class Decoder(nn.Module): - """Transformer Model Decoder for sequence to sequence translation. + """Transformer Model Decoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ - config: TransformerConfig - shared_embedding: Any = None + config: TransformerConfig + shared_embedding: Any = None - @nn.compact - def __call__( - self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=None, - ): - """Applies Transformer model on the inputs. + @nn.compact + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=None, + ): + """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. @@ -426,57 +431,56 @@ def __call__( Returns: output of a transformer decoder. """ - cfg = self.config - - assert encoded.ndim == 3 # (batch, len, depth) - assert targets.ndim == 2 # (batch, len) - - # Target Embedding - if self.shared_embedding is None: - output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), - ) - else: - output_embed = self.shared_embedding - - y = targets.astype("int32") - if not cfg.decode: - y = shift_right(y) - y = output_embed(y) - y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")( - y, inputs_positions=targets_positions - ) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) - - y = y.astype(cfg.dtype) - - # Target-Input Decoder - for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=dropout_rate, - ) - y = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) - if cfg.pre_ln - else y - ) - - # Use the transpose of embedding matrix for logit transform. - logits = output_embed.attend(y.astype(jnp.float32)) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - return logits + cfg = self.config + + assert encoded.ndim == 3 # (batch, len, depth) + assert targets.ndim == 2 # (batch, len) + + # Target Embedding + if self.shared_embedding is None: + output_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) + else: + output_embed = self.shared_embedding + + y = targets.astype("int32") + if not cfg.decode: + y = shift_right(y) + y = output_embed(y) + y = AddPositionEmbs( + config=cfg, decode=cfg.decode, name="posembed_output")( + y, inputs_positions=targets_positions) + if dropout_rate is None: + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + + y = y.astype(cfg.dtype) + + # Target-Input Decoder + for lyr in range(cfg.num_layers): + y = EncoderDecoder1DBlock( + config=cfg, name=f"encoderdecoderblock_{lyr}")( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) + y = ( + nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) + if cfg.pre_ln else y) + + # Use the transpose of embedding matrix for logit transform. + logits = output_embed.attend(y.astype(jnp.float32)) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + return logits class Transformer(nn.Module): @@ -504,7 +508,11 @@ def setup(self): self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None): + def encode(self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + dropout_rate=None): """Applies Transformer encoder-branch on the inputs. Args: @@ -529,7 +537,10 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropou jnp.equal, dtype=cfg.dtype)) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate) + inputs, + inputs_positions=inputs_positions, + encoder_mask=encoder_mask, + dropout_rate=dropout_rate) def decode( self, From 363da8ac032c82b2ed8ac1c7ea64a25be243cbc6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:39:39 +0000 Subject: [PATCH 009/123] dropout fix for criteo1tb jax --- algoperf/workloads/criteo1tb/criteo1tb_jax/models.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 0b2126915..b7af15208 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -30,7 +30,7 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) @@ -93,7 +93,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if dropout_rate and layer_idx == num_layers_top - 2: - x = Dropout(deterministic=not train)(x, rate=dropout_rate) + x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -156,7 +156,7 @@ class DlrmSmall(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is None: dropout_rate = self.dropout_rate bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) @@ -219,7 +219,8 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.LayerNorm()(top_mlp_input) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_input = Dropout(deterministic=not train)( - top_mlp_input, rate=dropout_rate) + top_mlp_input = Dropout( + dropout_rate, deterministic=not train)( + top_mlp_input, rate=dropout_rate) logits = top_mlp_input return logits From 341bf8996de4e4897eb9559124b1123c115dd62a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:42:00 +0000 Subject: [PATCH 010/123] dropout fix for criteo1tb jax --- .../criteo1tb/criteo1tb_jax/workload.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..bad2f4390 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -82,15 +82,26 @@ def init_model_fn( model_class = models.DLRMResNet else: model_class = models.DlrmSmall - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) + + if dropout_rate is None: + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) + else: + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 From f0c385bcec139050fe11bee01f0bc0ba0b9194d9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 May 2025 23:49:42 +0000 Subject: [PATCH 011/123] remove aux dropout option from conformer and from init_model_fn signature for fastmri, vit and criteo --- .../criteo1tb/criteo1tb_jax/workload.py | 2 -- .../workloads/fastmri/fastmri_jax/workload.py | 3 +- .../imagenet_vit/imagenet_jax/workload.py | 4 +-- .../librispeech_jax/workload.py | 28 ++++++++++++------- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index bad2f4390..e3864643b 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -73,11 +73,9 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate if self.use_resnet: model_class = models.DLRMResNet else: diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 17ce6b442..13ab5c1b8 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -22,9 +22,8 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) if dropout_rate is None: self._model = UNet( diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 5d07b5ff8..5107ed993 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -33,9 +33,7 @@ def initialized(self, key: spec.RandomState, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: if dropout_rate is None: self._model = models.ViT( num_classes=self._num_classes, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..4da70fc61 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -61,24 +61,32 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Conformer model init function. - Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as + Here we use dropout_rate as *_residual_dropout_rate, and for input_dropout_rate. """ if self.use_gelu: activation_function_name = 'gelu' else: activation_function_name = 'swish' - model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + if dropout_rate is None: + model_config = models.ConformerConfig( + attention_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) + else: + model_config = models.ConformerConfig( + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) + self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] From 7af5c941d81a7e66e3128afaf5b49c6f2730c302 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 31 May 2025 00:26:32 +0000 Subject: [PATCH 012/123] add dropout piping for conformer and deepspeech --- .../workloads/fastmri/fastmri_jax/workload.py | 2 +- .../librispeech_jax/models.py | 7 +-- .../librispeech_jax/workload.py | 2 +- .../librispeech_jax/models.py | 34 +++++++++------ .../librispeech_jax/workload.py | 43 +++++++++++-------- algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 26 ++++++----- algoperf/workloads/wmt/wmt_jax/workload.py | 24 ++++++----- 8 files changed, 83 insertions(+), 59 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 13ab5c1b8..439b8d055 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -22,7 +22,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - ) -> spec.ModelInitState: + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" fake_batch = jnp.zeros((13, 320, 320)) if dropout_rate is None: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 51e93acc9..29c349e11 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -38,13 +38,10 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. + attention_residual_dropout_rate: Optional[float] = 0.0 conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 + feed_forward_residual_dropout_rate: Optional[float] = 0.0 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 4da70fc61..a54f52c04 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -61,7 +61,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - ) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Conformer model init function. Here we use dropout_rate as *_residual_dropout_rate, and for diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 4cdb02ee1..3ad31b532 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -75,6 +75,9 @@ class Subsample(nn.Module): @nn.compact def __call__(self, inputs, output_paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate + outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( @@ -111,7 +114,9 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): input_dropout_rate = 0.1 else: input_dropout_rate = config.input_dropout_rate - outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs) + outputs = Dropout( + rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -187,7 +192,13 @@ class FeedForwardModule(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): + def __call__(self, + inputs, + input_paddings=None, + train=False, + dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.config.feed_forward_dropout_rate padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config @@ -211,12 +222,8 @@ def __call__(self, inputs, input_paddings=None, train=False): inputs = nn.relu(inputs) inputs *= padding_mask - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = Dropout(rate=feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) return inputs @@ -472,8 +479,10 @@ def setup(self): ) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate outputs = inputs output_paddings = input_paddings @@ -493,7 +502,7 @@ def __call__(self, inputs, input_paddings, train): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train) + config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the lstm layers. for _ in range(config.num_lstm_layers): @@ -507,9 +516,8 @@ def __call__(self, inputs, input_paddings, train): outputs = outputs + FeedForwardModule(config=self.config)( outputs, output_paddings, train) else: - outputs = FeedForwardModule(config=self.config)(outputs, - output_paddings, - train) + outputs = FeedForwardModule(config=self.config)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the decoder which in this case is a trivial projection layer. if config.enable_decoder_layer_norm: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..3c9a96f99 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -18,24 +18,31 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. """ - model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) + if dropout_rate is None: + model_config = models.DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) + else: + model_config = models.DeepspeechConfig( + feed_forward_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + input_dropout_rate=dropout_rate, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] @@ -64,6 +71,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None + dropout_rate: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -75,7 +83,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 6ced9bef5..f6cb1c490 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -48,9 +48,9 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train, dropout_rate=None): - if not dropout_rate: + if dropout_rate is not None: dropout_rate = self.dropout_rate - dropout = Dropout(deterministic=not train, rate=dropout_rate) + dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..f7de3f982 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -20,18 +20,24 @@ class OgbgWorkload(BaseOgbgWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate rng, params_rng, dropout_rng = jax.random.split(rng, 3) - self._model = models.GNN( - self._num_outputs, - dropout_rate=dropout_rate, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + if dropout_rate is None: + self._model = models.GNN( + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) + else: + self._model = models.GNN( + self._num_outputs, + dropout_rate=dropout_rate, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 240ad2c11..b1f1e78a8 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -209,10 +209,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = 0.0, - aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - + dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -224,13 +221,20 @@ def init_model_fn( else: raise ValueError(f'Unknown activation function {self.activation}.') + if dropout_rate is None: + model_config = models.TransformerConfig( + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) + else: model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate, + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) From cbd065b490e1e57212d6b0112715ee73fcdf1a67 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 31 May 2025 01:30:22 +0000 Subject: [PATCH 013/123] pipe dropout through model_fn --- .../criteo1tb/criteo1tb_jax/workload.py | 4 +- .../workloads/fastmri/fastmri_jax/workload.py | 14 ++-- .../imagenet_vit/imagenet_jax/workload.py | 6 +- .../librispeech_jax/models.py | 72 +++++++++---------- .../librispeech_jax/workload.py | 6 +- .../librispeech_jax/workload.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 6 +- 7 files changed, 60 insertions(+), 50 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index e3864643b..101e02c15 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -126,7 +126,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] @@ -134,6 +135,7 @@ def model_fn( apply_kwargs = {'train': train} if train: apply_kwargs['rngs'] = {'dropout': rng} + apply_kwargs['dropout_rate'] = dropout_rate logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 439b8d055..3d891cf8f 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -60,14 +60,18 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train) + + if train: + logits = self._model.apply({'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 5107ed993..89355ac6e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -66,14 +66,16 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _eval_model_on_split(self, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 29c349e11..2ca0fffdc 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -37,7 +37,7 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 + dropout_rate: float = 0.1 attention_residual_dropout_rate: Optional[float] = 0.0 conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 @@ -51,8 +51,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -98,10 +96,12 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 - input_dropout_rate: float = 0.0 + dropout_rate: float = 0.0 @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) @@ -128,8 +128,8 @@ def __call__(self, inputs, input_paddings, train): seq_length=outputs.shape[1]) outputs = Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) + rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate) return outputs, output_paddings @@ -196,9 +196,10 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate): config = self.config - + if dropout_rate is None: + dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( @@ -387,8 +388,11 @@ class MultiHeadedSelfAttention(nn.Module): config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train): + def __call__(self, inputs, paddings, train, dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate + mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) @@ -410,13 +414,9 @@ def __call__(self, inputs, paddings, train): deterministic=not train)( inputs_q=inputs, mask=attention_mask) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate result = Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)( - result) + rate=dropout_rate, deterministic=not train)( + result, rate=dropout_rate) return result @@ -526,8 +526,11 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average_bn): + use_running_average_bn, + dropout_rate=None): config = self.config + if dropout_rate is None: + dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( @@ -572,13 +575,9 @@ def __call__(self, config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( inputs) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate inputs = Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)( - inputs) + rate=dropout_rate, deterministic=not train)( + inputs, rate=dropout_rate) return inputs @@ -603,26 +602,28 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average): + use_running_average, + dropout_rate=None): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) + inputs, input_paddings, train, dropout_rate=dropout_rate) inputs = inputs + \ ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, - use_running_average + use_running_average, + dropout_rate ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate) if config.use_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -656,7 +657,8 @@ def __call__(self, input_paddings, train, update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = None): config = self.config outputs = inputs @@ -681,15 +683,10 @@ def __call__(self, if train and config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - # Subsample input by a factor of 4 by performing strided convolutions. - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs, output_paddings = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate)( - outputs, output_paddings, train) + dropout_rate=dropout_rate)( + outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): @@ -697,7 +694,8 @@ def __call__(self, output_paddings, train, update_batch_norm, - use_running_average_bn) + use_running_average_bn, + dropout_rate) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index a54f52c04..2e082cf07 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -116,7 +116,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] @@ -129,7 +130,8 @@ def model_fn( train=True, rngs={'dropout' : rng}, mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + use_running_average_bn=use_running_average_bn, + dropout_rate=dropout_rate) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3c9a96f99..825b470db 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -70,7 +70,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + use_running_average_bn: Optional[bool] = None, dropout_rate: Optional[bool] = None ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index f7de3f982..3becd5599 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -63,7 +63,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: @@ -74,7 +75,8 @@ def model_fn( logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, - train=train) + train=train, + dropout_rate=dropout_rate) return logits, None def _binary_cross_entropy_with_mask( From 31babfd9da6e2ad55156f55aeeb1c9cf10d88edc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 19:05:10 +0000 Subject: [PATCH 014/123] fix syntax --- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index b1f1e78a8..367c062cb 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -228,7 +228,7 @@ def init_model_fn( activation=activation, glu=self.glu) else: - model_config = models.TransformerConfig( + model_config = models.TransformerConfig( dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, From 95d67db14e4c0d68e4868b55240c2211b8b039af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:45:08 +0000 Subject: [PATCH 015/123] dropout changes wmt jax --- algoperf/workloads/wmt/wmt_jax/models.py | 67 +++++++++------------- algoperf/workloads/wmt/wmt_jax/workload.py | 6 +- 2 files changed, 30 insertions(+), 43 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 54a917a09..3947a1b81 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -28,10 +28,7 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - #If None, defaults to 0.1. dropout_rate: Optional[float] = 0.1 - #If None, defaults to 0.1. - attention_dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -154,6 +151,9 @@ class MlpBlock(nn.Module): def __call__(self, inputs, dropout_rate=None): """Applies Transformer MlpBlock module.""" cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( cfg.mlp_dim, @@ -172,12 +172,7 @@ def __call__(self, inputs, dropout_rate=None): )( inputs) x = x * y - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, @@ -185,7 +180,7 @@ def __call__(self, inputs, dropout_rate=None): bias_init=cfg.bias_init, )( x) - output = Dropout()( + output = Dropout(rate=dropout_rate)( output, rate=dropout_rate, deterministic=cfg.deterministic) return output @@ -211,16 +206,14 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): output after transformer encoder block. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + pre_ln = cfg.pre_ln # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -233,7 +226,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -275,17 +268,15 @@ def __call__( output after transformer encoder-decoder block. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + pre_ln = cfg.pre_ln # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - if dropout_rate is None: - if cfg.attention_dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate x = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, @@ -298,11 +289,8 @@ def __call__( deterministic=cfg.deterministic, decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -321,7 +309,7 @@ def __call__( deterministic=cfg.deterministic, )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -361,6 +349,9 @@ def __call__(self, output of a transformer encoder. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate + assert inputs.ndim == 2 # (batch, len) # Input Embedding @@ -377,12 +368,7 @@ def __call__(self, x = AddPositionEmbs( config=cfg, decode=False, name="posembed_input")( x, inputs_positions=inputs_positions) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) x = x.astype(cfg.dtype) @@ -432,6 +418,8 @@ def __call__( output of a transformer decoder. """ cfg = self.config + if dropout_rate is None: + dropout_rate = cfg.dropout_rate assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) @@ -453,12 +441,7 @@ def __call__( y = AddPositionEmbs( config=cfg, decode=cfg.decode, name="posembed_output")( y, inputs_positions=targets_positions) - if dropout_rate is None: - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) y = y.astype(cfg.dtype) @@ -549,7 +532,8 @@ def decode( targets, targets_positions=None, inputs_segmentation=None, - targets_segmentation=None): + targets_segmentation=None, + dropout_rate=None): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -598,7 +582,8 @@ def decode( targets, targets_positions=targets_positions, decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=droput_rate) return logits.astype(self.config.dtype) def __call__(self, diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 367c062cb..193732640 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -259,7 +259,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm @@ -286,7 +287,8 @@ def model_fn( targets_positions=targets_positions, inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, - rngs={'dropout': rng}) + rngs={'dropout': rng}, + dropout_rate=None) return logits_batch, None def _normalize_eval_metrics( From 2c96b884eba3cd8500c8d3fd1de6feb28194fbe3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:49:15 +0000 Subject: [PATCH 016/123] modify dockerfile --- docker/Dockerfile | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..72e3a810f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From 54786a6594bf70388e8791aab2b35c78b7cbf028 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 4 Jun 2025 20:49:52 +0000 Subject: [PATCH 017/123] modify docker build script --- docker/build_docker_images.sh | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" From 9c189adc9b11948ab8e605f75da17187ccf35f3a Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 5 Jun 2025 16:07:06 -0400 Subject: [PATCH 018/123] Update metrics.py - fix for ogbg pytorch It seems that the problem affecting the pytorch ogbg workloads (but only if they run for some length of time) has to do with jax/xla cpu compilation of the metrics computation. By converting the jax arrays to numpy, hopefully this can be avoided. The next step is to test on schedule free and shampoo, which I hope to do very soon. --- algoperf/workloads/ogbg/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 55f83d905..c2db383b5 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -40,7 +40,7 @@ def compute(self): if USE_PYTORCH_DDP: # Sync labels, logits, and masks across devices. - all_values = [labels, logits, mask] + all_values = [np.array(labels), np.array(logits), np.array(mask)] for idx, array in enumerate(all_values): tensor = torch.as_tensor(array, device=DEVICE) # Assumes that the tensors on all devices have the same shape. @@ -51,7 +51,7 @@ def compute(self): mask = mask.astype(bool) - probs = jax.nn.sigmoid(logits) + probs = 1 / (1 + np.exp(-logits)) num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan) From 246d68ee96d05d9b4b463dd08ae36e1715f6b3bb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 23:38:41 +0000 Subject: [PATCH 019/123] fsmall fixes --- algoperf/workloads/fastmri/fastmri_jax/models.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 2 +- .../workloads/librispeech_conformer/librispeech_jax/models.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 7ecca2add..b04510297 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -139,9 +139,9 @@ class ConvBlock(nn.Module): dropout_rate: Dropout probability. """ out_channels: int - dropout_rate: float = 0.0 use_tanh: bool use_layer_norm: bool + dropout_rate: float = 0.0 @nn.compact def __call__(self, x, train=True, dropout_rate=None): diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 227f7c297..8ffc0b610 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -162,7 +162,7 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 - dropout_rate: 0.0 + dropout_rate: float = 0.0 @nn.compact def __call__(self, x, dropout_rate=None): diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2ca0fffdc..2d0da15e5 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -196,7 +196,7 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None): config = self.config if dropout_rate is None: dropout_rate = config.dropout_rate From 0c8dd14d617ff1a642915f34ccd1e504d5a8c0a1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Jun 2025 00:39:59 +0000 Subject: [PATCH 020/123] change docker base image to 12.1.1 --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 72e3a810f..f1fc99550 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" From a78fa6642aed752010e76e953d11ee4c54bafddd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Jun 2025 00:49:04 +0000 Subject: [PATCH 021/123] update base image --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f1fc99550..9926b0542 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" From 4beef49bad7570c40ad1563e9ffc5ab83ba8df90 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:04:21 +0000 Subject: [PATCH 022/123] add slurm instructions --- scoring/utils/slurm/README.md | 16 +++ .../utils/slurm/algoperf_slurm_cluster.yaml | 105 ++++++++++++++ .../slurm/algoperf_slurm_pakcer_builder.yaml | 132 ++++++++++++++++++ scoring/utils/slurm/config.json | 106 ++++++++++++++ 4 files changed, 359 insertions(+) create mode 100644 scoring/utils/slurm/algoperf_slurm_cluster.yaml create mode 100644 scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml create mode 100644 scoring/utils/slurm/config.json diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md index ffd56fbf3..29df653e4 100644 --- a/scoring/utils/slurm/README.md +++ b/scoring/utils/slurm/README.md @@ -1,3 +1,4 @@ +# Launching SLURM jobs with SBATCH This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset). To launch jobs: @@ -24,3 +25,18 @@ python3 make_job_config.py \ ``` sbatch run_jobs.sh ``` + + +# Set up new SLURM cluster +If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster). +To set up the new cluster: + +1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart). +2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml). +3) Manually update the image: + 1) Create a VM from the Disk image created in the previous step. + 2) Install the NVIDIA container toolkit on the VM. + 3) Transfer the data from GCP bucket to `/opt/data`. + 4) Create a new disk image from the VM. +4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml). + diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml new file mode 100644 index 000000000..e6c35e017 --- /dev/null +++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml @@ -0,0 +1,105 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +blueprint_name: algoperf-slurm-internal + +vars: + project_id: training-algorithms-external + deployment_name: algoperf-slurm-internal + region: europe-west4 + zone: europe-west4-a + disk_size_gb: 3000 + slurm_cluster_name: algoperf + image_name: algoperf-image-data-container-tkt + +# Recommended to use GCS backend for Terraform state +# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state +# +# terraform_backend_defaults: +# type: gcs +# configuration: +# bucket: <> + +deployment_groups: +- group: primary + modules: + - id: network + source: modules/network/vpc + + - id: homefs + source: community/modules/file-system/nfs-server + use: [network] + settings: + local_mounts: [/home] + disk_size: 3000 + zone: $(vars.zone) + + - id: script + source: modules/scripts/startup-script + settings: + +- group: cluster + modules: + - id: v100_nodeset + source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset + use: + - network + settings: + node_count_dynamic_max: 25 # set to 0 if you want node to live forever + region: $(vars.region) + zone: $(vars.zone) + enable_placement: false + bandwidth_tier: gvnic_enabled + machine_type: n1-standard-64 + guest_accelerator: + - type: nvidia-tesla-v100 + count: 8 + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + + - id: v100_partition + source: community/modules/compute/schedmd-slurm-gcp-v6-partition + use: [v100_nodeset] + settings: + exclusive: false + partition_name: v100 + is_default: true + + - id: slurm_login + source: community/modules/scheduler/schedmd-slurm-gcp-v6-login + use: [network] + settings: + enable_login_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + zone: $(vars.zone) + + - id: slurm_controller + source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller + use: + - network + - v100_partition + - homefs + - slurm_login + settings: + enable_controller_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + region: $(vars.region) diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml new file mode 100644 index 000000000..286728e1d --- /dev/null +++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml @@ -0,0 +1,132 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- + +blueprint_name: algoperf-slurm-packer + +vars: + project_id: training-algorithms-external + deployment_name: algoperf-slurm-packer + region: europe-west4 + zone: europe-west4-a + new_image: + family: algoperf-image + project: $(vars.project_id) + disk_size_gb: 3000 + slurm_cluster_name: algoperf-packer + +# Recommended to use GCS backend for Terraform state +# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state +# +# terraform_backend_defaults: +# type: gcs +# configuration: +# bucket: <> + +deployment_groups: +- group: primary + modules: + - id: network + source: modules/network/vpc + + - id: script + source: modules/scripts/startup-script + settings: + region: $(vars.region) + install_ansible: true + docker: + enabled: true + world_writable: true + # (TODO) Do I need this? + configure_ssh_host_patterns: + - 10.0.0.* + - 10.1.0.* + - 10.2.0.* + - 10.3.0.* + - 10.4.0.* + - 10.5.0.* + - 10.6.0.* + - 10.7.0.* + - $(vars.slurm_cluster_name)* + runners: + - type: shell + destination: install-ml-libraries.sh + content: | + #!/bin/bash + # this script is designed to execute on Slurm images published by SchedMD that: + # - are based on Debian distribution of Linux + # - have NVIDIA drivers pre-installed + + set -e -o pipefail + + echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list + apt-get update --allow-releaseinfo-change + apt-get install --assume-yes google-fast-socket + + CONDA_BASE=/opt/conda + + if [ -d $CONDA_BASE ]; then + exit 0 + fi + + DL_DIR=\$(mktemp -d) + cd $DL_DIR + curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh + HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE + cd - + rm -rf $DL_DIR + unset DL_DIR + + source $CONDA_BASE/bin/activate base + conda init --system + conda config --system --set auto_activate_base False + # following channel ordering is important! use strict_priority! + conda config --system --set channel_priority strict + conda update -n base conda --yes + + ### create a virtual environment for tensorflow + conda create -n tf python=3.11 --yes + conda activate tf + pip install tensorflow[and-cuda]==2.18.* + pip install tensorrt==10.6.* + + ### create a virtual environment for pytorch + conda create -n pytorch python=3.11 --yes + conda activate pytorch + pip install torch torchvision torchaudio + +- group: packer + modules: + - id: custom-image + source: modules/packer/custom-image + kind: packer + use: + - network + - script + settings: + # give VM a public IP to ensure startup script can reach public internet + # w/o new VPC + omit_external_ip: false + source_image_project_id: [schedmd-slurm-public] + # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family + source_image_family: slurm-gcp-6-8-debian-11 + # You can find size of source image by using following command + # gcloud compute images describe-from-family --project schedmd-slurm-public + disk_size: $(vars.disk_size_gb) + image_family: $(vars.new_image.family) + # building this image does not require a GPU-enabled VM + machine_type: c2-standard-16 + state_timeout: 300m + zone: $(vars.zone) diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json new file mode 100644 index 000000000..dc19e57f7 --- /dev/null +++ b/scoring/utils/slurm/config.json @@ -0,0 +1,106 @@ +{ + "0": { + "framework": "jax", + "workload": "imagenet_resnet", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 411096763, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "1": { + "framework": "jax", + "workload": "imagenet_vit", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1884713130, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "2": { + "framework": "jax", + "workload": "fastmri", + "dataset": "fastmri", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -214785144, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "3": { + "framework": "jax", + "workload": "ogbg", + "dataset": "ogbg", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -893097833, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "4": { + "framework": "jax", + "workload": "wmt", + "dataset": "wmt", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1244182279, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "5": { + "framework": "jax", + "workload": "librispeech_deepspeech", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 1546003634, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "6": { + "framework": "jax", + "workload": "criteo1tb", + "dataset": "criteo1tb", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -2062333143, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "7": { + "framework": "jax", + "workload": "librispeech_conformer", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 409209730, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + } +} \ No newline at end of file From 2b782dda88beef216629832e2511fb636ad9f974 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:06:24 +0000 Subject: [PATCH 023/123] formatting docs --- scoring/utils/slurm/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md index 29df653e4..3e5d85b73 100644 --- a/scoring/utils/slurm/README.md +++ b/scoring/utils/slurm/README.md @@ -34,9 +34,9 @@ To set up the new cluster: 1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart). 2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml). 3) Manually update the image: - 1) Create a VM from the Disk image created in the previous step. - 2) Install the NVIDIA container toolkit on the VM. - 3) Transfer the data from GCP bucket to `/opt/data`. - 4) Create a new disk image from the VM. + 1) Create a VM from the Disk image created in the previous step. + 2) Install the NVIDIA container toolkit on the VM. + 3) Transfer the data from GCP bucket to `/opt/data`. + 4) Create a new disk image from the VM. 4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml). From dfe4fb47055cfc019dc11b22d8ffa86720d39af7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:09:41 +0000 Subject: [PATCH 024/123] update instrucitons --- scoring/utils/slurm/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md index 3e5d85b73..ed42752dd 100644 --- a/scoring/utils/slurm/README.md +++ b/scoring/utils/slurm/README.md @@ -11,7 +11,7 @@ python3 make_job_config.py \ --framework ``` 2) Save the config.json in the same directory you will run the sbatch script from. -3) Check the sbatch script `run_jobs.sh`. +3) Copy the example sbatch script `run_jobs.sh`. - Set the task range to the number of tasks in the config. ``` #SBATCH --array=0-119 @@ -21,6 +21,16 @@ python3 make_job_config.py \ #SBATCH --output=experiments///job_%A_%a.out #SBATCH --error=experiments///job_%A_%a.err ``` +- Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary: +``` +REPO="us-central1-docker.pkg.dev" +IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main" +y | gcloud auth configure-docker $REPO +docker pull $IMAGE +# Job config (ATTENTION: you may want to modify this) +config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path +LOGS_BUCKET="algoperf-runs-internal" +``` 4) Submit a SLURM batch job by running: ``` sbatch run_jobs.sh From f0019ac4ad04fb8bfe2c6430475a7043dec77ff3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:23:18 +0000 Subject: [PATCH 025/123] small fix --- .../workloads/librispeech_deepspeech/librispeech_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 3ad31b532..455366e5e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -115,7 +115,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None): else: input_dropout_rate = config.input_dropout_rate outputs = Dropout( - rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)( + rate=input_dropout_rate, deterministic=not train)( outputs, rate=dropout_rate) return outputs, output_paddings From 3cb012e919374e204f123145ebbe596bf72b4eac Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Jun 2025 04:34:36 +0000 Subject: [PATCH 026/123] remove aux_dropout from submission_runner.py --- submission_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..bb4a8c6cc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -229,13 +229,10 @@ def train_once( logging.info('Initializing model.') with profiler.profile('Initializing model'): dropout_rate = None - aux_dropout_rate = None if hasattr(hyperparameters, 'dropout_rate'): dropout_rate = hyperparameters.dropout_rate - if hasattr(hyperparameters, 'aux_dropout_rate'): - aux_dropout_rate = hyperparameters.aux_dropout_rate model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate, aux_dropout_rate) + model_init_rng, dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', From fdc956bf62b80f84e634d6717ab9bc12aea33a9e Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 9 Jun 2025 10:27:45 -0400 Subject: [PATCH 027/123] Update metrics.py The problem with torchrun and jax seems to be caused by jax.nn.sigmoid. --- algoperf/workloads/ogbg/metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index c2db383b5..982e1044e 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -37,10 +37,11 @@ def compute(self): labels = values['labels'] logits = values['logits'] mask = values['mask'] + sigmoid = jax.nn.sigmoid if USE_PYTORCH_DDP: # Sync labels, logits, and masks across devices. - all_values = [np.array(labels), np.array(logits), np.array(mask)] + all_values = [labels, logits, mask] for idx, array in enumerate(all_values): tensor = torch.as_tensor(array, device=DEVICE) # Assumes that the tensors on all devices have the same shape. @@ -48,10 +49,11 @@ def compute(self): dist.all_gather(all_tensors, tensor) all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values + sigmoid = lambda x: 1 / (1 + np.exp(-x)) mask = mask.astype(bool) - probs = 1 / (1 + np.exp(-logits)) + probs = sigmoid(logits) num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan) From 6c888df9a365be98332575bae74295b23501a7ea Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 9 Jun 2025 10:43:29 -0400 Subject: [PATCH 028/123] Update metrics.py Changed from lambda expression which pylint doesn't like. --- algoperf/workloads/ogbg/metrics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 982e1044e..5e1e7c4ef 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -31,6 +31,9 @@ class MeanAveragePrecision( metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): """Computes the mean average precision (mAP) over different tasks.""" + def sigmoid_np(x): + return 1 / (1 + np.exp(-x)) + def compute(self): # Matches the official OGB evaluation scheme for mean average precision. values = super().compute() @@ -49,7 +52,7 @@ def compute(self): dist.all_gather(all_tensors, tensor) all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values - sigmoid = lambda x: 1 / (1 + np.exp(-x)) + sigmoid = sigmoid_np mask = mask.astype(bool) From e4a55ab1db0a114a3a713c327ab334effcba9d53 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 9 Jun 2025 17:19:11 -0400 Subject: [PATCH 029/123] Update metrics.py Defined np sigmoid inside use_pytorch_ddp --- algoperf/workloads/ogbg/metrics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 5e1e7c4ef..8f342c25d 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -31,9 +31,6 @@ class MeanAveragePrecision( metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): """Computes the mean average precision (mAP) over different tasks.""" - def sigmoid_np(x): - return 1 / (1 + np.exp(-x)) - def compute(self): # Matches the official OGB evaluation scheme for mean average precision. values = super().compute() @@ -52,6 +49,8 @@ def compute(self): dist.all_gather(all_tensors, tensor) all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values + def sigmoid_np(x): + return 1 / (1 + np.exp(-x)) sigmoid = sigmoid_np mask = mask.astype(bool) From 07f89a2b69d6f7667c83961e0fea4ae228681355 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 9 Jun 2025 17:29:58 -0400 Subject: [PATCH 030/123] Update metrics.py Added white space before and after sigmoid_np --- algoperf/workloads/ogbg/metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 8f342c25d..19d43aae4 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -49,8 +49,10 @@ def compute(self): dist.all_gather(all_tensors, tensor) all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values + def sigmoid_np(x): return 1 / (1 + np.exp(-x)) + sigmoid = sigmoid_np mask = mask.astype(bool) From 3e436c771f270be867a6ec973eb6c022a5c6ae69 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 9 Jun 2025 18:19:52 -0400 Subject: [PATCH 031/123] Update metrics.py Fix white space --- algoperf/workloads/ogbg/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 19d43aae4..ea6041a6c 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -49,10 +49,10 @@ def compute(self): dist.all_gather(all_tensors, tensor) all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values - + def sigmoid_np(x): return 1 / (1 + np.exp(-x)) - + sigmoid = sigmoid_np mask = mask.astype(bool) From b3060762eda8a47c6409bf5804ccca8d53686e0c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 10 Jun 2025 18:42:28 +0200 Subject: [PATCH 032/123] dropout fix criteo, fastmri, vit, conf --- .../criteo1tb_pytorch/models_dropout.py | 298 ++++++++++ .../models_functional_dropout.py | 308 +++++++++++ algoperf/workloads/dropout_modules.py | 41 ++ .../fastmri/fastmri_pytorch/models_dropout.py | 167 ++++++ .../imagenet_pytorch/models_dropout.py | 395 +++++++++++++ .../librispeech_pytorch/models_dropout.py | 518 ++++++++++++++++++ 6 files changed, 1727 insertions(+) create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py create mode 100644 algoperf/workloads/dropout_modules.py create mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py create mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py create mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py new file mode 100644 index 000000000..8042ec31e --- /dev/null +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -0,0 +1,298 @@ +"""Pytorch implementation of DLRM-Small.""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout + + +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""" "" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + + def forward(self, x): + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p=None): + return self.module(x, p) + x if self.resnet else self.module(x, p) + + +class DotInteract(nn.Module): + """Performs feature interaction operation between dense or sparse features.""" + + def __init__(self, num_sparse_features): + super().__init__() + self.triu_indices = torch.triu_indices(num_sparse_features + 1, + num_sparse_features + 1) + + def forward(self, dense_features, sparse_features): + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), + dim=1) + interactions = torch.bmm(combined_values, + torch.transpose(combined_values, 1, 2)) + interactions_flat = interactions[:, + self.triu_indices[0], + self.triu_indices[1]] + return torch.cat((dense_features, interactions_flat), dim=1) + + +class DLRMResNet(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + bot_mlp_blocks.append(block) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] + num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + block.append(CustomDropout()) # (nico) + block = SequentialWithDropout(*block) # (nico) + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): + block = DenseBlockWithDropout(block, resnet=True) + else: + block = DenseBlockWithDropout(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico) + + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, 26 * self.embed_dim]) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) + + # Final MLP. + logits = self.top_mlp(top_mlp_input, dropout_rate) + return logits + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bottom_mlp_layers = [] + for dense_dim in self.mlp_bottom_dims: + bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) + bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + + # TODO: Write down the formula here instead of the constant. + input_dims = 506 + num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + fan_in = input_dims if layer_idx == 0 \ + else self.mlp_top_dims[layer_idx - 1] + top_mlp_layers.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + top_mlp_layers.append(CustomDropout()) # (nico) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP. + logits = self.top_mlp(concatenated_dense, dropout_rate) + return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py new file mode 100644 index 000000000..346e0e72a --- /dev/null +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py @@ -0,0 +1,308 @@ +"""Pytorch implementation of DLRM-Small.""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""" "" + + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + + def forward(self, x): + if self.resnet: + return self.module(x) + x + else: + return self.module(x) + + +class DotInteract(nn.Module): + """Performs feature interaction operation between dense or sparse features.""" + + def __init__(self, num_sparse_features): + super().__init__() + self.triu_indices = torch.triu_indices(num_sparse_features + 1, + num_sparse_features + 1) + + def forward(self, dense_features, sparse_features): + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), + dim=1) + interactions = torch.bmm(combined_values, + torch.transpose(combined_values, 1, 2)) + interactions_flat = interactions[:, + self.triu_indices[0], + self.triu_indices[1]] + return torch.cat((dense_features, interactions_flat), dim=1) + + +class DLRMResNet(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + # dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + bot_mlp_blocks.append(block) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] + num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + # if (dropout_rate is not None and dropout_rate > 0.0 and + # layer_idx == num_layers_top - 2): + # block.append(nn.Dropout(p=dropout_rate)) + block = nn.Sequential(*block) + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = nn.Sequential(*mlp_top_blocks) + + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, 26 * self.embed_dim]) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) + + # Final MLP (horrible!!). + h = top_mlp_input + num_layers_top = len(self.mlp_top_dims) + for layer_idx, block in enumerate(self.top_mlp): + # block.module is nn.Sequential([...]) + seq = block.module + # 1) linear + out = seq[0](h) + # 2) ReLU (if present) + if layer_idx < (num_layers_top - 1): + out = seq[1](out) + # 3) functional dropout at penult layer + if dropout_rate > 0 and layer_idx == num_layers_top - 2: + out = F.dropout(out, dropout_rate, training=self.training) + # 4) wrap in residual if needed + h = out + h if block.resnet else out + return h + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + # dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chunks = 4 + assert vocab_size % num_chunks == 0 + self.embedding_table_chucks = [] + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + + for i in range(num_chunks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bottom_mlp_layers = [] + for dense_dim in self.mlp_bottom_dims: + bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) + bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + + # TODO: Write down the formula here instead of the constant. + input_dims = 506 + num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + fan_in = input_dims if layer_idx == 0 \ + else self.mlp_top_dims[layer_idx - 1] + top_mlp_layers.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) + # if (dropout_rate is not None and dropout_rate > 0.0 and + # layer_idx == num_layers_top - 2): + # top_mlp_layers.append(nn.Dropout(p=dropout_rate)) + self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x, dropout_rate): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP: run each layer, and after the penultimate layer do functional dropout + h = concatenated_dense + N = len(self.top_mlp) + for idx, layer in enumerate(self.top_mlp): + h = layer(h) + # insert dropout exactly where nn.Dropout used to live + if dropout_rate > 0 and idx == N - 2: + h = F.dropout(h, dropout_rate, training=self.training) + return h diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py new file mode 100644 index 000000000..3917b75bf --- /dev/null +++ b/algoperf/workloads/dropout_modules.py @@ -0,0 +1,41 @@ +"""Custom classes to support a dynamic modulized dropout, see issue??TODO""" + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout(input, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout2d(input, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x, p): + for module in self: + # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)): + if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py new file mode 100644 index 000000000..5862f6352 --- /dev/null +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -0,0 +1,167 @@ +"""U-Net Model. + +Adapted from fastMRI: +https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py +""" + +from functools import partial +from typing import Optional + +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + +from algoperf import init_utils +from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout + + + +class UNet(nn.Module): + r"""U-Net model from + `"U-net: Convolutional networks + for biomedical image segmentation" + `_. + """ + + def __init__(self, + in_chans: int = 1, + out_chans: int = 1, + num_channels: int = 32, + num_pool_layers: int = 4, + use_tanh: bool = False, + use_layer_norm: bool = False) -> None: + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.num_channels = num_channels + self.num_pool_layers = num_pool_layers + self.down_sample_layers = nn.ModuleList([ + ConvBlock(in_chans, + num_channels, + use_tanh, + use_layer_norm) + ]) + ch = num_channels + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append( + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append( + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + self.up_conv.append( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + ch //= 2 + + self.up_transpose_conv.append( + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + self.up_conv.append( + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + )) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + init_utils.pytorch_default_init(m) + + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + stack = [] + output = x + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output, dropout_rate) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output, dropout_rate) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/bottom if needed to handle + # odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output, dropout_rate) + + return output + + +class ConvBlock(nn.Module): + # A Convolutional Block that consists of two convolution layers each + # followed by instance normalization, LeakyReLU activation and dropout_rate. + + def __init__(self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool) -> None: + super().__init__() + self._supports_custom_dropout = True + + if use_layer_norm: + norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) + else: + norm_layer = nn.InstanceNorm2d + if use_tanh: + activation_fn = nn.Tanh() + else: + activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.conv_layers = SequentialWithDropout( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + ) + + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) + + +class TransposeConvBlock(nn.Module): + # A Transpose Convolutional Block that consists of one convolution transpose + # layers followed by instance normalization and LeakyReLU activation. + + def __init__( + self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool, + ): + super().__init__() + if use_tanh: + activation_fn = nn.Tanh() + else: + activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False), + nn.InstanceNorm2d(out_chans), + activation_fn, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py new file mode 100644 index 000000000..f5e315fd7 --- /dev/null +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -0,0 +1,395 @@ +"""PyTorch implementation of refactored and simplified ViT. + +Adapted from: +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. +""" + +import math +from typing import Any, Optional, Tuple, Union + +import torch +from torch import nn +import torch.nn.functional as F + +from algoperf import init_utils +from algoperf import spec +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention + + +def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" + _, width, h, w = patches.shape + device = patches.device + y, x = torch.meshgrid(torch.arange(h, device=device), + torch.arange(w, device=device), indexing='ij') + + if width % 4 != 0: + raise ValueError('Width must be mult of 4 for sincos posemb.') + omega = torch.arange(width // 4, device=device) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe[None, :, :] + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + def __init__( + self, + width: int, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu + self.dropout_rate = dropout_rate + + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + + self.reset_parameters() + + def reset_parameters(self) -> None: + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight.data) + if module.bias is not None: + module.bias.data.normal_(std=1e-6) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = F.dropout(x, dropout_rate, training=self.training) + x = self.linear2(x) + return x + + +class SelfAttention(nn.Module): + """Self-attention special case of multi-head dot-product attention.""" + + def __init__(self, + width: int, + num_heads: int = 8, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.num_heads = num_heads + + assert width % num_heads == 0, ( + 'Memory dimension must be divisible by number of heads.') + + self.head_dim = int(width / num_heads) + self.all_head_dim = self.num_heads * self.head_dim + self.dropout_rate = dropout_rate + + self.query = nn.Linear(self.width, self.all_head_dim) + self.key = nn.Linear(self.width, self.all_head_dim) + self.value = nn.Linear(self.width, self.all_head_dim) + self.out = nn.Linear(self.width, self.width) + self.reset_parameters() + + def reset_parameters(self) -> None: + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight.data) + if module.bias is not None: + nn.init.constant_(module.bias.data, 0.) + + def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: + new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + mixed_query_layer = self.query(x) + + key_layer = self.transpose_for_scores(self.key(x)) + value_layer = self.transpose_for_scores(self.value(x)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.head_dim) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) + context_layer = context_layer.view(new_context_layer_shape) + out = self.out(context_layer) + return out + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.dropout_rate = dropout_rate + + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) + self.self_attention1 = SelfAttention(self.width, self.num_heads) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + else: + y = x + y = self.self_attention1(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + x = self.layer_norm0(x) + + y = x + y = self.mlp3(y) + y = F.dropout(y, dropout_rate, training=self.training) + x = x + y + x = self.layer_norm2(x) + return x + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__(self, + depth: int, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + dropout_rate: float = 0.0) -> None: + super().__init__() + + self.depth = depth + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + + self.net = nn.ModuleList([ + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) + ]) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None + + def forward(self, x: spec.Tensor) -> spec.Tensor: + # Input Encoder. + for block in self.net: + x = block(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=True) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x: spec.Tensor) -> spec.Tensor: + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x)[0] + y = self.layer_norm(x) + x = x + self.mlp(y) + return x[:, 0] + + +class ViT(nn.Module): + """ViT model.""" + + image_height: int = 224 + image_width: int = 224 + channels: int = 3 + + def __init__( + self, + num_classes: int = 1000, + patch_size: Tuple[int, int] = (16, 16), + width: int = 768, + depth: int = 12, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + num_heads: int = 12, + rep_size: Union[int, bool] = True, + dropout_rate: Optional[float] = 0.0, + head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, + dtype: Any = torch.float32) -> None: + super().__init__() + if dropout_rate is None: + dropout_rate = 0.0 + + self.num_classes = num_classes + self.patch_size = patch_size + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.rep_size = rep_size + self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map + self.dtype = dtype + self.dropout_rate = dropout_rate + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + self.pre_logits = nn.Linear(self.width, rep_size) + + self.conv_patch_extract = nn.Conv2d( + self.channels, + self.width, + self.patch_size, + stride=self.patch_size, + padding='valid') + + self.encoder = Encoder( + depth=self.depth, + width=self.width, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + dropout_rate=dropout_rate) + + if self.num_classes: + self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + + self.reset_parameters() + + def reset_parameters(self) -> None: + init_utils.pytorch_default_init(self.conv_patch_extract) + + if self.rep_size: + init_utils.pytorch_default_init(self.pre_logits) + + if self.num_classes: + if self.head_zeroinit: + nn.init.constant_(self.head.weight.data, 0.) + nn.init.constant_(self.head.bias.data, 0.) + else: + init_utils.pytorch_default_init(self.head) + + def get_posemb(self, x: spec.Tensor) -> spec.Tensor: + return posemb_sincos_2d(x).type(self.dtype) + + def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + # Patch extraction. + x = self.conv_patch_extract(x) + + # Add posemb before adding extra token. + n, c, h, w = x.shape + pes = self.get_posemb(x) + + # Reshape to match Jax's ViT implementation. + x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) + x = x + pes + + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x) + + if self.use_map: + x = self.map(x) + else: + x = torch.mean(x, dim=1) + + if self.rep_size: + x = torch.tanh(self.pre_logits(x)) + + if self.num_classes: + x = self.head(x) + + return x diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py new file mode 100644 index 000000000..da66dfe43 --- /dev/null +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -0,0 +1,518 @@ +"""This is a pytorch implementation mirroring: +https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. +""" + +from dataclasses import dataclass +from functools import partial +import math +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import init +import torch.nn.functional as F + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ + preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ + SpecAug + + +@dataclass +class ConformerConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + encoder_dim: int = 512 + num_attention_heads: int = 8 + num_encoder_layers: int = 4 + attention_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + attention_residual_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.0. + conv_residual_dropout_rate: Optional[float] = 0.0 + feed_forward_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + feed_forward_residual_dropout_rate: Optional[float] = 0.1 + convolution_kernel_size: int = 5 + feed_forward_expansion_factor: int = 4 + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + batch_norm_momentum: float = 1 - 0.999 + batch_norm_epsilon: float = 0.001 + use_specaug: bool = True + attention_temperature: float = 1.0 + activation_function_name: str = 'swish' + use_post_layer_norm: bool = True + + +def initialize(m): + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.MultiheadAttention): + init.xavier_uniform_(m.in_proj_weight) + for i in m.children(): + initialize(i) + + +class LayerNorm(nn.Module): + + def __init__(self, dim, epsilon=1e-6): + super().__init__() + self.dim = dim + + self.scale = nn.Parameter(torch.zeros(self.dim)) + self.bias = nn.Parameter(torch.zeros(self.dim)) + self.epsilon = epsilon + + def forward(self, x): + return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon) + + +class Subsample(nn.Module): + + def __init__(self, + encoder_dim: int = 0, + input_dropout_rate: float = 0.0, + num_bins: int = 80): + super().__init__() + self.encoder_dim = encoder_dim + self.input_dropout_rate = input_dropout_rate + + self.conv1 = Conv2dSubsampling( + input_channels=1, output_channels=encoder_dim) + self.conv2 = Conv2dSubsampling( + input_channels=encoder_dim, output_channels=encoder_dim) + + self.linear = nn.Linear( + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True) + self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.input_dropout_rate + + output_paddings = input_paddings + outputs = inputs[:, None, :, :] + + outputs, output_paddings = self.conv1(outputs, output_paddings) + outputs, output_paddings = self.conv2(outputs, output_paddings) + + batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape + outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, + subsampled_lengths, + subsampled_dims * channels) + + outputs = self.linear(outputs) + outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + + def __init__(self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME'): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.filter_stride = filter_stride + self.padding = padding + + self.filter_shape = (output_channels, input_channels, 3, 3) + + self.kernel = nn.Parameter( + torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + self.bias = nn.Parameter(torch.zeros(output_channels)) + self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) + + def get_same_padding(self, input_shape): + in_height, in_width = input_shape[2:] + stride_height, stride_width = self.filter_stride + filter_height, filter_width = 3, 3 + if in_height % stride_height == 0: + pad_along_height = max(filter_height - stride_height, 0) + else: + pad_along_height = max(filter_height - (in_height % stride_height), 0) + if in_width % stride_width == 0: + pad_along_width = max(filter_width - stride_width, 0) + else: + pad_along_width = max(filter_width - (in_width % stride_width), 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return (pad_left, pad_right, pad_top, pad_bottom) + + def forward(self, inputs, paddings): + groups = inputs.shape[1] // self.input_channels + + if self.padding == 'SAME': + in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) + else: + in_ = inputs + outputs = F.conv2d( + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups) + + outputs = F.relu(outputs) + + input_length = paddings.shape[1] + stride = self.filter_stride[0] + pad_len = (input_length + stride - 1) // stride * stride - input_length + padded_paddings = F.pad( + paddings[:, None, :], (0, pad_len), mode='constant', value=0) + out_padding = F.conv1d( + input=padded_paddings, + weight=self.paddings_kernel, + stride=self.filter_stride[:1]) + out_padding = out_padding.squeeze(dim=1) + outputs = outputs * (1 - out_padding[:, None, :, None]) + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.config = config + + self.ln = LayerNorm(dim=config.encoder_dim) + self.linear1 = nn.Linear( + in_features=config.encoder_dim, + out_features=config.encoder_dim * config.feed_forward_expansion_factor, + bias=True) + self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) + self.linear2 = nn.Linear( + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True) + + if config.feed_forward_residual_dropout_rate is None: + self.feed_forward_residual_dropout_rate = 0.1 + else: + self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate + + def forward(self, inputs, padding_mask, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.feed_forward_residual_dropout_rate + + inputs = self.ln(inputs) + inputs = self.linear1(inputs) + if self.config.activation_function_name == 'swish': + activation_fn = F.silu + elif self.config.activation_function_name == 'gelu': + # Use tanh approximation of GELU which is default for jax + activation_fn = partial(F.gelu, approximate='tanh') + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = self.dropout1(inputs) + inputs = inputs * padding_mask + inputs = self.linear2(inputs) + inputs = inputs * padding_mask + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + + return inputs + + +class AddPositionalEmbedding(nn.Module): + + def __init__(self, + min_timescale: int = 1, + max_timescale: int = 10_000, + embedding_dim: int = 512): + super().__init__() + self.min_timescale = min_timescale + self.max_timescale = max_timescale + self.embedding_dim = embedding_dim + num_timescales = self.embedding_dim // 2 + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale)) / ( + num_timescales - 1) + inv_timescales = self.min_timescale * \ + torch.exp(torch.arange(num_timescales, dtype=torch.float32) + * -log_timescale_increment) + self.register_buffer('inv_timescales', inv_timescales[None, None, :]) + + def forward(self, seq_length): + position = torch.arange( + end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) + scaled_time = position[None, :, None] * self.inv_timescales + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + if self.embedding_dim % 2: + signal = torch.cat( + [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) + return signal + + +class QueryScaler(nn.Module): + + def __init__(self, dim): + super().__init__() + self.dim = dim + self.scale = nn.Parameter(torch.zeros(self.dim)) + + def forward(self, inputs): + r_softplus_0 = 1.442695041 + scale = r_softplus_0 * F.softplus(self.scale) + return inputs * scale + + +class MHSAwithQS(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.embed_dim = config.encoder_dim + self.num_heads = config.num_attention_heads + self.attention_dropout_rate = config.attention_dropout_rate + self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) + self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) + self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) + self.attention_temperature = config.attention_temperature + + def forward(self, inputs, key_padding_mask=None): + batch_size, seq_len, embed_dim = inputs.shape + q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) + q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=~key_padding_mask[:, None, None], + dropout_p=self.attention_dropout_rate, + ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + out = out * self.attention_temperature + out = self.out_proj(out) + return out + + +class MultiHeadedSelfAttention(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + + self.config = config + + self.ln = LayerNorm(dim=config.encoder_dim) + self.self_attention = MHSAwithQS(config) + if config.attention_residual_dropout_rate is None: + self.attention_residual_dropout_rate = 0.1 + else: + self.attention_residual_dropout_rate = config.attention_residual_dropout_rate + + def forward(self, outputs, paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.attention_residual_dropout_rate + + outputs = self.ln(outputs) + outputs = self.self_attention( + outputs, + key_padding_mask=paddings == 1, + ) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + return outputs + + +class BatchNorm(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + running_mean = torch.zeros(config.encoder_dim) + running_var = torch.ones(config.encoder_dim) + self.register_buffer('running_mean', running_mean) + self.register_buffer('running_var', running_var) + self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) + self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) + + self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) + self.momentum = config.batch_norm_momentum + self.epsilon = config.batch_norm_epsilon + + def forward(self, inputs, input_paddings): + #inputs: NHD + #padding: NH + """ + Alternatively: + inputs[input_paddings==0] = F.batch_norm( + input = inputs[input_paddings==0], + running_mean = self.running_mean, + running_var = self.running_var, + weight = 1+self.scale, + bias = self.bias, + training = self.training, + momentum=1-self.momentum, + eps=self.epsilon + ) + inputs.masked_fill(input_paddings[...,None] != 0, 0) + return inputs + """ + mask = 1 - input_paddings[:, :, None] + if self.training: + count = mask.sum() + masked_inp = inputs.masked_fill(mask == 0, 0) + mean = (masked_inp).sum(dim=(0, 1)) / count + var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count + + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + + else: + mean = self.running_mean + var = self.running_var + v = (1 + self.scale) * torch.rsqrt(var + self.epsilon) + bn = (inputs - mean) * v + self.bias + output = bn.masked_fill(mask == 0, 0) + return output + + +class ConvolutionBlock(nn.Module): + + def __init__(self, config): + super().__init__() + + self.config = config + self.ln = LayerNorm(dim=config.encoder_dim) + self.lin1 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim) + self.lin2 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim) + + self.conv1 = nn.Conv1d( + in_channels=config.encoder_dim, + out_channels=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + stride=(1,), + padding='same', + bias=False, + groups=config.encoder_dim) + self.bn = BatchNorm(config) + self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) + if config.conv_residual_dropout_rate is None: + self.conv_residual_dropout_rate = 0.0 + else: + self.conv_residual_dropout_rate = config.conv_residual_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.conv_residual_dropout_rate + + inputs = self.ln(inputs) + + inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) + inputs = inputs * (1 - input_paddings[:, :, None]) + + inputs = inputs.permute(0, 2, 1) + inputs = self.conv1(inputs) + inputs = inputs.permute(0, 2, 1) + + inputs = self.bn(inputs, input_paddings) + if self.config.activation_function_name == 'swish': + activation_fn = F.silu + elif self.config.activation_function_name == 'gelu': + activation_fn = F.gelu + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = self.lin3(inputs) + + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + return inputs + + +class ConformerBlock(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + + self.ff1 = FeedForwardModule(config) + self.mhsa = MultiHeadedSelfAttention(config) + self.conv = ConvolutionBlock(config) + self.ff2 = FeedForwardModule(config) + self.ln = None + if config.use_post_layer_norm: + self.ln = LayerNorm(dim=config.encoder_dim) + + def forward(self, inputs, input_paddings, dropout_rate=None): + padding_mask = 1 - input_paddings[:, :, None] + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) + if self.ln: + inputs = self.ln(inputs) + return inputs + + +class ConformerEncoderDecoder(nn.Module): + + def __init__(self, config: ConformerConfig): + super().__init__() + self.config = config + preprocessing_config = preprocessor.PreprocessorConfig() + self.preprocessor = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + self.specaug = SpecAug( + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + ) + if config.input_dropout_rate is None: + input_dropout_rate = 0.1 + else: + input_dropout_rate = config.input_dropout_rate + self.subsample = Subsample( + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate, + num_bins=preprocessing_config.num_bins) + self.conformers = nn.ModuleList( + [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) + + self.ln = LayerNorm(config.encoder_dim) + self.lin = nn.Linear(config.encoder_dim, config.vocab_size) + + def forward(self, inputs, input_paddings, dropout_rate=None): + outputs = inputs + output_paddings = input_paddings + outputs, output_paddings = self.preprocessor(outputs, output_paddings) + if self.training and self.config.use_specaug: + outputs, output_paddings = self.specaug(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings) + for conformer in self.conformers: + outputs = conformer(outputs, output_paddings, dropout_rate) + outputs = self.ln(outputs) + outputs = self.lin(outputs) + return outputs, output_paddings From 3e7a3967ba5f4845826e971ba5f0c49fa4c031b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 11:22:23 +0200 Subject: [PATCH 033/123] dropout fix deepspeech, ogbg --- algoperf/workloads/dropout_modules.py | 5 +- .../librispeech_pytorch/models_dropout.py | 395 ++++++++++++++++++ .../ogbg/ogbg_pytorch/models_dropout.py | 314 ++++++++++++++ 3 files changed, 711 insertions(+), 3 deletions(-) create mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py create mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py index 3917b75bf..6cec3f7ad 100644 --- a/algoperf/workloads/dropout_modules.py +++ b/algoperf/workloads/dropout_modules.py @@ -31,10 +31,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._supports_custom_dropout = True - def forward(self, x, p): + def forward(self, x: Tensor, p: float) -> Tensor: for module in self: - # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)): - if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve + if getattr(module, '_supports_custom_dropout', False): x = module(x, p) else: x = module(x) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py new file mode 100644 index 000000000..e68a820ed --- /dev/null +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -0,0 +1,395 @@ +"""This is a pytorch implementation mirroring: +https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. +""" + +from dataclasses import dataclass +import os +from typing import Optional, Tuple + +import torch +from torch import nn +import torch.distributed.nn as dist_nn +import torch.nn.functional as F + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ + preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ + SpecAug + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +@dataclass +class DeepspeechConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + encoder_dim: int = 512 + num_lstm_layers: int = 6 + num_ffn_layers: int = 3 + conv_subsampling_factor: int = 2 + conv_subsampling_layers: int = 2 + use_specaug: bool = True + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + batch_norm_momentum: float = 1 - 0.999 + batch_norm_epsilon: float = 0.001 + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.1. + feed_forward_dropout_rate: Optional[float] = 0.1 + enable_residual_connections: bool = True + enable_decoder_layer_norm: bool = True + bidirectional: bool = True + use_tanh: bool = False + layernorm_everywhere: bool = False + + +class LayerNorm(nn.Module): + + def __init__(self, dim, epsilon=1e-6): + super().__init__() + self.dim = dim + + self.scale = nn.Parameter(torch.zeros(self.dim)) + self.bias = nn.Parameter(torch.zeros(self.dim)) + self.epsilon = epsilon + + def forward(self, x): + mean = x.mean(dim=-1, keepdims=True) + var = x.var(dim=-1, unbiased=False, keepdims=True) + + normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) + normed_x *= (1 + self.scale) + normed_x += self.bias + + return normed_x + + +class Subsample(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + encoder_dim = config.encoder_dim + + self.encoder_dim = encoder_dim + + self.conv1 = Conv2dSubsampling( + input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) + self.conv2 = Conv2dSubsampling( + input_channels=encoder_dim, + output_channels=encoder_dim, + use_tanh=config.use_tanh) + + self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) + + if config.input_dropout_rate is None: + self.input_dropout_rate = 0.1 + else: + self.input_dropout_rate = config.input_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.input_dropout_rate + + output_paddings = input_paddings + outputs = inputs[:, None, :, :] + + outputs, output_paddings = self.conv1(outputs, output_paddings) + outputs, output_paddings = self.conv2(outputs, output_paddings) + + batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape + outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, + subsampled_lengths, + subsampled_dims * channels) + + outputs = self.lin(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + + def __init__(self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + batch_norm_momentum: float = 0.999, + batch_norm_epsilon: float = 0.001, + use_tanh: bool = False): + super().__init__() + + self.input_channels = input_channels + self.output_channels = output_channels + self.filter_stride = filter_stride + self.padding = padding + + self.filter_shape = (output_channels, input_channels, 3, 3) + + self.kernel = nn.Parameter( + nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + self.bias = nn.Parameter(torch.zeros(output_channels)) + + self.use_tanh = use_tanh + + def get_same_padding(self, input_shape): + in_height, in_width = input_shape[2:] + stride_height, stride_width = self.filter_stride + filter_height, filter_width = 3, 3 + if in_height % stride_height == 0: + pad_along_height = max(filter_height - stride_height, 0) + else: + pad_along_height = max(filter_height - (in_height % stride_height), 0) + if in_width % stride_width == 0: + pad_along_width = max(filter_width - stride_width, 0) + else: + pad_along_width = max(filter_width - (in_width % stride_width), 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + return (pad_left, pad_right, pad_top, pad_bottom) + + def forward(self, inputs, paddings): + groups = inputs.shape[1] // self.input_channels + + if self.padding == 'SAME': + in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) + else: + in_ = inputs + outputs = F.conv2d( + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups) + + if self.use_tanh: + outputs = F.tanh(outputs) + else: + outputs = F.relu(outputs) + + input_length = paddings.shape[1] + stride = self.filter_stride[0] + pad_len = (input_length + stride - 1) // stride * stride - input_length + out_padding = F.conv1d( + input=torch.cat([ + paddings[:, None, :], + torch.zeros( + size=(paddings.shape[0], 1, pad_len), device=paddings.device) + ], + dim=2), + weight=torch.ones([1, 1, 1], device=paddings.device), + stride=self.filter_stride[:1]) + out_padding = out_padding.squeeze(dim=1) + outputs = outputs * (1 - out_padding[:, None, :, None]) + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + + if config.layernorm_everywhere: + self.normalization_layer = LayerNorm(config.encoder_dim) + else: + self.bn_normalization_layer = BatchNorm( + dim=config.encoder_dim, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon) + self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) + if config.feed_forward_dropout_rate is None: + self.feed_forward_dropout_rate = 0.1 + else: + self.feed_forward_dropout_rate = config.feed_forward_dropout_rate + + def forward(self, inputs, input_paddings, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.feed_forward_dropout_rate + + padding_mask = (1 - input_paddings)[:, :, None] + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: # batchnorm + inputs = self.bn_normalization_layer(inputs, input_paddings) + + inputs = self.lin(inputs) + + if self.config.use_tanh: + inputs = F.tanh(inputs) + else: + inputs = F.relu(inputs) + + inputs = inputs * padding_mask + inputs = F.dropout(inputs, dropout_rate, training=self.training) + + return inputs + + +class BatchNorm(nn.Module): + + def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): + super().__init__() + running_mean = torch.zeros(dim) + running_var = torch.ones(dim) + self.register_buffer('running_mean', running_mean) + self.register_buffer('running_var', running_var) + self.weight = nn.Parameter(torch.zeros(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) + + self.momentum = batch_norm_momentum + self.epsilon = batch_norm_epsilon + self.dim = dim + + def forward(self, inputs, input_paddings): + #inputs: NHD + #padding: NH + mask = 1 - input_paddings[:, :, None] + if self.training: + count = mask.sum() + masked_inp = inputs.masked_fill(mask == 0, 0) + sum_ = (masked_inp).sum(dim=(0, 1)) + if USE_PYTORCH_DDP: + sum_ = dist_nn.all_reduce(sum_) + count = dist_nn.all_reduce(count) + mean = sum_ / count + + sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) + if USE_PYTORCH_DDP: + sum_ = dist_nn.all_reduce(sum_) + var = sum_ / count + + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: + mean = self.running_mean + var = self.running_var + v = (1 + self.weight) * torch.rsqrt(var + self.epsilon) + bn = (inputs - mean) * v + self.bias + output = bn.masked_fill(mask == 0, 0) + return output + + +class BatchRNN(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + hidden_size = config.encoder_dim + input_size = config.encoder_dim + bidirectional = config.bidirectional + self.bidirectional = bidirectional + + if config.layernorm_everywhere: + self.normalization_layer = LayerNorm(config.encoder_dim) + else: + self.bn_normalization_layer = BatchNorm(config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon) + + if bidirectional: + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size // 2, + bidirectional=True, + batch_first=True) + else: + self.lstm = nn.LSTM( + input_size=input_size, hidden_size=hidden_size, batch_first=True) + + def forward(self, inputs, input_paddings): + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: + inputs = self.bn_normalization_layer(inputs, input_paddings) + lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() + packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( + inputs, lengths, batch_first=True, enforce_sorted=False) + packed_outputs, _ = self.lstm(packed_inputs) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( + packed_outputs, batch_first=True) + if outputs.shape[1] < inputs.shape[1]: + outputs = torch.cat([ + outputs, + torch.zeros( + size=(outputs.shape[0], + inputs.shape[1] - outputs.shape[1], + outputs.shape[2]), + device=outputs.device) + ], + dim=1) + return outputs + + +class DeepspeechEncoderDecoder(nn.Module): + + def __init__(self, config: DeepspeechConfig): + super().__init__() + self.config = config + + self.specaug = SpecAug( + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + ) + preprocessing_config = preprocessor.PreprocessorConfig() + self.preprocessor = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + + self.subsample = Subsample(config=config) + + self.lstms = nn.ModuleList( + [BatchRNN(config) for _ in range(config.num_lstm_layers)]) + self.ffns = nn.ModuleList( + [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) + + if config.enable_decoder_layer_norm: + self.ln = LayerNorm(config.encoder_dim) + else: + self.ln = nn.Identity() + + self.lin = nn.Linear(config.encoder_dim, config.vocab_size) + + def forward(self, inputs, input_paddings, dropout_rate=None): + outputs = inputs + output_paddings = input_paddings + + outputs, output_paddings = self.preprocessor(outputs, output_paddings) + if self.training and self.config.use_specaug: + outputs, output_paddings = self.specaug(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) + for idx in range(self.config.num_lstm_layers): + if self.config.enable_residual_connections: + outputs = outputs + self.lstms[idx](outputs, output_paddings) + else: + outputs = self.lstms[idx](outputs, output_paddings) + + for idx in range(self.config.num_ffn_layers): + if self.config.enable_residual_connections: + outputs = outputs + self.ffns[idx](outputs, output_paddings) + else: + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) + + if self.config.enable_decoder_layer_norm: + outputs = self.ln(outputs) + + outputs = self.lin(outputs) + + return outputs, output_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py new file mode 100644 index 000000000..1d89ea9e7 --- /dev/null +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -0,0 +1,314 @@ +# Ported to PyTorch from +# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from functools import partial +from typing import Callable, Optional, Tuple + +import jax.tree_util as tree +from jraph import GraphsTuple +import torch +from torch import nn + +from algoperf import init_utils +from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout + + +def _make_mlp(in_dim, hidden_dims, activation_fn): + """Creates a MLP with specified dimensions.""" + layers = SequentialWithDropout() + for i, dim in enumerate(hidden_dims): + layers.add_module(f'dense_{i}', + nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) + layers.add_module(f'activation_fn_{i}', activation_fn()) + layers.add_module(f'dropout_{i}', CustomDropout()) + in_dim = dim + return layers + + +class GNN(nn.Module): + """Defines a graph network. + + The model assumes the input data is a jraph.GraphsTuple without global + variables. The final prediction will be encoded in the globals. + """ + + def __init__(self, + num_outputs: int = 128, + dropout_rate: Optional[float] = 0.1, + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5) -> None: + super().__init__() + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_message_passing_steps = num_message_passing_steps + self.num_outputs = num_outputs + if dropout_rate is None: + self.dropout_rate = 0.1 + # in_features are specifically chosen for the ogbg workload. + self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) + self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) + + if activation_fn_name == 'relu': + activation_fn = nn.ReLU + elif activation_fn_name == 'gelu': + activation_fn = partial(nn.GELU, approximate='tanh') + elif activation_fn_name == 'silu': + activation_fn = nn.SiLU + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + + graph_network_layers = [] + for st in range(self.num_message_passing_steps): + # Constants in in_dims are based on forward call of GraphNetwork: + # specifically update_edge_fn update_node_fn and update_global_fn. + if st == 0: + in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[ + -1] * 2 + self.num_outputs + last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs + else: + in_dim_edge_fn = self.hidden_dims[-1] * 4 + in_dim_node_fn = self.hidden_dims[-1] * 4 + last_in_dim = self.hidden_dims[-1] * 3 + + graph_network_layers.append( + GraphNetwork( + update_edge_fn=_make_mlp(in_dim_edge_fn, + self.hidden_dims, + activation_fn), + update_node_fn=_make_mlp(in_dim_node_fn, + self.hidden_dims, + activation_fn), + update_global_fn=_make_mlp(last_in_dim, + self.hidden_dims, + activation_fn))) + self.graph_network = SequentialWithDropout(*graph_network_layers) + + self.decoder = nn.Linear( + in_features=self.hidden_dims[-1], out_features=self.num_outputs) + + for m in self.modules(): + if isinstance(m, nn.Linear): + init_utils.pytorch_default_init(m) + + def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + + graph = graph._replace( + globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], + device=graph.n_node.device)) + graph = graph._replace(nodes=self.node_embedder(graph.nodes)) + graph = graph._replace(edges=self.edge_embedder(graph.edges)) + + graph = self.graph_network(graph, dropout_rate) + + # Map globals to represent the final result + graph = graph._replace(globals=self.decoder(graph.globals)) + + return graph.globals + + +class GraphNetwork(nn.Module): + """Returns a method that applies a configured GraphNetwork. + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + There is one difference. For the nodes update the class aggregates over the + sender edges and receiver edges separately. This is a bit more general + than the algorithm described in the paper. The original behaviour can be + recovered by using only the receiver edge aggregations for the update. + In addition this implementation supports softmax attention over incoming + edge features. + Example usage:: + gn = GraphNetwork(update_edge_function, + update_node_function, **kwargs) + # Conduct multiple rounds of message passing with the same parameters: + for _ in range(num_message_passing_steps): + graph = gn(graph) + Args: + update_edge_fn: function used to update the edges or None to deactivate edge + updates. + update_node_fn: function used to update the nodes or None to deactivate node + updates. + update_global_fn: function used to update the globals or None to deactivate + globals updates. + Returns: + A method that applies the configured GraphNetwork. + """ + + def __init__(self, + update_edge_fn: Optional[Callable] = None, + update_node_fn: Optional[Callable] = None, + update_global_fn: Optional[Callable] = None) -> None: + super().__init__() + self.update_edge_fn = update_edge_fn + self.update_node_fn = update_node_fn + self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout + + def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: + """Applies a configured GraphNetwork to a graph. + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + There is one difference. For the nodes update the class aggregates over the + sender edges and receiver edges separately. This is a bit more general + the algorithm described in the paper. The original behaviour can be + recovered by using only the receiver edge aggregations for the update. + In addition this implementation supports softmax attention over incoming + edge features. + Many popular Graph Neural Networks can be implemented as special cases of + GraphNets, for more information please see the paper. + Args: + graph: a `GraphsTuple` containing the graph. + Returns: + Updated `GraphsTuple`. + """ + nodes, edges, receivers, senders, globals_, n_node, n_edge = graph + sum_n_node = tree.tree_leaves(nodes)[0].shape[0] + if not tree.tree_all( + tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): + raise ValueError( + 'All node arrays in nest must contain the same number of nodes.') + + sent_attributes = tree.tree_map(lambda n: n[senders], nodes) + received_attributes = tree.tree_map(lambda n: n[receivers], nodes) + # Here we scatter the global features to the corresponding edges, + # giving us tensors of shape [num_edges, global_feat]. + global_edge_attributes = tree.tree_map( + lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) + if self.update_edge_fn: + edge_fn_inputs = torch.cat( + [edges, sent_attributes, received_attributes, global_edge_attributes], + dim=-1) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) + + if self.update_node_fn: + sent_attributes = tree.tree_map( + lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) + received_attributes = tree.tree_map( + lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), + edges) + # Here we scatter the global features to the corresponding nodes, + # giving us tensors of shape [num_nodes, global_feat]. + global_attributes = tree.tree_map( + lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + node_fn_inputs = torch.cat( + [nodes, sent_attributes, received_attributes, global_attributes], + dim=-1) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) + + if self.update_global_fn: + n_graph = n_node.shape[0] + graph_idx = torch.arange(n_graph, device=graph.n_node.device) + # To aggregate nodes and edges from each graph to global features, + # we first construct tensors that map the node to the corresponding graph. + # For example, if you have `n_node=[1,2]`, we construct the tensor + # [0, 1, 1]. We then do the same for edges. + node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0) + edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) + # We use the aggregation function to pool the nodes/edges per graph. + node_attributes = tree.tree_map( + lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) + edge_attributes = tree.tree_map( + lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) + # These pooled nodes are the inputs to the global update fn. + global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], + dim=-1) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) + + return GraphsTuple( + nodes=nodes, + edges=edges, + receivers=receivers, + senders=senders, + globals=globals_, + n_node=n_node, + n_edge=n_edge) + + +# Forked from +# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. +def scatter_sum(src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + r""" + | + .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ + master/docs/source/_figures/add.svg?sanitize=true + :align: center + :width: 400px + | + Reduces all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. + For each value in :attr:`src`, its output index is specified by its index + in :attr:`src` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + The applied reduction is here defined as a sum. + Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional + tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` + and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional + tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. + Moreover, the values of :attr:`index` must be between :math:`0` and + :math:`y - 1`, although no specific ordering of indices is required. + The :attr:`index` tensor supports broadcasting in case its dimensions do + not match with :attr:`src`. + For one-dimensional tensors, the operation computes + .. math:: + \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + .. note:: + This operation is implemented via atomic operations on the GPU and is + therefore **non-deterministic** since the order of parallel operations + to the same value is undetermined. + For floating-point variables, this results in a source of variance in + the result. + :param src: The source tensor. + :param index: The indices of elements to scatter. + :param dim: The axis along which to index. (default: :obj:`-1`) + :param out: The destination tensor. + :param dim_size: If :attr:`out` is not given, automatically create output + with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor + according to :obj:`index.max() + 1` is returned. + :rtype: :class:`Tensor` + .. code-block:: python + src = torch.randn(10, 6, 64) + index = torch.tensor([0, 1, 0, 1, 2, 1]) + # Broadcasting in the first and last dim. + out = scatter_sum(src, index, dim=1) + print(out.size()) + .. code-block:: + torch.Size([10, 3, 64]) + """ + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +# Forked from +# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py. +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src From e80add440b4ab281414d7babf4f5492c16d758e2 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 11:23:05 +0200 Subject: [PATCH 034/123] remove attention_dropout_rate from wmt --- .../wmt/wmt_pytorch/models_dropout.py | 981 ++++++++++++++++++ 1 file changed, 981 insertions(+) create mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py new file mode 100644 index 000000000..588d06abf --- /dev/null +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -0,0 +1,981 @@ +import copy +import math +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +from torch.nn.init import normal_ +from torch.nn.init import xavier_uniform_ + + +def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: + """Make a causal mask for self-attention. + + Args: + x: input array of shape `[batch..., len]` + device: device to store the idxs + + Returns: + A `[batch..., len, len]` shaped causal attention mask. + """ + idxs = torch.broadcast_to( + torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) + return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) + + +def make_src_mask(src, inputs_segmentation, nhead): + """Utility for creating src mask and adjust it for PyTorch Transformer API.""" + src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) + # Add segmentation block-diagonal attention mask if using segmented data. + if inputs_segmentation is not None: + src_mask = torch.logical_and( + src_mask, + torch.eq( + inputs_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2))) + # Flip values and ensure numerical stability. + src_mask = torch.repeat_interleave( + torch.logical_not(src_mask), repeats=nhead, dim=0) + new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) + new_src_mask.masked_fill_(src_mask, -1e10) + return new_src_mask + + +def make_tgt_and_memory_mask(tgt, + src, + inputs_segmentation, + targets_segmentation, + decode, + nhead): + """ Utility for creating target and memory mask and adjust them for PyTorch + Transformer API.""" + if not decode: + tgt_mask = torch.logical_and( + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), + make_causal_mask(tgt, device=tgt.device)) + memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) + else: + tgt_mask = None + memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), + (src > 0).unsqueeze(-2)) + # Add segmentation block-diagonal attention masks if using segmented data. + if inputs_segmentation is not None: + tgt_mask = torch.logical_and( + tgt_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), + targets_segmentation.unsqueeze(-2))) + memory_mask = torch.logical_and( + memory_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2))) + # Flip values and ensure numerical stability. + memory_mask = torch.repeat_interleave( + torch.logical_not(memory_mask), repeats=nhead, dim=0) + new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) + new_memory_mask.masked_fill_(memory_mask, -1e10) + if tgt_mask is not None: + tgt_mask = torch.repeat_interleave( + torch.logical_not(tgt_mask), repeats=nhead, dim=0) + new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) + new_tgt_mask.masked_fill_(tgt_mask, -1e10) + tgt_mask = new_tgt_mask + return tgt_mask, new_memory_mask + + +def shift_right(x, axis=1): + """Shift the input to the right by padding on axis 1.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup) + padded = F.pad(x, pad_widths, mode='constant') + return padded[:, :-1] + + +class Transformer(nn.Module): + """Transformer architecture based on the model from the WMT Jax workload.""" + + def __init__(self, + ntoken: int = 32000, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: Optional[float] = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + if dropout_rate is None: + dropout_rate = 0.1 + self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.shared_embedding = nn.Embedding(ntoken, d_model) + self.encoder = Encoder(d_model, + nhead, + d_hid, + nlayers, + dropout_rate, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln) + self.decoder = Decoder(d_model, + nhead, + d_hid, + nlayers, + dropout_rate, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln) + # Share positional encoding and embedding between encoder and decoder. + self.encoder.pos_encoder = self.pos_encoder + self.encoder.shared_embedding = self.shared_embedding + self.decoder.pos_encoder = self.pos_encoder + self.decoder.shared_embedding = self.shared_embedding + + self._reset_parameters() + + def _reset_parameters(self): + """Initiate parameters in the transformer model.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) + + def forward(self, + src: Tensor, + tgt: Tensor, + inputs_positions: Optional[Tensor] = None, + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False) -> Tensor: + """ + Args: + src: Tensor, shape [batch_size, seq_len] + tgt: Tensor, shape [batch_size, seq_len] + inputs_positions: Optional[Tensor], shape [batch_size, seq_len] + targets_positions: Optional[Tensor], shape [batch_size, seq_len] + inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] + targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] + decode: bool + + Returns: + output Tensor of shape [batch_size, seq_len, ntoken] + """ + if src.size(0) != tgt.size(0): + raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( + src, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation) + output = self.decoder( + tgt, + memory, + src, # just for calculating the padding mask + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + decode=decode) + return output + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class. + num_layers: the number of sub-encoder-layers in the encoder. + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to + nested tensor (and convert back on output). This will improve + the overall performance of TransformerEncoder when padding + rate is high. + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(12, 8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True): + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + self.enable_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + + Shape: + see the docs in Transformer class. + """ + output = src + convert_to_nested = False + + for mod in self.layers: + output = mod(output, src_mask=mask) + + if convert_to_nested: + output = output.to_padded_tensor(0.) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class Encoder(nn.Module): + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + self.nhead = nhead + self.shared_embedding = None + self.pos_encoder = None + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + d_hid, + dropout_rate, + activation=activation, + glu=glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln) + encoder_norm = ( + nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) + + def forward(self, + src: Tensor, + inputs_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None) -> Tensor: + src = src.to(torch.int) + src_mask = make_src_mask(src, inputs_segmentation, self.nhead) + src = self.shared_embedding(src) + src = self.pos_encoder(src, inputs_positions) + memory = self.encoder(src, mask=src_mask) + return memory + + +class Decoder(nn.Module): + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True): + super().__init__() + self.nhead = nhead + self.shared_embedding = None + self.pos_encoder = None + self.decoder = TransformerDecoder(d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps, + nlayers, + attention_temp, + pre_ln) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + src: Tensor, # just for calculating the padding mask + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None) -> Any: + tgt = tgt.to(torch.int) + tgt_mask, memory_mask = make_tgt_and_memory_mask( + tgt, src, inputs_segmentation, targets_segmentation, + decode, self.nhead) + if not decode: + tgt = shift_right(tgt) + tgt = self.shared_embedding(tgt) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + if decode: + tgt, cache = tgt + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache) + if decode: + output, cache = output + normalize = math.sqrt(output.shape[-1]) + output = torch.matmul(output, self.shared_embedding.weight.T) / normalize + if decode: + return output, cache + return output + + +class PositionalEncoding(nn.Module): + + def __init__(self, + d_model: int, + dropout_rate: float = 0.1, + max_len: int = 256): + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + + position = torch.arange(max_len).unsqueeze(1) + scale_factor = -math.log(10000.0) / (d_model // 2 - 1) + div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, :d_model // 2] = torch.sin(position * div_term) + pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward( + self, + x: Tensor, + inputs_positions: Optional[Tensor] = None, + decode: bool = False, + cache: Optional[Dict[str, Dict[str, Tensor]]] = None + ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: + """ + Args: + x: Tensor (shape [batch_size, seq_len, embedding_dim]) + inputs_positions: Tensor (shape [batch_size, seq_len]) or None + decode: bool + cache: Dict[str, Dict[str, Tensor]] or None + Returns: + Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] + """ + # We use a cache position index for tracking decoding position. + if decode: + name = self._get_name() + if cache is None: + cache = { + name: { + 'cache_index': + torch.tensor(0, dtype=torch.long, device=self.pe.device), + }, + } + pe = self.pe[0, cache[name]['cache_index'], :] + cache[name]['cache_index'] += 1 + return self.dropout(x + pe), cache + if inputs_positions is None: + # normal unpacked case: + pe = self.pe[:, :x.size(1), :] + else: + # for packed data we need to use known position indices: + pe = self.pe[0, inputs_positions, :] + return self.dropout(x + pe) + + +# TransformerEncoderLayer and TransformerDecoderLayer are taken from: +# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py +# Main difference is the use of custom MultiheadAttention modules. +class TransformerEncoderLayer(nn.Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all + you need. In Advances in Neural Information Processing Systems, + pages 6000-6010. Users may modify or implement in a different way during + application. + Args: + d_model: the number of expected features in the input (default=1024). + nhead: the number of heads in the multiheadattention models (default=16). + dim_feedforward: the dimension of the feedforward network model + (default=1024). + dropout_rate: the dropout_rate value (default=0.1). + activation: the activation function of the intermediate layer, can be a + string ("relu" or "gelu") or a unary callable (default=F.relu). + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + pre_ln: if ``True``, layer norm is done prior to attention and + feedforward operations, respectivaly. Otherwise it's done after. + Default: ``True``. + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(32, 10, 512) + >>> out = encoder_layer(src) + """ + __constants__ = ['pre_ln'] + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + self_attn=True, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + + # Implementation of Feedforward model. + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.glu = glu + if self.glu: + self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout_rate) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.pre_ln = pre_ln + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + + self.activation = activation + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + + Shape: + see the docs in Transformer class. + """ + x = src + if self.pre_ln: + x = x + self._sa_block(self.norm1(x), src_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + # Self-attention block: + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask) + return self.dropout1(x) + + # Feed forward block: + def _ff_block(self, inputs: Tensor) -> Tensor: + x = self.activation(self.linear1(inputs)) + if self.glu: + y = self.linear_glu(inputs) + x = x * y + x = self.linear2(self.dropout(x)) + return self.dropout2(x) + + +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers + Args: + d_model: the number of expected features in the input (default=1024) + nhead: the number of heads in the multiheadattention models (default=16) + d_hid: the dimension of the feedforward network model + (default=1024) + dropout_rate: the dropout_rate value (default=0.1) + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + decoder_layer: an instance of the TransformerDecoderLayer() class + num_layers: the number of sub-decoder-layers in the decoder + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(12, 8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, + d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps, + num_layers, + attention_temp, + pre_ln): + super().__init__() + self.layers = nn.ModuleList([ + TransformerDecoderLayer( + d_model, + nhead, + d_hid, + dropout_rate, + activation, + glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln) for _ in range(num_layers) + ]) + self.num_layers = num_layers + self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + + def forward(self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None) -> Any: + r"""Pass the inputs (and mask) through the decoder layer in turn. + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + decode: whether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + Shape: + see the docs in Transformer class. + """ + output = tgt + + for idx, mod in enumerate(self.layers): + output, cache = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=idx) + + if self.norm is not None: + output = self.norm(output) + + if decode: + return output, cache + return output + + +# Modified to use cache for autoregressive decoding and custom +# MultiheadAttention modules. +class TransformerDecoderLayer(nn.Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and + feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all + you need. In Advances in Neural Information Processing Systems, + pages 6000-6010. Users may modify or implement in a different way during + application. + Args: + d_model: the number of expected features in the input (default=1024). + nhead: the number of heads in the multiheadattention models (default=16). + dim_feedforward: the dimension of the feedforward network model + (default=1024). + dropout_rate: the dropout_rate value (default=0.1). + activation: the activation function of the intermediate layer, can be a + string ("relu" or "gelu") or a unary callable (default=F.relu). + layer_norm_eps: the eps value in layer normalization components + (default=1e-6). + pre_ln: if ``True``, layer norm is done prior to self attention, + multihead attention and feedforward operations, respectivaly. + Otherwise it's done after. Default: ``True``. + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 20, 512) + >>> out = decoder_layer(tgt, memory) + """ + __constants__ = ['pre_ln'] + + def __init__(self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + dropout_rate: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + pre_ln: bool = True, + attention_temp: float = 1.0, + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + self_attn=True, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + self_attn=False, + dropout_rate=dropout_rate, + attention_temp=attention_temp, + bias=False, + **factory_kwargs) + + # Implementation of Feedforward model. + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.glu = glu + if self.glu: + self.linear_glu = nn.Linear(dim_feedforward, + dim_feedforward, + **factory_kwargs) + self.dropout = nn.Dropout(dropout_rate) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.pre_ln = pre_ln + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout_rate) + self.dropout2 = nn.Dropout(dropout_rate) + self.dropout3 = nn.Dropout(dropout_rate) + + self.activation = activation + + def forward( # pylint: disable=arguments-renamed + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + r"""Pass the inputs (and mask) through the decoder layer. + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + decode: wether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + Shape: + see the docs in Transformer class. + """ + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = tgt + if self.pre_ln: + sa_out, cache = self._sa_block( + self.norm1(x), + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + x = x + sa_out + x = x + self._mha_block(self.norm2(x), memory, memory_mask) + x = x + self._ff_block(self.norm3(x)) + else: + sa_out, cache = self._sa_block( + x, + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + x = self.norm1(x + sa_out) + x = self.norm2(x + self._mha_block(x, memory, memory_mask)) + x = self.norm3(x + self._ff_block(x)) + + return x, cache + + # Self-attention block: + def _sa_block( # pylint: disable=arguments-renamed + self, + x: Tensor, + attn_mask: Optional[Tensor], + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + x, cache = self.self_attn( + x, + attn_mask=attn_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index) + return self.dropout1(x), cache + + # Multihead attention block: + def _mha_block(self, x: Tensor, mem: Tensor, + attn_mask: Optional[Tensor]) -> Tensor: + x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) + return self.dropout2(x) + + # Feed forward block. + def _ff_block(self, inputs: Tensor) -> Tensor: + x = self.activation(self.linear1(inputs)) + if self.glu: + y = self.linear_glu(inputs) + x = x * y + x = self.linear2(self.dropout(x)) + return self.dropout3(x) + + +class MultiheadAttention(nn.Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. Supports self-attention and + encoder-decoder attention. + See `Attention Is All You Need `_. + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will + be split across ``num_heads`` (i.e. each head will have dimension + ``embed_dim // num_heads``). + self_attn: Whether self attention or encoder-decoder attention is used. + Default: ``True``. + dropout_rate: Dropout probability on ``attn_output_weights``. + Default: ``0.0`` (no dropout_rate). + bias: If specified, adds bias to input / output projection layers. + Default: ``False``. + device: The device of the module. + dtype: The dtype of the module. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, cache = multihead_attn(x) + """ + + def __init__(self, + embed_dim: int, + num_heads: int, + self_attn: bool = True, + dropout_rate: float = 0., + attention_temp: float = 1.0, + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.self_attn = self_attn + self.dropout = dropout_rate + self.head_dim = embed_dim // num_heads + self.attention_temp = attention_temp + assert self.head_dim * num_heads == self.embed_dim, \ + 'embed_dim must be divisible by num_heads.' + + factory_kwargs = {'device': device, 'dtype': dtype} + if self_attn: + # Self-attention. + self.in_proj = nn.Linear( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + # Encoder-decoder attention. + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.kv_proj = nn.Linear( + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + self._reset_parameters() + + def _reset_parameters(self): + """Initiate parameters in the MultiheadAttention module.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + xavier_uniform_(module.weight) + if module.bias is not None: + normal_(module.bias, std=1e-6) + + def forward(self, + x: Tensor, + mem: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None) -> Any: + r""" + Args: + x: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for self + attention mechanism. See "Attention Is All You Need" for more details. + mem: Batch of input sequences of shape + (batch size, sequence length, embedding dimensionality) for + encoder-decoder attention. See "Attention Is All You Need" for more + details. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain + positions. Must be of shape :math:`(L, S)` or + :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the + batch size, :math:`L` is the target sequence length, and :math:`S` + is the source sequence length. A 2D mask will be broadcasted across + the batch while a 3D mask allows for a different mask for each entry + in the batch. Binary, byte, and float masks are supported. + For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, + a non-zero value indicates that the corresponding position is not + allowed to attend. For a float mask, the mask values will be added to + the attention weight. + decode: wether to use cache for autoregressive decoding or not. + max_len: maximum sequence length, necessary for decoding cache. + cache: cache dictionary for autoregressive decoding. + index: index of the current decoding step, necessary for decoding cache. + Outputs: + - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where + :math:`L` is the target sequence length, :math:`N` is the batch size, + and :math:`E` is the embedding dimension ``embed_dim``. + - **cache** - For autoregressive decoding. + """ + # Shape: (batch size, sequence length, embedding dimensionality) + bsz, seq_len, embed_dim = x.size() + # In projection. + if self.self_attn: + q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) + else: + q = self.q_proj(x) + k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) + # This is 1 (!= seq_len) during autoreregressive decoding. + tgt_len = q.size(1) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + name = f'decoder.layers.{index}.self_attn' + loc_cache = cache[name] if decode and name in cache else None + if decode: + if loc_cache is None: + loc_cache = { + 'cached_key': + torch.zeros((bsz, max_len, embed_dim), + dtype=k.dtype, + device=k.device), + 'cached_value': + torch.zeros((bsz, max_len, embed_dim), + dtype=v.dtype, + device=v.device), + 'cache_index': + torch.tensor(0, dtype=torch.long, device=k.device), + } + cached_key = loc_cache['cached_key'] + cached_value = loc_cache['cached_value'] + cache_index = loc_cache['cache_index'] + # Shape check of cached keys against query input. + expected_shape = (bsz, 1, embed_dim) + if expected_shape != x.shape: + raise ValueError('Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.') + # Update key, value caches with our new 1d spatial slices. + cached_key[:, cache_index:cache_index + 1, :] = k + cached_value[:, cache_index:cache_index + 1, :] = v + k = cached_key + v = cached_value + cache_index += 1 + # Causal mask for cached decoder self-attention: + # our single query position should only attend to those key + # positions that have already been generated and cached, + # not the remaining zero elements. + if attn_mask is not None: + raise ValueError('Attention mask has to be None for decode == True.') + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) + + # Update sequence length to account for complete sequence. + seq_len = k.size(1) + + # Rearrange q, k, v for multihead attention. + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # Check dtype and shape of attention mask. + if not decode and attn_mask is not None: + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + # Ensure attn_mask's dim is 3. + if attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.') + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported") + # Reshape attention mask to be consistent with q, k, v. + attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) + + # Convert attention mask to float. + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, -1e10) + attn_mask = new_attn_mask + + # Adjust dropout_rate probability. + dropout_rate = self.dropout if self.training else 0.0 + + # Calculate attention. + q = self.attention_temp * q + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_rate) + # Rearrange for output projection. + attn_output = attn_output.transpose(1, 2).contiguous().view( + bsz, tgt_len, embed_dim) + # Output projection. + attn_output = self.out_proj(attn_output) + + if decode: + cache[name] = loc_cache + + return attn_output, cache From 84b1bd19bb1083947adfa87a9488b074a6b170ac Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 15:37:33 +0200 Subject: [PATCH 035/123] dropout fix on wmt --- .../wmt/wmt_pytorch/models_dropout.py | 168 ++++++++++-------- 1 file changed, 91 insertions(+), 77 deletions(-) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index 588d06abf..c5014d87d 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -112,14 +112,15 @@ def __init__(self, pre_ln: bool = True): super().__init__() if dropout_rate is None: - dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.dropout_rate = 0.1 + else: + self.dropout_rate = dropout_rate + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, nhead, d_hid, nlayers, - dropout_rate, activation, glu, layer_norm_eps, @@ -129,7 +130,6 @@ def __init__(self, nhead, d_hid, nlayers, - dropout_rate, activation, glu, layer_norm_eps, @@ -158,7 +158,8 @@ def forward(self, targets_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + decode: bool = False, + dropout_rate: Optional[float] = None) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -168,16 +169,22 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: Optional[float] Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + + if dropout_rate is None: + dropout_rate = self.dropout_rate + memory = self.encoder( src, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) output = self.decoder( tgt, memory, @@ -185,7 +192,8 @@ def forward(self, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - decode=decode) + decode=decode, + dropout_rate=dropout_rate) return output @@ -224,12 +232,15 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional) Shape: see the docs in Transformer class. @@ -238,7 +249,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -256,7 +267,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -270,7 +280,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation=activation, glu=glu, layer_norm_eps=layer_norm_eps, @@ -283,12 +292,13 @@ def __init__(self, def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory @@ -299,7 +309,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -312,7 +321,6 @@ def __init__(self, self.decoder = TransformerDecoder(d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps, @@ -330,7 +338,8 @@ def forward( targets_segmentation: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = None) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -338,7 +347,7 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -348,7 +357,8 @@ def forward( memory_mask=memory_mask, decode=decode, max_len=max_len, - cache=cache) + cache=cache, + dropout_rate=dropout_rate) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -362,10 +372,8 @@ class PositionalEncoding(nn.Module): def __init__(self, d_model: int, - dropout_rate: float = 0.1, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) @@ -380,7 +388,8 @@ def forward( x: Tensor, inputs_positions: Optional[Tensor] = None, decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = None ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -403,14 +412,14 @@ def forward( } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: pe = self.pe[:, :x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -448,7 +457,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -462,7 +470,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -472,50 +479,55 @@ def __init__(self, self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = None) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = None) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, + inputs: Tensor, + dropout_rate: Optional[float] = None) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -527,7 +539,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -545,7 +556,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps, @@ -558,7 +568,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, activation, glu, layer_norm_eps=layer_norm_eps, @@ -575,7 +584,8 @@ def forward(self, memory_mask: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = None) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -584,6 +594,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -598,7 +609,8 @@ def forward(self, decode=decode, max_len=max_len, cache=cache, - index=idx) + index=idx, + dropout_rate=dropout_rate) if self.norm is not None: output = self.norm(output) @@ -624,7 +636,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -644,7 +655,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -658,7 +668,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -666,7 +675,6 @@ def __init__(self, d_model, nhead, self_attn=False, - dropout_rate=dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -678,16 +686,12 @@ def __init__(self, self.linear_glu = nn.Linear(dim_feedforward, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation @@ -700,7 +704,8 @@ def forward( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -709,6 +714,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -722,10 +728,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( x, @@ -733,10 +740,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache @@ -748,30 +756,38 @@ def _sa_block( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, decode=decode, max_len=max_len, cache=cache, - index=index) - return self.dropout1(x), cache + index=index, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = None) -> Tensor: + x, _ = self.multihead_attn( + x, + mem, + attn_mask=attn_mask, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, inputs: Tensor, + dropout_rate: Optional[float] = None) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -789,8 +805,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -804,7 +818,6 @@ def __init__(self, embed_dim: int, num_heads: int, self_attn: bool = True, - dropout_rate: float = 0., attention_temp: float = 1.0, bias: bool = False, device: Optional[torch.device] = None, @@ -813,7 +826,6 @@ def __init__(self, self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp assert self.head_dim * num_heads == self.embed_dim, \ @@ -848,7 +860,8 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = None) -> Any: r""" Args: x: Batch of input sequences of shape @@ -874,6 +887,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -963,12 +977,12 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate) # Rearrange for output projection. attn_output = attn_output.transpose(1, 2).contiguous().view( bsz, tgt_len, embed_dim) From af08bb91f93266d12e6ffeab7045b7c57cb9143f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:19:44 +0200 Subject: [PATCH 036/123] fix dropout, ALL tested --- .../criteo1tb_pytorch/models_dropout.py | 35 +- .../models_functional_dropout.py | 308 ------------------ .../fastmri/fastmri_pytorch/models_dropout.py | 13 +- .../librispeech_pytorch/models_dropout.py | 2 +- .../librispeech_pytorch/models_dropout.py | 2 +- 5 files changed, 37 insertions(+), 323 deletions(-) delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 8042ec31e..d8d7393e4 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -79,6 +79,10 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -127,17 +131,16 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(CustomDropout()) # (nico) - block = SequentialWithDropout(*block) # (nico) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): block = DenseBlockWithDropout(block, resnet=True) else: block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): @@ -149,7 +152,10 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate): + def forward(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -201,6 +207,11 @@ def __init__(self, self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim self.embedding_init_multiplier = embedding_init_multiplier + self.dropout_rate = dropout_rate + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -253,10 +264,9 @@ def __init__(self, top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(CustomDropout()) # (nico) - self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -271,7 +281,10 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate): + def forward(self, x, dropout_rate=None): + if dropout_rate is None: + dropout_rate = self.dropout_rate + batch_size = x.shape[0] dense_features, sparse_features = torch.split( diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py deleted file mode 100644 index 346e0e72a..000000000 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Pytorch implementation of DLRM-Small.""" - -import math - -import torch -import torch.nn.functional as F -from torch import nn - - -class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" - - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - - def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) - - -class DotInteract(nn.Module): - """Performs feature interaction operation between dense or sparse features.""" - - def __init__(self, num_sparse_features): - super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) - - def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] - return torch.cat((dense_features, interactions_flat), dim=1) - - -class DLRMResNet(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - # dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bot_mlp_blocks = [] - for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): - block = [] - block.append(nn.Linear(input_dim, dense_dim)) - block.append(nn.ReLU(inplace=True)) - block = nn.Sequential(*block) - if layer_idx > 0: - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - bot_mlp_blocks.append(block) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bot_mlp_blocks) - - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - # Number of sparse features = 26 - fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] - num_layers_top = len(self.mlp_top_dims) - mlp_top_blocks = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - block = [] - block.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - block.append(nn.ReLU(inplace=True)) - # if (dropout_rate is not None and dropout_rate > 0.0 and - # layer_idx == num_layers_top - 2): - # block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) - if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - mlp_top_blocks.append(block) - fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) - - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate): - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) - top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) - - # Final MLP (horrible!!). - h = top_mlp_input - num_layers_top = len(self.mlp_top_dims) - for layer_idx, block in enumerate(self.top_mlp): - # block.module is nn.Sequential([...]) - seq = block.module - # 1) linear - out = seq[0](h) - # 2) ReLU (if present) - if layer_idx < (num_layers_top - 1): - out = seq[1](out) - # 3) functional dropout at penult layer - if dropout_rate > 0 and layer_idx == num_layers_top - 2: - out = F.dropout(out, dropout_rate, training=self.training) - # 4) wrap in residual if needed - h = out + h if block.resnet else out - return h - - -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - # dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - self.embedding_init_multiplier = embedding_init_multiplier - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - - if self.embedding_init_multiplier is None: - scale = 1.0 / torch.sqrt(self.vocab_size) - else: - scale = self.embedding_init_multiplier - - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bottom_mlp_layers = [] - for dense_dim in self.mlp_bottom_dims: - bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) - bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - - # TODO: Write down the formula here instead of the constant. - input_dims = 506 - num_layers_top = len(self.mlp_top_dims) - top_mlp_layers = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] - top_mlp_layers.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - # if (dropout_rate is not None and dropout_rate > 0.0 and - # layer_idx == num_layers_top - 2): - # top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate): - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) - if self.embed_ln: - embedded_sparse = self.embed_ln(embedded_sparse) - # Dot product interactions. - concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) - - # Final MLP: run each layer, and after the penultimate layer do functional dropout - h = concatenated_dense - N = len(self.top_mlp) - for idx, layer in enumerate(self.top_mlp): - h = layer(h) - # insert dropout exactly where nn.Dropout used to live - if dropout_rate > 0 and idx == N - 2: - h = F.dropout(h, dropout_rate, training=self.training) - return h diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 5862f6352..8954cb737 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -29,6 +29,7 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, + dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -37,6 +38,11 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers + if dropout_rate is None: + self.dropout_rate = 0.0 + else: + self.dropout_rate = dropout_rate + self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, num_channels, @@ -72,7 +78,10 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: + if dropout_rate is None: + dropout_rate = self.dropout_rate + stack = [] output = x @@ -136,7 +145,7 @@ def __init__(self, CustomDropout2d(), ) - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: return self.conv_layers(x, dropout_rate) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index da66dfe43..9ff662fb8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -510,7 +510,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index e68a820ed..8797aa578 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -383,7 +383,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) else: outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) From 7a6651a69953af4655e0f7c50b8c7fefe71aa9ca Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:20:39 +0200 Subject: [PATCH 037/123] add dropout equivalence tests --- .../test_model_equivalence.py | 77 ++++++++++++ .../fastmri_pytorch/test_model_equivalence.py | 98 +++++++++++++++ .../test_model_equivalence.py | 112 ++++++++++++++++++ .../test_model_equivalence.py | 91 ++++++++++++++ .../test_model_equivalence.py | 89 ++++++++++++++ .../ogbg_pytorch/test_model_equivalence.py | 76 ++++++++++++ .../wmt_pytorch/test_model_equivalence.py | 83 +++++++++++++ 7 files changed, 626 insertions(+) create mode 100644 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py create mode 100644 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..b9b1232ef --- /dev/null +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -0,0 +1,77 @@ +""" +Runs fwd pass with random input for our DLRM models and compares outputs. +Run it as: + python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( + DLRMResNet as OriginalDLRMResNet, + DlrmSmall as OriginalDlrmSmall, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import ( + DLRMResNet as CustomDLRMResNet, + DlrmSmall as CustomDlrmSmall, +) + + +BATCH, DENSE, SPARSE = 16, 13, 26 +FEATURES = DENSE + SPARSE +VOCAB = 1000 +DEVICE = 'cuda' +TORCH_COMPILE = False +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ModelEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None), + dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None), + dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), + dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), + dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), + dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1), + dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0), + dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0), + ) + def test_forward(self, model, dropout_rate): + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) + orig.to(DEVICE) + + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate) + cust.to(DEVICE) + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..6339ff21b --- /dev/null +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -0,0 +1,98 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet +from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet + +BATCH, IN_CHANS, H, W = 4, 1, 256, 256 +OUT_CHANS, C, LAYERS = 1, 32, 4 +DEVICE = 'cuda' +TORCH_COMPILE = True +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class FastMRIModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different values of dropout_rate.""" + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters( + dict(testcase_name='default', use_tanh=False, use_layer_norm=False), + dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), + dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), + dict(testcase_name='both', use_tanh=True, use_layer_norm=True), + ) + def test_arch_setups(self, use_tanh, use_layer_norm): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, + use_tanh=use_tanh, use_layer_norm=use_layer_norm + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..56644f152 --- /dev/null +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -0,0 +1,112 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os +import itertools + +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit + +# Model / test hyper-params +BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) +WIDTH, DEPTH, HEADS = 256, 4, 8 +DROPOUT_RATE = None +DEVICE = 'cuda' +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + +class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.6', dropout_rate=0.6), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different dropout_values.""" + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [dropout_rate, None]: + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=custom_init_dropout_rate, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + + @parameterized.named_parameters([ + dict( + testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", + use_glu=use_glu, + use_post_ln=use_post_ln, + use_map=use_map, + ) + for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3) + ]) + def test_arch(self, use_glu, use_post_ln, use_map): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=None, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..19525a98b --- /dev/null +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -0,0 +1,91 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `attention_residual_dropout_rate` (if None, 0.1 +- `conv_residual_dropout_rate` (if None, 0.0) +- `feed_forward_residual_dropout_rate` (if None, 0.1) +- `input_dropout_rate` (if None, 0.1) +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + # ConformerConfig, + ConformerEncoderDecoder as OriginalConf +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( + ConformerEncoderDecoder as CustomConf, + ConformerConfig, +) + +B, T = 32, 36_000 +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class ConformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalConf( + ConformerConfig( + num_encoder_layers=3, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomConf( + ConformerConfig( + num_encoder_layers=3, + attention_residual_dropout_rate=custom_init_dropout_rate, + conv_residual_dropout_rate=custom_init_dropout_rate, + feed_forward_residual_dropout_rate=custom_init_dropout_rate, + input_dropout_rate=custom_init_dropout_rate + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..e31f4a7eb --- /dev/null +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -0,0 +1,89 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `input_dropout_rate` (if None, 0.1 +- `feed_forward_dropout_rate` (if None, 0.1) +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch +import os + +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechEncoderDecoder as OriginalModel +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( + DeepspeechEncoderDecoder as CustomModel, + DeepspeechConfig, +) + +B, T = 32, 30_000 +DEVICE = 'cuda' +TORCH_COMPILE = True + +os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class DeepSpeechEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=None', dropout_rate=None), + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalModel( + DeepspeechConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(DeepspeechConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=custom_init_dropout_rate, + feed_forward_dropout_rate=custom_init_dropout_rate, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..cc1857705 --- /dev/null +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -0,0 +1,76 @@ +""" +Runs fwd pass with random graphs for OGBG GNN models and compares outputs. +Run with: + python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np +from jraph import GraphsTuple + +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel +from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel + +B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph +NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) +SEED = 1996 + + +def _rand_graph(): + total_nodes, total_edges = B * N, B * E + nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) + edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) + senders, receivers = [], [] + for i in range(B): + offset = i * N + s = torch.randint(N, (E,), device=DEVICE) + offset + r = torch.randint(N, (E,), device=DEVICE) + offset + senders.append(s), receivers.append(r) + senders = torch.cat(senders); receivers = torch.cat(receivers) + n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) + n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) + return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) + + +class GNNEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='None', dropout_rate=None), + dict(testcase_name='0.0', dropout_rate=0.0), + dict(testcase_name='0.2', dropout_rate=0.2), + dict(testcase_name='0.7', dropout_rate=0.7), + dict(testcase_name='1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py new file mode 100644 index 000000000..9aca717d9 --- /dev/null +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -0,0 +1,83 @@ +""" +Runs fwd pass with random input for WMT Transformer models and compares outputs. +Run with: + python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +""" + +from absl.testing import absltest, parameterized +from torch.testing import assert_close +import torch, os, random, numpy as np + +from algoperf.workloads.wmt.wmt_pytorch.models import ( + Transformer as OriginalModel, +) +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import ( + Transformer as CustomModel, +) + +B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 +DEVICE = "cuda" +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +def _rand_tokens(bs, seqlen): + return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) + + +class TransformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention + + dict(testcase_name="None", dropout_rate=None, compile=False), + dict(testcase_name="0.0", dropout_rate=0.0, compile=False), + dict(testcase_name="0.2", dropout_rate=0.2, compile=False), + dict(testcase_name="0.7", dropout_rate=0.7, compile=False), + + dict(testcase_name="p=None, compile", dropout_rate=None, compile=True), + dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True), + ) + def test_forward(self, dropout_rate, compile): + + # Test initalizing custom model with a None dropout_rate + for custom_init_dropout_rate in [None, dropout_rate]: + + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate + ).to(DEVICE) + cust = CustomModel( + dropout_rate=custom_init_dropout_rate + ).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main() From a7ff3d1ab09c57807f4d0c7b219803407c085c69 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 17:39:46 +0200 Subject: [PATCH 038/123] moved custom dropout to pytorch_utils --- algoperf/pytorch_utils.py | 38 ++++++++++++++++++ .../criteo1tb_pytorch/models_dropout.py | 2 +- algoperf/workloads/dropout_modules.py | 40 ------------------- .../fastmri/fastmri_pytorch/models_dropout.py | 2 +- .../ogbg/ogbg_pytorch/models_dropout.py | 2 +- 5 files changed, 41 insertions(+), 43 deletions(-) delete mode 100644 algoperf/workloads/dropout_modules.py diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 4a674985d..4af77088e 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -5,7 +5,10 @@ import jax import tensorflow as tf import torch +from torch import Tensor +import torch.nn as nn import torch.distributed as dist +import torch.nn.functional as F from algoperf import spec from algoperf.profiler import Profiler @@ -77,3 +80,38 @@ def update_batch_norm_fn(module: spec.ParameterContainer, module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout(input, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, input: Tensor, p: float) -> Tensor: + return F.dropout2d(input, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + for module in self: + if getattr(module, '_supports_custom_dropout', False): + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index d8d7393e4..065ebd1f8 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn -from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout class DenseBlock(nn.Module): diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py deleted file mode 100644 index 6cec3f7ad..000000000 --- a/algoperf/workloads/dropout_modules.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Custom classes to support a dynamic modulized dropout, see issue??TODO""" - -from torch import Tensor -from torch import nn -import torch.nn.functional as F - - -class CustomDropout(nn.Module): - """A module around torch.nn.functional.dropout.""" - def __init__(self): - super().__init__() - self._supports_custom_dropout = True - - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout(input, p, training=self.training) - - -class CustomDropout2d(nn.Module): - """A module around torch.nn.functional.dropout2d.""" - def __init__(self): - super().__init__() - self._supports_custom_dropout = True - - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout2d(input, p, training=self.training) - - -class SequentialWithDropout(nn.Sequential): - """Sequential of modules with dropout.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._supports_custom_dropout = True - - def forward(self, x: Tensor, p: float) -> Tensor: - for module in self: - if getattr(module, '_supports_custom_dropout', False): - x = module(x, p) - else: - x = module(x) - return x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 8954cb737..260cb7e44 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from algoperf import init_utils -from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index 1d89ea9e7..b86b88caa 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -9,7 +9,7 @@ from torch import nn from algoperf import init_utils -from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout def _make_mlp(in_dim, hidden_dims, activation_fn): From f26ab02d987a839c481dc74d3d15c1d920165d38 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 11 Jun 2025 18:14:26 +0200 Subject: [PATCH 039/123] remove aux_dropout from pytorch workloads --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 +--- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +--- .../imagenet_vit/imagenet_pytorch/workload.py | 4 +--- .../librispeech_pytorch/workload.py | 11 +++-------- .../librispeech_pytorch/workload.py | 11 +++-------- algoperf/workloads/ogbg/ogbg_pytorch/workload.py | 5 +---- algoperf/workloads/wmt/wmt_pytorch/workload.py | 6 ++---- 8 files changed, 13 insertions(+), 36 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 726aa8705..638022a5e 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -68,10 +68,8 @@ def loss_fn( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..9582325e1 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -108,9 +108,7 @@ def _build_input_queue(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..372cac7fa 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -157,11 +157,9 @@ def _build_dataset( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate - del aux_dropout_rate torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 97bb38515..1a6bb1381 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -24,9 +24,7 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..39f33f4aa 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -64,13 +64,8 @@ def attention_temperature(self) -> float: def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Conformer model init function. - - Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as - input_dropout_rate. - """ + dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -86,7 +81,7 @@ def init_model_fn( attention_residual_dropout_rate=dropout_rate, feed_forward_residual_dropout_rate=dropout_rate, conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, + input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e5387f5cb..932ba9392 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -25,19 +25,14 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. - """ + dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, + input_dropout_rate=dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 45295ac7f..1dd85951d 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -139,10 +139,7 @@ def _build_input_queue(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d0716d6c8..64eea73b7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -168,9 +168,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" + dropout_rate: Optional[float] = None) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -182,7 +180,7 @@ def init_model_fn( model = Transformer( dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, + attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, From 872393770c6e64f1cf5cb6e6a2b0227c5e88685d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 11 Jun 2025 10:22:56 -0700 Subject: [PATCH 040/123] Update submission.py --- submissions/template/submission.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index a4fdc62b4..2269b7dbb 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -15,9 +15,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule. - Returns: - optimizer state - optimizer_update_fn + Returns: spec.OptimizerState initialized optimizer state """ pass @@ -37,9 +35,9 @@ def update_params( train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """ Returns: - (new_optimizer_state, update_fn) - new_params - new_model_state + spec.OptimizerState: new optimizer state + spec.ParameterTypeTree: new params + new_model_state: new model state """ pass From e0a0e624b7cd70fea51a054d85cd1b418076cdef Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 10:57:15 +0200 Subject: [PATCH 041/123] criteo rm dropout from init --- .../criteo1tb_pytorch/models_dropout.py | 21 +++-------- .../criteo1tb/criteo1tb_pytorch/workload.py | 5 +-- .../test_model_equivalence.py | 35 ++++++++----------- 3 files changed, 20 insertions(+), 41 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 065ebd1f8..2ac5c2d1b 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -8,6 +8,8 @@ from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" @@ -69,7 +71,6 @@ def __init__(self, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -79,10 +80,6 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -152,9 +149,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): batch_size = x.shape[0] @@ -196,7 +191,6 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -207,11 +201,6 @@ def __init__(self, self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim self.embedding_init_multiplier = embedding_init_multiplier - self.dropout_rate = dropout_rate - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -281,9 +270,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): batch_size = x.shape[0] diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 638022a5e..b128f5bd5 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -67,9 +67,7 @@ def loss_fn( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Only dropout is used.""" + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -83,7 +81,6 @@ def init_model_fn( mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index b9b1232ef..c4f074ff3 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -34,8 +34,6 @@ class ModelEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None), - dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None), dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), @@ -50,27 +48,24 @@ def test_forward(self, model, dropout_rate): else (OriginalDlrmSmall, CustomDlrmSmall) ) - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) + orig.to(DEVICE) - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) - orig.to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB) + cust.to(DEVICE) - torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate) - cust.to(DEVICE) - - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - x = torch.randn(BATCH, FEATURES, device=DEVICE) + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1 = orig(x) - torch.manual_seed(SEED); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': From 1e2f379a955bbd75c57042cb6b73750fd3f06eb7 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:07:04 +0200 Subject: [PATCH 042/123] criteo rm dropout from init --- .../criteo1tb_pytorch/test_model_equivalence.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index c4f074ff3..f40f7b3df 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -63,10 +63,15 @@ def test_forward(self, model, dropout_rate): for mode in ('train', 'eval'): getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1 = orig(x) - torch.manual_seed(SEED); y2 = cust(x, dropout_rate) + torch.manual_seed(SEED) + y1 = orig(x) + torch.manual_seed(SEED) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) - + if __name__ == '__main__': absltest.main() From f10e3dc6bdebdf87a3460b600c729c63512f7397 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:11:26 +0200 Subject: [PATCH 043/123] criteo rm dropout from init --- .../dropout_fix/criteo1tb_pytorch/test_model_equivalence.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index f40f7b3df..32aa5b34f 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -62,7 +62,9 @@ def test_forward(self, model, dropout_rate): x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) y1 = orig(x) torch.manual_seed(SEED) @@ -70,6 +72,7 @@ def test_forward(self, model, dropout_rate): y2 = cust(x, dropout_rate) else: y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) From 027b053e838a57e5a1c33389d0b8f1b0d4269aa2 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:25:31 +0200 Subject: [PATCH 044/123] criteo rm dropout from init --- .../workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index 2ac5c2d1b..b5ee465e2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -3,7 +3,6 @@ import math import torch -import torch.nn.functional as F from torch import nn from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout @@ -30,7 +29,7 @@ def __init__(self, module, resnet=False): self.resnet = resnet self._supports_custom_dropout = True - def forward(self, x, p=None): + def forward(self, x, p): return self.module(x, p) + x if self.resnet else self.module(x, p) From 74c43aa20e3c61b488234014816f1bffeec1508d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:26:04 +0200 Subject: [PATCH 045/123] fastmri rm dropout from init --- .../fastmri/fastmri_pytorch/models_dropout.py | 12 ++----- .../fastmri/fastmri_pytorch/workload.py | 6 ++-- .../fastmri_pytorch/test_model_equivalence.py | 32 ++++++++----------- 3 files changed, 19 insertions(+), 31 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 260cb7e44..73b1d81d9 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -15,6 +15,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -29,7 +30,6 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -38,10 +38,6 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - self.dropout_rate = 0.0 - else: - self.dropout_rate = dropout_rate self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, @@ -78,9 +74,7 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor: stack = [] output = x @@ -145,7 +139,7 @@ def __init__(self, CustomDropout2d(), ) - def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor: + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: return self.conv_layers(x, dropout_rate) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 9582325e1..6da0bb0af 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -107,15 +107,13 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + use_layer_norm=self.use_layer_norm) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 6339ff21b..c71ff8980 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -33,36 +33,32 @@ def fwd_pass(self, orig, cust, dropout_rate): torch.manual_seed(0) y1 = orig(x) torch.manual_seed(0) - y2 = cust(x, dropout_rate) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_dropout_values(self, dropout_rate): """Test different values of dropout_rate.""" - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate - ).to(DEVICE) + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate - ).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) - self.fwd_pass(orig, cust, dropout_rate) + self.fwd_pass(orig, cust, dropout_rate) @parameterized.named_parameters( @@ -71,7 +67,7 @@ def test_dropout_values(self, dropout_rate): dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), dict(testcase_name='both', use_tanh=True, use_layer_norm=True), ) - def test_arch_setups(self, use_tanh, use_layer_norm): + def test_arch_configs(self, use_tanh, use_layer_norm): """Test different architecture configurations, fixed dropout_rate.""" dropout_rate = 0.1 From 64276ef67cb0ab56bac998c665c82d62e7722087 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:44:11 +0200 Subject: [PATCH 046/123] vit rm dropout at init --- .../imagenet_pytorch/models_dropout.py | 72 +++++++------------ .../wmt/wmt_pytorch/models_dropout.py | 2 +- .../test_model_equivalence.py | 72 ++++++++++++------- 3 files changed, 72 insertions(+), 74 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index f5e315fd7..8641847b0 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -14,7 +14,9 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention + +DEFAULT_DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -41,14 +43,12 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + use_glu: bool = False) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') @@ -69,9 +69,7 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: x = self.linear1(x) x = self.act_fnc(x) @@ -90,8 +88,7 @@ class SelfAttention(nn.Module): def __init__(self, width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + num_heads: int = 8) -> None: super().__init__() self.width = width @@ -102,7 +99,6 @@ def __init__(self, self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.dropout_rate = dropout_rate self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) @@ -122,9 +118,7 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: mixed_query_layer = self.query(x) @@ -136,7 +130,7 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -154,8 +148,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.width = width @@ -163,7 +156,6 @@ def __init__(self, self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dropout_rate = dropout_rate self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) @@ -171,32 +163,29 @@ def __init__(self, self.mlp3 = MlpBlock( width=self.width, mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + use_glu=self.use_glu) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) + y = self.self_attention1(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) + y = self.mlp3(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) + y = self.self_attention1(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) + y = self.mlp3(y, dropout_rate) y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) @@ -212,8 +201,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.depth = depth @@ -228,8 +216,7 @@ def __init__(self, self.mlp_dim, self.num_heads, self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) + self.use_post_layer_norm) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -237,10 +224,10 @@ def __init__(self, else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -267,13 +254,13 @@ def __init__(self, self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -293,15 +280,12 @@ def __init__( mlp_dim: Optional[int] = None, # Defaults to 4x input dim. num_heads: int = 12, rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, use_glu: bool = False, use_post_layer_norm: bool = False, use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -315,7 +299,6 @@ def __init__( self.use_post_layer_norm = use_post_layer_norm self.use_map = use_map self.dtype = dtype - self.dropout_rate = dropout_rate if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size @@ -334,8 +317,7 @@ def __init__( mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + use_post_layer_norm=self.use_post_layer_norm) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -363,9 +345,7 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) @@ -379,10 +359,10 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor: x = x + pes x = F.dropout(x, dropout_rate, training=self.training) - x = self.encoder(x) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index c5014d87d..6e265cd7f 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -861,7 +861,7 @@ def forward(self, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index 56644f152..32db2e7d4 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -32,41 +32,41 @@ def fwd_pass(self, orig, cust, dropout_rate): for mode in ('train', 'eval'): getattr(orig, mode)() getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x, dropout_rate) + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + if mode == 'train': + y2 = cust(x, dropout_rate) + else: + y2 = cust(x) assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.1', dropout_rate=0.1), dict(testcase_name='p=0.6', dropout_rate=0.6), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_dropout_values(self, dropout_rate): - """Test different dropout_values.""" - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [dropout_rate, None]: - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=custom_init_dropout_rate, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) + """Test different dropout_rates.""" + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) @parameterized.named_parameters([ @@ -101,12 +101,30 @@ def test_arch(self, use_glu, use_post_ln, use_map): use_glu=use_glu, use_post_layer_norm=use_post_ln, use_map=use_map, - dropout_rate=None, ).to(DEVICE) cust.load_state_dict(orig.state_dict()) # sync weights self.fwd_pass(orig, cust, dropout_rate) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + if __name__ == '__main__': absltest.main() From 44029d2ef1b8ac81556e90e40b119b642c2b4e41 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:44:16 +0200 Subject: [PATCH 047/123] vit rm dropout at init --- algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 1a6bb1381..20bd3828b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,11 +23,9 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( - dropout_rate=dropout_rate, num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, From 44ffec1d6821228eb839a0069b21daeaf32cb08c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:47:52 +0200 Subject: [PATCH 048/123] add default dropout test --- .../test_model_equivalence.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index 32aa5b34f..c59331a3d 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -23,7 +23,6 @@ FEATURES = DENSE + SPARSE VOCAB = 1000 DEVICE = 'cuda' -TORCH_COMPILE = False SEED = 1996 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" @@ -49,16 +48,11 @@ def test_forward(self, model, dropout_rate): ) torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) - orig.to(DEVICE) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB) - cust.to(DEVICE) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): @@ -75,6 +69,29 @@ def test_forward(self, model, dropout_rate): assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), + dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + ) + def test_default_dropout(self, model): + """Test default dropout_rate.""" + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB).to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From 9d12fa65b1963ab1b77d1cd8787ddad380bf803a Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 11:49:55 +0200 Subject: [PATCH 049/123] add default dropout test --- .../fastmri_pytorch/test_model_equivalence.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index c71ff8980..6c8ca896c 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -15,7 +15,7 @@ BATCH, IN_CHANS, H, W = 4, 1, 256, 256 OUT_CHANS, C, LAYERS = 1, 32, 4 DEVICE = 'cuda' -TORCH_COMPILE = True +TORCH_COMPILE = False SEED = 1996 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" @@ -89,6 +89,24 @@ def test_arch_configs(self, use_tanh, use_layer_norm): self.fwd_pass(orig, cust, dropout_rate) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From ac45a9fc07aab0b3af28d06ca7540acb06bcc561 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 12:36:14 +0200 Subject: [PATCH 050/123] conformer: rm dropout_rate from init --- .../librispeech_pytorch/models_dropout.py | 41 ++----- .../librispeech_pytorch/workload.py | 7 +- .../test_model_equivalence.py | 106 +++++++++++------- 3 files changed, 74 insertions(+), 80 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index 9ff662fb8..f77c8a814 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -17,6 +17,11 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug +DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1 +DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0 +DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1 +DEFAULT_INPUT_DROPOUT_RATE = 0.1 + @dataclass class ConformerConfig: @@ -26,13 +31,7 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.0 feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -42,8 +41,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -81,11 +78,9 @@ class Subsample(nn.Module): def __init__(self, encoder_dim: int = 0, - input_dropout_rate: float = 0.0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim) @@ -100,7 +95,7 @@ def __init__(self, def forward(self, inputs, input_paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.input_dropout_rate + dropout_rate = DEFAULT_INPUT_DROPOUT_RATE output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -207,14 +202,9 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - if config.feed_forward_residual_dropout_rate is None: - self.feed_forward_residual_dropout_rate = 0.1 - else: - self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate - def forward(self, inputs, padding_mask, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.feed_forward_residual_dropout_rate + dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE inputs = self.ln(inputs) inputs = self.linear1(inputs) @@ -319,14 +309,10 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - self.attention_residual_dropout_rate = 0.1 - else: - self.attention_residual_dropout_rate = config.attention_residual_dropout_rate def forward(self, outputs, paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.attention_residual_dropout_rate + dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE outputs = self.ln(outputs) outputs = self.self_attention( @@ -413,14 +399,10 @@ def __init__(self, config): groups=config.encoder_dim) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - self.conv_residual_dropout_rate = 0.0 - else: - self.conv_residual_dropout_rate = config.conv_residual_dropout_rate def forward(self, inputs, input_paddings, dropout_rate=None): if dropout_rate is None: - dropout_rate = self.conv_residual_dropout_rate + dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE inputs = self.ln(inputs) @@ -490,13 +472,8 @@ def __init__(self, config: ConformerConfig): time_masks_per_frame=config.time_masks_per_frame, use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 39f33f4aa..2d0942fe9 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -63,8 +63,7 @@ def attention_temperature(self) -> float: def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. @@ -78,10 +77,6 @@ def init_model_fn( activation_function_name = 'swish' model = models.ConformerEncoderDecoder( models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index 19525a98b..ec8318c9a 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -16,14 +16,15 @@ import os from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( - # ConformerConfig, - ConformerEncoderDecoder as OriginalConf + ConformerConfig as OriginalConfig, + ConformerEncoderDecoder as OriginalModel ) from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( - ConformerEncoderDecoder as CustomConf, - ConformerConfig, + ConformerConfig as CustomConfig, + ConformerEncoderDecoder as CustomModel, ) +N_LAYERS = 3 B, T = 32, 36_000 DEVICE = 'cuda' @@ -37,55 +38,76 @@ class ConformerEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.2', dropout_rate=0.2), dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_encoder_layers=N_LAYERS, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel( + CustomConfig( + num_encoder_layers=N_LAYERS + ) + ).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() torch.manual_seed(SEED) - orig = OriginalConf( - ConformerConfig( - num_encoder_layers=3, - attention_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - )).to(DEVICE) - + y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - cust = CustomConf( - ConformerConfig( - num_encoder_layers=3, - attention_residual_dropout_rate=custom_init_dropout_rate, - conv_residual_dropout_rate=custom_init_dropout_rate, - feed_forward_residual_dropout_rate=custom_init_dropout_rate, - input_dropout_rate=custom_init_dropout_rate - )).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - - torch.manual_seed(SEED) + if mode == 'train': y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) + else: + y2, p2 = cust(x, paddings) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() From 31d64f6c3228eb410065b212994d3f02883d19ee Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Jun 2025 13:36:15 +0200 Subject: [PATCH 051/123] rm dropout_rate at init from all workloads --- .../fastmri/fastmri_pytorch/models_dropout.py | 5 +- .../imagenet_pytorch/models_dropout.py | 5 +- .../librispeech_pytorch/models_dropout.py | 24 +---- .../librispeech_pytorch/workload.py | 5 +- .../ogbg/ogbg_pytorch/models_dropout.py | 15 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 4 +- .../wmt/wmt_pytorch/models_dropout.py | 44 ++++---- .../workloads/wmt/wmt_pytorch/workload.py | 5 +- .../test_model_equivalence.py | 17 ++- .../fastmri_pytorch/test_model_equivalence.py | 15 ++- .../test_model_equivalence.py | 15 ++- .../test_model_equivalence.py | 12 ++- .../test_model_equivalence.py | 100 +++++++++++------- .../ogbg_pytorch/test_model_equivalence.py | 55 +++++++--- .../wmt_pytorch/test_model_equivalence.py | 97 +++++++++++------ 15 files changed, 234 insertions(+), 184 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 73b1d81d9..0e59e1436 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -74,7 +74,10 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor: + def forward( + self, + x: Tensor, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: stack = [] output = x diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index 8641847b0..570cee575 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -345,7 +345,10 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor: + def forward( + self, + x: spec.Tensor, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index 8797aa578..21a4df614 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -17,6 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DEFAULT_DROPOUT_RATE = 0.1 @dataclass @@ -38,10 +39,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -87,14 +84,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - self.input_dropout_rate = 0.1 - else: - self.input_dropout_rate = config.input_dropout_rate - - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.input_dropout_rate + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -207,14 +197,8 @@ def __init__(self, config: DeepspeechConfig): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - self.feed_forward_dropout_rate = 0.1 - else: - self.feed_forward_dropout_rate = config.feed_forward_dropout_rate - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.feed_forward_dropout_rate + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: @@ -367,7 +351,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE): outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 932ba9392..e6ec4764f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -24,15 +24,12 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=dropout_rate, use_tanh=self.use_tanh, enable_residual_connections=self.enable_residual_connections, enable_decoder_layer_norm=self.enable_decoder_layer_norm, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index b86b88caa..c8ed23dda 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -11,6 +11,8 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DEFAULT_DROPOUT_RATE = 0.1 + def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" @@ -34,7 +36,6 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, activation_fn_name: str = 'relu', latent_dim: int = 256, hidden_dims: Tuple[int] = (256,), @@ -44,8 +45,6 @@ def __init__(self, self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - self.dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -94,9 +93,10 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + def forward( + self, + graph: GraphsTuple, + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor: graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], @@ -148,7 +148,7 @@ def __init__(self, self.update_global_fn = update_global_fn self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -161,6 +161,7 @@ def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 1dd85951d..7ead696ce 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -138,12 +138,10 @@ def _build_input_queue(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, - dropout_rate=dropout_rate, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, num_message_passing_steps=self.num_message_passing_steps, diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index 6e265cd7f..a5d822669 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -9,6 +9,8 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ +DEFAULT_DROPOUT_RATE = 0.1 + def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. @@ -104,17 +106,12 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, attention_temp: float = 1.0, pre_ln: bool = True): super().__init__() - if dropout_rate is None: - self.dropout_rate = 0.1 - else: - self.dropout_rate = dropout_rate self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, @@ -159,7 +156,7 @@ def forward(self, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, decode: bool = False, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -169,7 +166,7 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool - dropout_rate: Optional[float] + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] @@ -177,9 +174,6 @@ def forward(self, if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') - if dropout_rate is None: - dropout_rate = self.dropout_rate - memory = self.encoder( src, inputs_positions=inputs_positions, @@ -234,13 +228,13 @@ def __init__(self, def forward(self, src: Tensor, mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability (optional) + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -293,7 +287,7 @@ def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) @@ -339,7 +333,7 @@ def forward( decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -389,7 +383,7 @@ def forward( inputs_positions: Optional[Tensor] = None, decode: bool = False, cache: Optional[Dict[str, Dict[str, Tensor]]] = None, - dropout_rate: Optional[float] = None + dropout_rate: Optional[float] = 0.0 ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +391,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -490,7 +484,7 @@ def __init__(self, def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. Args: @@ -514,14 +508,14 @@ def forward(self, def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -585,7 +579,7 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -705,7 +699,7 @@ def forward( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -757,7 +751,7 @@ def _sa_block( # pylint: disable=arguments-renamed max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: + dropout_rate: Optional[float] = 0.0) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, @@ -771,7 +765,7 @@ def _sa_block( # pylint: disable=arguments-renamed # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.multihead_attn( x, mem, @@ -781,7 +775,7 @@ def _mha_block(self, x: Tensor, mem: Tensor, # Feed forward block. def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = None) -> Tensor: + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -861,7 +855,7 @@ def forward(self, max_len: Optional[int] = None, cache: Optional[dict] = None, index: Optional[int] = None, - dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?! + dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index 64eea73b7..bb9c3834f 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -167,8 +167,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -179,8 +178,6 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model = Transformer( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate, pre_ln=self.pre_ln, attention_temp=self.attention_temp, activation=activation, diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index c59331a3d..db56b17cf 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -56,18 +56,13 @@ def test_forward(self, model, dropout_rate): x = torch.randn(BATCH, FEATURES, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1 = orig(x) - torch.manual_seed(SEED) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) - + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1 = orig(x) + torch.manual_seed(SEED); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 6c8ca896c..0d3d52980 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -28,16 +28,13 @@ class FastMRIModeEquivalenceTest(parameterized.TestCase): def fwd_pass(self, orig, cust, dropout_rate): x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='p=0.0', dropout_rate=0.0), diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index 32db2e7d4..d19fad0ba 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -30,16 +30,13 @@ class ImageNetVitModeEquivalenceTest(parameterized.TestCase): def fwd_pass(self, orig, cust, dropout_rate): x = torch.randn(BATCH, C, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - if mode == 'train': - y2 = cust(x, dropout_rate) - else: - y2 = cust(x) + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(0); y1 = orig(x) + torch.manual_seed(0); y2 = cust(x, dropout_rate) assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0); y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name='p=0.0', dropout_rate=0.0), diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index ec8318c9a..a4238bbc9 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -74,14 +74,16 @@ def test_forward(self, dropout_rate): torch.manual_seed(SEED) y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - if mode == 'train': - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - else: - y2, p2 = cust(x, paddings) - + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) @parameterized.named_parameters( dict(testcase_name=''), diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index e31f4a7eb..acdc8c5b3 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -14,11 +14,12 @@ import os from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( - DeepspeechEncoderDecoder as OriginalModel + DeepspeechEncoderDecoder as OriginalModel, + DeepspeechConfig as OriginalConfig ) from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( DeepspeechEncoderDecoder as CustomModel, - DeepspeechConfig, + DeepspeechConfig as CustomConfig ) B, T = 32, 30_000 @@ -35,55 +36,82 @@ class DeepSpeechEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=None', dropout_rate=None), dict(testcase_name='p=0.0', dropout_rate=0.0), dict(testcase_name='p=0.2', dropout_rate=0.2), dict(testcase_name='p=0.7', dropout_rate=0.7), dict(testcase_name='p=1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + """Test different dropout_rate values.""" + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( + num_lstm_layers=2, + num_ffn_layers=2, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) - orig = OriginalModel( - DeepspeechConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=dropout_rate, - feed_forward_dropout_rate=dropout_rate, - )).to(DEVICE) + y1, p1 = orig(x, paddings) torch.manual_seed(SEED) - cust = CustomModel(DeepspeechConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=custom_init_dropout_rate, - feed_forward_dropout_rate=custom_init_dropout_rate, - )).to(DEVICE) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - orig.load_state_dict(cust.state_dict()) # sync weights - - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - + y2, p2 = cust(x, paddings) assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig); cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)(); getattr(cust, mode)() + torch.manual_seed(SEED); y1, p1 = orig(x, paddings) + torch.manual_seed(SEED); y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py index cc1857705..aaca6cebd 100644 --- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -42,35 +42,62 @@ def _rand_graph(): class GNNEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='None', dropout_rate=None), dict(testcase_name='0.0', dropout_rate=0.0), dict(testcase_name='0.2', dropout_rate=0.2), dict(testcase_name='0.7', dropout_rate=0.7), dict(testcase_name='1.0', dropout_rate=1.0), ) def test_forward(self, dropout_rate): + """Test different dropout_rates.""" - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights - orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + graph = _rand_graph() - graph = _rand_graph() + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(graph) + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(graph, dropout_rate=dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name=''), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(graph) + + assert_close(y1, y2, atol=0, rtol=0) + + if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py index 9aca717d9..9675f1df2 100644 --- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -32,52 +32,79 @@ def _rand_tokens(bs, seqlen): class TransformerEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention - - dict(testcase_name="None", dropout_rate=None, compile=False), + # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention dict(testcase_name="0.0", dropout_rate=0.0, compile=False), dict(testcase_name="0.2", dropout_rate=0.2, compile=False), dict(testcase_name="0.7", dropout_rate=0.7, compile=False), - - dict(testcase_name="p=None, compile", dropout_rate=None, compile=True), - dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True), - dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True), - dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True), + dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), ) - def test_forward(self, dropout_rate, compile): - - # Test initalizing custom model with a None dropout_rate - for custom_init_dropout_rate in [None, dropout_rate]: - - orig = OriginalModel( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate - ).to(DEVICE) - cust = CustomModel( - dropout_rate=custom_init_dropout_rate - ).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) + def test_dropout_value(self, dropout_rate, compile): - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate + ).to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(src, tgt) + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(src, tgt, dropout_rate=dropout_rate) - + y2 = cust(src, tgt) assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name="default", compile=False), + dict(testcase_name="default_compile", compile=True), + ) + def test_default(self, compile): + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) + y2 = cust(src, tgt) + + assert_close(y1, y2, atol=0, rtol=0) + + if __name__ == "__main__": absltest.main() From 5e192dd6397194d1435631752a1c29289b9a9888 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:02:11 +0000 Subject: [PATCH 052/123] remove dropout_rate from init_model_fn for all jax workloads --- .../workloads/cifar/cifar_jax/workload.py | 6 +-- .../criteo1tb/criteo1tb_jax/workload.py | 34 ++++++----------- .../workloads/fastmri/fastmri_jax/workload.py | 27 +++++-------- .../imagenet_resnet/imagenet_jax/workload.py | 7 +--- .../imagenet_vit/imagenet_jax/workload.py | 28 +++++--------- .../librispeech_jax/workload.py | 28 +++++--------- .../librispeech_jax/workload.py | 38 ++++++------------- .../workloads/mnist/mnist_jax/workload.py | 7 +--- algoperf/workloads/ogbg/ogbg_jax/workload.py | 29 +++++--------- algoperf/workloads/wmt/wmt_jax/workload.py | 26 ++++--------- 10 files changed, 71 insertions(+), 159 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..3f2397f8c 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -81,12 +81,8 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 101e02c15..dcb7b9a57 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -72,7 +72,6 @@ def loss_fn( def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" @@ -81,27 +80,16 @@ def init_model_fn( else: model_class = models.DlrmSmall - if dropout_rate is None: - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) - else: - self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) - - params_rng, dropout_rng = jax.random.split(rng) + self._model = model_class( + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) + + params_rng, _= jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 @@ -109,7 +97,7 @@ def init_model_fn( input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( - {'params': params_rng, 'dropout': dropout_rng}, + {'params': params_rng,}, jnp.ones(input_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 3d891cf8f..bf0acfc8d 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -21,28 +21,19 @@ class FastMRIWorkload(BaseFastMRIWorkload): def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" fake_batch = jnp.zeros((13, 320, 320)) - if dropout_rate is None: - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, + self._model = UNet( + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, ) - else: - self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) - - params_rng, dropout_rng = jax.random.split(rng) - variables = jax.jit( - self._model.init)({'params': params_rng, 'dropout': dropout_rng}, + + params_rng, _ = jax.random.split(rng) + init_fn = functools.partial(self._model.init, train=False) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..2a255fee4 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -84,12 +84,7 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + rng: spec.RandomState,) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 89355ac6e..b8a870de5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -23,32 +23,22 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - params_rng, dropout_rng = jax.random.split(key) + params_rng, _ = jax.random.split(key) variables = jax.jit( - model.init)({'params': params_rng, 'dropout': dropout_rng}, + model.init)({'params': params_rng}, jnp.ones(input_shape)) model_state, params = pop(variables, "params") return params, model_state def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - if dropout_rate is None: - self._model = models.ViT( - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) - else: - self._model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + rng: spec.RandomState) -> spec.ModelInitState: + self._model = models.ViT( + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 2e082cf07..042dba7f4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -60,7 +60,6 @@ def attention_temperature(self) -> float: def init_model_fn( self, rng: spec.RandomState, - dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: """Conformer model init function. @@ -71,21 +70,14 @@ def init_model_fn( activation_function_name = 'gelu' else: activation_function_name = 'swish' - if dropout_rate is None: - model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) - else: - model_config = models.ConformerConfig( - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + model_config = models.ConformerConfig( + attention_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name) self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] @@ -93,8 +85,8 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) model_state, params = pop(variables, "params") diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 825b470db..2213f189e 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -17,40 +17,26 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function. """ - if dropout_rate is None: - model_config = models.DeepspeechConfig( - use_specaug=self.use_specaug, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) - else: - model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, - ) + model_config = models.DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng,}, *fake_input_batch) model_state = variables[ diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..5f3fdcf78 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -34,12 +34,7 @@ class MnistWorkload(BaseMnistWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() initial_params = self._model.init({'params': rng}, init_val, diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index 3becd5599..aaa5b4064 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -19,25 +19,14 @@ class OgbgWorkload(BaseOgbgWorkload): def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - rng, params_rng, dropout_rng = jax.random.split(rng, 3) - if dropout_rate is None: - self._model = models.GNN( - self._num_outputs, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) - else: - self._model = models.GNN( - self._num_outputs, - dropout_rate=dropout_rate, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + rng: spec.RandomState) -> spec.ModelInitState: + rng, params_rng = jax.random.split(rng, 2) + self._model = models.GNN( + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), @@ -47,7 +36,7 @@ def init_model_fn( globals=jnp.zeros((1, self._num_outputs)), senders=jnp.asarray([0]), receivers=jnp.asarray([0])) - params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) + params = init_fn({'params': params_rng}, fake_batch) params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 193732640..9e109dc86 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -208,8 +208,7 @@ def translate_and_calculate_bleu(self, def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState: + rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -221,26 +220,17 @@ def init_model_fn( else: raise ValueError(f'Unknown activation function {self.activation}.') - if dropout_rate is None: - model_config = models.TransformerConfig( - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) - else: - model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + model_config = models.TransformerConfig( + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - params_rng, dropout_rng = jax.random.split(rng) + params_rng, _ = jax.random.split(rng) initial_variables = jax.jit( - self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, + self._eval_model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) From 23828cdb0d54207a6714263b4d9f44531011f375 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:03:11 +0000 Subject: [PATCH 053/123] remove dropout from model initialization call in submission_runner.py --- submission_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index bb4a8c6cc..d076a1043 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,11 +228,8 @@ def train_once( global_batch_size=global_batch_size) logging.info('Initializing model.') with profiler.profile('Initializing model'): - dropout_rate = None - if hasattr(hyperparameters, 'dropout_rate'): - dropout_rate = hyperparameters.dropout_rate model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate) + model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', From 86b86245a3754d59b8e707eecacd3be0477419d7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 12 Jun 2025 21:28:35 +0000 Subject: [PATCH 054/123] remove dropout check for None and use default instead if not passed --- .../criteo1tb/criteo1tb_jax/models.py | 11 +++---- .../workloads/fastmri/fastmri_jax/models.py | 12 +++----- .../imagenet_vit/imagenet_jax/models.py | 26 +++++------------ .../librispeech_jax/models.py | 27 +++++------------ .../librispeech_jax/models.py | 21 ++++---------- algoperf/workloads/ogbg/ogbg_jax/models.py | 7 ++--- algoperf/workloads/wmt/wmt_jax/models.py | 29 +++++++------------ 7 files changed, 41 insertions(+), 92 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index b7af15208..57cb7f2d9 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 class DLRMResNet(nn.Module): """Define a DLRMResNet model. @@ -24,14 +25,12 @@ class DLRMResNet(nn.Module): mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -155,9 +154,7 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index b04510297..5850defa7 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -21,6 +21,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -58,15 +59,12 @@ class UNet(nn.Module): num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_tanh: bool = False use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate - + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): # pylint: disable=invalid-name _ConvBlock = functools.partial( ConvBlock, @@ -144,7 +142,7 @@ class ConvBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, train=True, dropout_rate=None): + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -153,8 +151,6 @@ def __call__(self, x, train=True, dropout_rate=None): Returns: jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ - if dropout_rate is None: - dropout_rate = self.dropout_rate x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 8ffc0b610..7c5d7bd26 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -13,6 +13,7 @@ from algoperf import spec from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 def posemb_sincos_2d(h: int, w: int, @@ -36,17 +37,14 @@ class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE @nn.compact def __call__(self, x: spec.Tensor, train: bool = True, - dropout_rate=None) -> spec.Tensor: + dropout_rate=DROPOUT_RATE) -> spec.Tensor: """Applies Transformer MlpBlock module.""" - if dropout_rate is None: - dropout_rate = self.dropout_rate - inits = { 'kernel_init': nn.initializers.xavier_uniform(), 'bias_init': nn.initializers.normal(stddev=1e-6), @@ -78,8 +76,6 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) @@ -136,11 +132,7 @@ class Encoder(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - train: bool = True, - dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate - + train: bool = True) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( @@ -165,9 +157,7 @@ class MAPHead(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), @@ -194,7 +184,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -212,9 +202,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, - dropout_rate=None) -> spec.Tensor: - if dropout_rate is None: - dropout_rate = self.dropout_rate + dropout_rate=DROPOUT_RATE) -> spec.Tensor: # Patch extraction x = nn.Conv( self.width, diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 2d0da15e5..f7beed914 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -28,6 +28,7 @@ spectrum_augmenter from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.1 @struct.dataclass class ConformerConfig: @@ -37,11 +38,7 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - dropout_rate: float = 0.1 - attention_residual_dropout_rate: Optional[float] = 0.0 - conv_residual_dropout_rate: Optional[float] = 0.0 - feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: Optional[float] = 0.0 + dropout_rate: float = DROPOUT_RATE convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -96,12 +93,8 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 - dropout_rate: float = 0.0 - @nn.compact - def __call__(self, inputs, input_paddings, train, dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.dropout_rate + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) @@ -196,7 +189,7 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None): + def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): config = self.config if dropout_rate is None: dropout_rate = config.dropout_rate @@ -388,10 +381,8 @@ class MultiHeadedSelfAttention(nn.Module): config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train, dropout_rate=None): + def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( @@ -527,10 +518,8 @@ def __call__(self, train, update_batch_norm, use_running_average_bn, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( @@ -603,7 +592,7 @@ def __call__(self, train, update_batch_norm, use_running_average, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -658,7 +647,7 @@ def __call__(self, train, update_batch_norm: Optional[bool] = None, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[float] = None): + dropout_rate: float = DROPOUT_RATE: config = self.config outputs = inputs diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 455366e5e..84ba58ee2 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -1,4 +1,4 @@ -r"""Deepspeech. +"""Deepspeech. This model uses a deepspeech2 network to convert speech to text. paper : https://arxiv.org/abs/1512.02595 @@ -31,6 +31,8 @@ CarryHistory = Any Output = Any +DROPOUT_RATE=0.1 + @struct.dataclass class DeepspeechConfig: @@ -52,10 +54,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -73,11 +71,8 @@ class Subsample(nn.Module): config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train, dropout_rate=None): + def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate - outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( @@ -196,9 +191,7 @@ def __call__(self, inputs, input_paddings=None, train=False, - dropout_rate=None): - if dropout_rate is None: - dropout_rate = self.config.feed_forward_dropout_rate + dropout_rate=DROPOUT_RATE): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config @@ -479,10 +472,8 @@ def setup(self): ) @nn.compact - def __call__(self, inputs, input_paddings, train, dropout_rate=None): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index f6cb1c490..59d989284 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -8,6 +8,7 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE=0.1 def _make_embed(latent_dim, name): @@ -41,15 +42,11 @@ class GNN(nn.Module): num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) - # If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train, dropout_rate=None): - if dropout_rate is not None: - dropout_rate = self.dropout_rate + def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 3947a1b81..e262214ac 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -14,6 +14,8 @@ from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.1 + @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -28,7 +30,6 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -148,11 +149,9 @@ class MlpBlock(nn.Module): out_dim: Optional[int] = None @nn.compact - def __call__(self, inputs, dropout_rate=None): + def __call__(self, inputs, dropout_rate=DROPOUT_RATE): """Applies Transformer MlpBlock module.""" cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( @@ -195,7 +194,7 @@ class Encoder1DBlock(nn.Module): config: TransformerConfig @nn.compact - def __call__(self, inputs, encoder_mask=None, dropout_rate=None): + def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): """Applies Encoder1DBlock module. Args: @@ -206,8 +205,6 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None): output after transformer encoder block. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate pre_ln = cfg.pre_ln @@ -254,7 +251,7 @@ def __call__( encoded, decoder_mask=None, encoder_decoder_mask=None, - dropout_rate=None, + dropout_rate=DROPOUT_RATE, ): """Applies EncoderDecoder1DBlock module. @@ -268,8 +265,6 @@ def __call__( output after transformer encoder-decoder block. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate pre_ln = cfg.pre_ln @@ -337,7 +332,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. Args: @@ -349,8 +344,6 @@ def __call__(self, output of a transformer encoder. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate assert inputs.ndim == 2 # (batch, len) @@ -403,7 +396,7 @@ def __call__( targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, - dropout_rate=None, + dropout_rate=DROPOUT_RATE, ): """Applies Transformer model on the inputs. @@ -418,8 +411,6 @@ def __call__( output of a transformer decoder. """ cfg = self.config - if dropout_rate is None: - dropout_rate = cfg.dropout_rate assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) @@ -495,7 +486,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer encoder-branch on the inputs. Args: @@ -533,7 +524,7 @@ def decode( targets_positions=None, inputs_segmentation=None, targets_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -593,7 +584,7 @@ def __call__(self, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, - dropout_rate=None): + dropout_rate=DROPOUT_RATE): """Applies Transformer model on the inputs. Args: From 0128c9fe4caa2590c60af6c7276ab6789d06ae6d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 13 Jun 2025 13:50:00 +0200 Subject: [PATCH 055/123] pipe dropout to model_fn, set default in workload --- algoperf/spec.py | 6 +++-- .../criteo1tb_pytorch/models_dropout.py | 6 ++--- .../criteo1tb/criteo1tb_pytorch/workload.py | 6 +++-- .../fastmri/fastmri_pytorch/models_dropout.py | 4 +-- .../fastmri/fastmri_pytorch/workload.py | 9 ++++--- .../imagenet_pytorch/workload.py | 9 +++---- .../imagenet_pytorch/models_dropout.py | 4 +-- .../imagenet_vit/imagenet_pytorch/workload.py | 7 +++-- .../librispeech_pytorch/models_dropout.py | 27 +++++-------------- .../librispeech_pytorch/workload.py | 7 +++-- .../librispeech_pytorch/models_dropout.py | 4 +-- .../ogbg/ogbg_pytorch/models_dropout.py | 4 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 8 ++++-- .../wmt/wmt_pytorch/models_dropout.py | 4 +-- .../workloads/wmt/wmt_pytorch/workload.py | 8 ++++-- 15 files changed, 60 insertions(+), 53 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index cf4f1a14e..9670dcb76 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -247,7 +247,8 @@ def init_model_fn(self, # ModelAuxiliaryState, # ForwardPassMode, # RandomState, - # bool], + # bool, + # float], # Tensor] @abc.abstractmethod def model_fn(self, @@ -256,7 +257,8 @@ def model_fn(self, model_state: ModelAuxiliaryState, mode: ForwardPassMode, rng: RandomState, - update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]: """Return logits_batch""" # Possible side effect of updating BN. diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py index b5ee465e2..f0653a665 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py @@ -7,7 +7,7 @@ from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 class DenseBlock(nn.Module): @@ -148,7 +148,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] @@ -269,7 +269,7 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index b128f5bd5..74cb3e140 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -15,6 +15,7 @@ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() +DROPOUT_RATE = 0.0 class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @@ -103,7 +104,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -123,7 +125,7 @@ def model_fn( } with contexts[mode](): - logits_batch = model(inputs) + logits_batch = model(inputs, dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py index 0e59e1436..0b8ac5499 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py @@ -15,7 +15,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -77,7 +77,7 @@ def __init__(self, def forward( self, x: Tensor, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: + dropout_rate: float = DROPOUT_RATE) -> Tensor: stack = [] output = x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 6da0bb0af..6374c62d6 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -19,6 +19,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.0 + class FastMRIWorkload(BaseFastMRIWorkload): @@ -134,7 +136,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -154,8 +157,8 @@ def model_fn( with contexts[mode](): logit_batch = model( - augmented_and_preprocessed_input_batch['inputs'].unsqueeze( - 1)).squeeze(1) + augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), + dropout_rate=dropout_rate).squeeze(1) return logit_batch, None diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 372cac7fa..f28eb1762 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -156,10 +156,7 @@ def _build_dataset( def init_model_fn( self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate + rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: @@ -192,9 +189,11 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng + del dropout_rate model = params diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py index 570cee575..60e09edb5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py @@ -16,7 +16,7 @@ from algoperf import spec from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention -DEFAULT_DROPOUT_RATE = 0.0 +DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -348,7 +348,7 @@ def get_posemb(self, x: spec.Tensor) -> spec.Tensor: def forward( self, x: spec.Tensor, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor: + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 20bd3828b..8b011071a 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -16,6 +16,7 @@ from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.0 # Make sure we inherit from the ViT base workload first. @@ -51,7 +52,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -70,7 +72,8 @@ def model_fn( } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py index f77c8a814..a6a60bf95 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py @@ -17,10 +17,7 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug -DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1 -DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0 -DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1 -DEFAULT_INPUT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 @dataclass @@ -93,9 +90,7 @@ def __init__(self, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_INPUT_DROPOUT_RATE + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -202,9 +197,7 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - def forward(self, inputs, padding_mask, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE + def forward(self, inputs, padding_mask, dropout_rate): inputs = self.ln(inputs) inputs = self.linear1(inputs) @@ -310,10 +303,7 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - def forward(self, outputs, paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE - + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( outputs, @@ -400,10 +390,7 @@ def __init__(self, config): self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): - if dropout_rate is None: - dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE - + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -442,7 +429,7 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) @@ -481,7 +468,7 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=None): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 2d0942fe9..d99bc1608 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -24,6 +24,7 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() MAX_INPUT_LENGTH = 320000 +DROPOUT_RATE = 0.1 class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): @@ -105,7 +106,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -126,7 +128,8 @@ def model_fn( with contexts[mode](): inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE)) + input_paddings.to(DEVICE), + dropout_rate=dropout_rate) return (logits, logits_paddings), None def _build_input_queue( diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py index 21a4df614..a8480a343 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py @@ -17,7 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 @dataclass @@ -351,7 +351,7 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py index c8ed23dda..be5882333 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py @@ -11,7 +11,7 @@ from algoperf import init_utils from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 def _make_mlp(in_dim, hidden_dims, activation_fn): @@ -96,7 +96,7 @@ def __init__(self, def forward( self, graph: GraphsTuple, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor: + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 7ead696ce..281f4cd08 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -17,6 +17,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.1 + def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: @@ -166,7 +168,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. @@ -186,7 +189,8 @@ def model_fn( } with contexts[mode](): - logits = model(augmented_and_preprocessed_input_batch['inputs']) + logits = model(augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits, None diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py index a5d822669..a43df30d4 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py @@ -9,7 +9,7 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ -DEFAULT_DROPOUT_RATE = 0.1 +DROPOUT_RATE = 0.1 def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: @@ -156,7 +156,7 @@ def forward(self, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, decode: bool = False, - dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor: + dropout_rate: float = DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index bb9c3834f..d30abc4c7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -22,6 +22,8 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +DROPOUT_RATE = 0.1 + class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" @@ -202,7 +204,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -228,7 +231,8 @@ def model_fn( inputs_segmentation=augmented_and_preprocessed_input_batch.get( 'inputs_segmentation', None), targets_segmentation=augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None)) + 'targets_segmentation', None), + dropout_rate=dropout_rate) return logits_batch, None From a7cba1a1acc9da53b9500cab8755f49c88bdbb2c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 13 Jun 2025 14:19:36 +0200 Subject: [PATCH 056/123] remove aux_dropout from pytorch workloads --- .../test_model_equivalence.py | 32 +------------------ .../test_model_equivalence.py | 2 +- 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index a4238bbc9..4a1252a39 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -3,11 +3,7 @@ Run with: python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py -`dropout_rate` controls the following args: -- `attention_residual_dropout_rate` (if None, 0.1 -- `conv_residual_dropout_rate` (if None, 0.0) -- `feed_forward_residual_dropout_rate` (if None, 0.1) -- `input_dropout_rate` (if None, 0.1) +NOTE: we don't test for default dropout_rate values, since they changed. """ from absl.testing import absltest, parameterized @@ -85,31 +81,5 @@ def test_forward(self, dropout_rate): assert_close(y1, y2, atol=0, rtol=0) assert_close(p1, p2, atol=0, rtol=0) - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) - orig.load_state_dict(cust.state_dict()) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index acdc8c5b3..58ddb354e 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -24,7 +24,7 @@ B, T = 32, 30_000 DEVICE = 'cuda' -TORCH_COMPILE = True +TORCH_COMPILE = False os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" torch.backends.cudnn.benchmark = False From 05bff916dee7de6852afc6d95e2564ad57aa77ef Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 13 Jun 2025 20:45:13 +0000 Subject: [PATCH 057/123] fix to model_fn default dropout value --- algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- algoperf/workloads/fastmri/fastmri_jax/workload.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 2 +- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index dcb7b9a57..cb7e8cf9f 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -115,7 +115,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index bf0acfc8d..acdf077e1 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -52,7 +52,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index b8a870de5..08a8f4eb1 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -57,7 +57,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 042dba7f4..8d966ef87 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -109,7 +109,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[float] = None, + dropout_rate: Optional[float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2213f189e..2bb119439 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -57,7 +57,7 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool, use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[bool] = None + dropout_rate: Optional[bool] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index aaa5b4064..e03252ed9 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -53,7 +53,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 9e109dc86..9548f5b7e 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -250,7 +250,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm From d8e39b0da371abbcf311ce1d09e06439bd5a0eec Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Jun 2025 17:08:50 +0200 Subject: [PATCH 058/123] fix to model_fn default dropout_rate --- .../criteo1tb/criteo1tb_pytorch/workload.py | 3 +-- .../fastmri/fastmri_pytorch/workload.py | 5 ++--- .../imagenet_vit/imagenet_pytorch/workload.py | 3 +-- .../librispeech_pytorch/workload.py | 3 +-- .../librispeech_pytorch/workload.py | 17 ++++++++++++++++- .../workloads/ogbg/ogbg_pytorch/workload.py | 5 ++--- algoperf/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- 7 files changed, 25 insertions(+), 16 deletions(-) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 74cb3e140..48c6592f2 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -15,7 +15,6 @@ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() -DROPOUT_RATE = 0.0 class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @@ -105,7 +104,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 6374c62d6..9b96230fc 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -13,14 +13,13 @@ from algoperf import pytorch_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_pytorch import models from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.0 - class FastMRIWorkload(BaseFastMRIWorkload): @@ -137,7 +136,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 8b011071a..f86a1b1c2 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -16,7 +16,6 @@ from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.0 # Make sure we inherit from the ViT base workload first. @@ -53,7 +52,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index d99bc1608..0477a7389 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -24,7 +24,6 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() MAX_INPUT_LENGTH = 320000 -DROPOUT_RATE = 0.1 class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): @@ -107,7 +106,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e6ec4764f..bf345cfc9 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP @@ -6,6 +6,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ @@ -54,6 +55,20 @@ def init_model_fn( else: model = torch.nn.DataParallel(model) return model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + # override super method, changing only the default dropout_rate + return super().model_fn( + params, augmented_and_preprocessed_input_batch, model_state, + mode, rng, update_batch_norm, dropout_rate) def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 281f4cd08..758b36b60 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -12,13 +12,12 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch import models from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN from algoperf.workloads.ogbg.workload import BaseOgbgWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.1 - def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: @@ -169,7 +168,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d30abc4c7..4c787becc 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -17,13 +17,12 @@ from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch import models from algoperf.workloads.wmt.wmt_pytorch.models import Transformer from algoperf.workloads.wmt.workload import BaseWmtWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -DROPOUT_RATE = 0.1 - class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" @@ -205,7 +204,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm From 7a0015830e840211b8002f068b1eb918c8390f5c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Jun 2025 17:14:18 +0200 Subject: [PATCH 059/123] rm models_dropout torch files --- .../criteo1tb/criteo1tb_pytorch/models.py | 54 +- .../criteo1tb_pytorch/models_dropout.py | 297 ------ .../fastmri/fastmri_pytorch/models.py | 44 +- .../fastmri/fastmri_pytorch/models_dropout.py | 173 --- .../imagenet_vit/imagenet_pytorch/models.py | 84 +- .../imagenet_pytorch/models_dropout.py | 378 ------- .../librispeech_pytorch/models.py | 70 +- .../librispeech_pytorch/models_dropout.py | 482 --------- .../librispeech_pytorch/models.py | 32 +- .../librispeech_pytorch/models_dropout.py | 379 ------- .../workloads/ogbg/ogbg_pytorch/models.py | 35 +- .../ogbg/ogbg_pytorch/models_dropout.py | 315 ------ algoperf/workloads/wmt/wmt_pytorch/models.py | 181 ++-- .../wmt/wmt_pytorch/models_dropout.py | 989 ------------------ 14 files changed, 234 insertions(+), 3279 deletions(-) delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py delete mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7a40f0e81..f0653a665 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,20 +5,32 @@ import torch from torch import nn +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout + +DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" - def __init__(self, module, resnet=False): super().__init__() self.module = module self.resnet = resnet def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p): + return self.module(x, p) + x if self.resnet else self.module(x, p) class DotInteract(nn.Module): @@ -58,7 +70,6 @@ def __init__(self, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -116,17 +127,16 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) + block = DenseBlockWithDropout(block, resnet=True) else: - block = DenseBlock(block) + block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): @@ -138,7 +148,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -157,7 +168,7 @@ def forward(self, x): top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. - logits = self.top_mlp(top_mlp_input) + logits = self.top_mlp(top_mlp_input, dropout_rate) return logits @@ -179,7 +190,6 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0, use_layer_norm=False, embedding_init_multiplier=None): super().__init__() @@ -242,10 +252,9 @@ def __init__(self, top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -260,7 +269,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): + batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -283,5 +293,5 @@ def forward(self, x): dense_features=embedded_dense, sparse_features=embedded_sparse) # Final MLP. - logits = self.top_mlp(concatenated_dense) + logits = self.top_mlp(concatenated_dense, dropout_rate) return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py deleted file mode 100644 index f0653a665..000000000 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Pytorch implementation of DLRM-Small.""" - -import math - -import torch -from torch import nn - -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout - -DROPOUT_RATE = 0.0 - - -class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - - def forward(self, x): - return self.module(x) + x if self.resnet else self.module(x) - - -class DenseBlockWithDropout(nn.Module): - """Dense block with optional residual connection and support for dropout.""" - def __init__(self, module, resnet=False): - super().__init__() - self.module = module - self.resnet = resnet - self._supports_custom_dropout = True - - def forward(self, x, p): - return self.module(x, p) + x if self.resnet else self.module(x, p) - - -class DotInteract(nn.Module): - """Performs feature interaction operation between dense or sparse features.""" - - def __init__(self, num_sparse_features): - super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) - - def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] - return torch.cat((dense_features, interactions_flat), dim=1) - - -class DLRMResNet(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bot_mlp_blocks = [] - for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): - block = [] - block.append(nn.Linear(input_dim, dense_dim)) - block.append(nn.ReLU(inplace=True)) - block = nn.Sequential(*block) - if layer_idx > 0: - block = DenseBlock(block, resnet=True) - else: - block = DenseBlock(block) - bot_mlp_blocks.append(block) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bot_mlp_blocks) - - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - # Number of sparse features = 26 - fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] - num_layers_top = len(self.mlp_top_dims) - mlp_top_blocks = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - block = [] - block.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - block.append(nn.ReLU(inplace=True)) - if layer_idx == num_layers_top - 2: - block.append(CustomDropout()) - block = SequentialWithDropout(*block) - if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlockWithDropout(block, resnet=True) - else: - block = DenseBlockWithDropout(block) - mlp_top_blocks.append(block) - fan_in = fan_out - self.top_mlp = SequentialWithDropout(*mlp_top_blocks) - - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate=DROPOUT_RATE): - - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) - top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) - - # Final MLP. - logits = self.top_mlp(top_mlp_input, dropout_rate) - return logits - - -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): - super().__init__() - self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) - self.num_dense_features = num_dense_features - self.num_sparse_features = num_sparse_features - self.mlp_bottom_dims = mlp_bottom_dims - self.mlp_top_dims = mlp_top_dims - self.embed_dim = embed_dim - self.embedding_init_multiplier = embedding_init_multiplier - - # Ideally, we should use the pooled embedding implementation from - # `TorchRec`. However, in order to have identical implementation - # with that of Jax, we define a single embedding matrix. - num_chunks = 4 - assert vocab_size % num_chunks == 0 - self.embedding_table_chucks = [] - - if self.embedding_init_multiplier is None: - scale = 1.0 / torch.sqrt(self.vocab_size) - else: - scale = self.embedding_init_multiplier - - for i in range(num_chunks): - chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) - chunk.data.uniform_(0, 1) - chunk.data = scale * chunk.data - self.register_parameter(f'embedding_chunk_{i}', chunk) - self.embedding_table_chucks.append(chunk) - - input_dim = self.num_dense_features - bottom_mlp_layers = [] - for dense_dim in self.mlp_bottom_dims: - bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) - bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) - input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) - for module in self.bot_mlp.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - - # TODO: Write down the formula here instead of the constant. - input_dims = 506 - num_layers_top = len(self.mlp_top_dims) - top_mlp_layers = [] - for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] - top_mlp_layers.append(nn.Linear(fan_in, fan_out)) - if layer_idx < (num_layers_top - 1): - top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if layer_idx == num_layers_top - 2: - top_mlp_layers.append(CustomDropout()) - self.top_mlp = SequentialWithDropout(*top_mlp_layers) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None - for module in self.top_mlp.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) - - def forward(self, x, dropout_rate=DROPOUT_RATE): - - batch_size = x.shape[0] - - dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) - - # Bottom MLP. - embedded_dense = self.bot_mlp(dense_features) - - # Sparse feature processing. - sparse_features = sparse_features.to(dtype=torch.int32) - idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size - embedding_table = torch.cat(self.embedding_table_chucks, dim=0) - embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) - if self.embed_ln: - embedded_sparse = self.embed_ln(embedded_sparse) - # Dot product interactions. - concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) - - # Final MLP. - logits = self.top_mlp(concatenated_dense, dropout_rate) - return logits diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 28f20bf20..0b8ac5499 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -13,6 +13,9 @@ from torch.nn import functional as F from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout + +DROPOUT_RATE = 0.0 class UNet(nn.Module): @@ -27,7 +30,6 @@ def __init__(self, out_chans: int = 1, num_channels: int = 32, num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, use_tanh: bool = False, use_layer_norm: bool = False) -> None: super().__init__() @@ -36,21 +38,19 @@ def __init__(self, self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - dropout_rate = 0.0 + self.down_sample_layers = nn.ModuleList([ ConvBlock(in_chans, num_channels, - dropout_rate, use_tanh, use_layer_norm) ]) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( - ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) ch *= 2 - self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm) + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() @@ -59,14 +59,14 @@ def __init__(self, self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) self.up_conv.append( - nn.Sequential( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm), + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), )) @@ -74,24 +74,28 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + dropout_rate: float = DROPOUT_RATE) -> Tensor: + stack = [] output = x # apply down-sampling layers for layer in self.down_sample_layers: - output = layer(output) + output = layer(output, dropout_rate) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - output = self.conv(output) + output = self.conv(output, dropout_rate) # apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) - # reflect pad on the right/botton if needed to handle + # reflect pad on the right/bottom if needed to handle # odd input dimensions padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: @@ -102,7 +106,7 @@ def forward(self, x: Tensor) -> Tensor: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) - output = conv(output) + output = conv(output, dropout_rate) return output @@ -114,10 +118,10 @@ class ConvBlock(nn.Module): def __init__(self, in_chans: int, out_chans: int, - dropout_rate: float, use_tanh: bool, use_layer_norm: bool) -> None: super().__init__() + self._supports_custom_dropout = True if use_layer_norm: norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) @@ -127,19 +131,19 @@ def __init__(self, activation_fn = nn.Tanh() else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = nn.Sequential( + self.conv_layers = SequentialWithDropout( nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), norm_layer(out_chans), activation_fn, - nn.Dropout2d(dropout_rate), + CustomDropout2d(), ) - def forward(self, x: Tensor) -> Tensor: - return self.conv_layers(x) + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) class TransposeConvBlock(nn.Module): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py deleted file mode 100644 index 0b8ac5499..000000000 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py +++ /dev/null @@ -1,173 +0,0 @@ -"""U-Net Model. - -Adapted from fastMRI: -https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py -""" - -from functools import partial -from typing import Optional - -import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F - -from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout - -DROPOUT_RATE = 0.0 - - -class UNet(nn.Module): - r"""U-Net model from - `"U-net: Convolutional networks - for biomedical image segmentation" - `_. - """ - - def __init__(self, - in_chans: int = 1, - out_chans: int = 1, - num_channels: int = 32, - num_pool_layers: int = 4, - use_tanh: bool = False, - use_layer_norm: bool = False) -> None: - super().__init__() - - self.in_chans = in_chans - self.out_chans = out_chans - self.num_channels = num_channels - self.num_pool_layers = num_pool_layers - - self.down_sample_layers = nn.ModuleList([ - ConvBlock(in_chans, - num_channels, - use_tanh, - use_layer_norm) - ]) - ch = num_channels - for _ in range(num_pool_layers - 1): - self.down_sample_layers.append( - ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) - ch *= 2 - self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) - - self.up_conv = nn.ModuleList() - self.up_transpose_conv = nn.ModuleList() - - for _ in range(num_pool_layers - 1): - self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - ch //= 2 - - self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - SequentialWithDropout( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), - nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), - )) - - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - init_utils.pytorch_default_init(m) - - def forward( - self, - x: Tensor, - dropout_rate: float = DROPOUT_RATE) -> Tensor: - - stack = [] - output = x - - # apply down-sampling layers - for layer in self.down_sample_layers: - output = layer(output, dropout_rate) - stack.append(output) - output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - - output = self.conv(output, dropout_rate) - - # apply up-sampling layers - for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): - downsample_layer = stack.pop() - output = transpose_conv(output) - - # reflect pad on the right/bottom if needed to handle - # odd input dimensions - padding = [0, 0, 0, 0] - if output.shape[-1] != downsample_layer.shape[-1]: - padding[1] = 1 # padding right - if output.shape[-2] != downsample_layer.shape[-2]: - padding[3] = 1 # padding bottom - if torch.sum(torch.tensor(padding)) != 0: - output = F.pad(output, padding, "reflect") - - output = torch.cat([output, downsample_layer], dim=1) - output = conv(output, dropout_rate) - - return output - - -class ConvBlock(nn.Module): - # A Convolutional Block that consists of two convolution layers each - # followed by instance normalization, LeakyReLU activation and dropout_rate. - - def __init__(self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool) -> None: - super().__init__() - self._supports_custom_dropout = True - - if use_layer_norm: - norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) - else: - norm_layer = nn.InstanceNorm2d - if use_tanh: - activation_fn = nn.Tanh() - else: - activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = SequentialWithDropout( - nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), - nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), - ) - - def forward(self, x: Tensor, dropout_rate: float) -> Tensor: - return self.conv_layers(x, dropout_rate) - - -class TransposeConvBlock(nn.Module): - # A Transpose Convolutional Block that consists of one convolution transpose - # layers followed by instance normalization and LeakyReLU activation. - - def __init__( - self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool, - ): - super().__init__() - if use_tanh: - activation_fn = nn.Tanh() - else: - activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.layers = nn.Sequential( - nn.ConvTranspose2d( - in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_chans), - activation_fn, - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layers(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fcf0992d3..60e09edb5 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,7 +14,9 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention + +DROPOUT_RATE = 0.0 def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: @@ -41,18 +43,15 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + use_glu: bool = False) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') - self.dropout = nn.Dropout(self.dropout_rate) if self.use_glu: self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) @@ -70,7 +69,8 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + x = self.linear1(x) x = self.act_fnc(x) @@ -78,7 +78,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.glu_linear(x) x = x * y - x = self.dropout(x) + x = F.dropout(x, dropout_rate, training=self.training) x = self.linear2(x) return x @@ -88,8 +88,7 @@ class SelfAttention(nn.Module): def __init__(self, width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + num_heads: int = 8) -> None: super().__init__() self.width = width @@ -104,7 +103,6 @@ def __init__(self, self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) self.value = nn.Linear(self.width, self.all_head_dim) - self.dropout = nn.Dropout(dropout_rate) self.out = nn.Linear(self.width, self.width) self.reset_parameters() @@ -120,7 +118,8 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: + mixed_query_layer = self.query(x) key_layer = self.transpose_for_scores(self.key(x)) @@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -149,8 +148,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.width = width @@ -161,35 +159,34 @@ def __init__(self, self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( width=self.width, mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + use_glu=self.use_glu) + + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) return x @@ -204,8 +201,7 @@ def __init__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + use_post_layer_norm: bool = False) -> None: super().__init__() self.depth = depth @@ -220,8 +216,7 @@ def __init__(self, self.mlp_dim, self.num_heads, self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) + self.use_post_layer_norm) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -229,10 +224,10 @@ def __init__(self, else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -259,13 +254,13 @@ def __init__(self, self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -285,15 +280,12 @@ def __init__( mlp_dim: Optional[int] = None, # Defaults to 4x input dim. num_heads: int = 12, rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, use_glu: bool = False, use_post_layer_norm: bool = False, use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -318,7 +310,6 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid') - self.dropout = nn.Dropout(p=dropout_rate) self.encoder = Encoder( depth=self.depth, @@ -326,8 +317,7 @@ def __init__( mlp_dim=self.mlp_dim, num_heads=self.num_heads, use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + use_post_layer_norm=self.use_post_layer_norm) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -355,7 +345,11 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward( + self, + x: spec.Tensor, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + # Patch extraction. x = self.conv_patch_extract(x) @@ -367,11 +361,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) x = x + pes - x = self.dropout(x) - x = self.encoder(x) + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py deleted file mode 100644 index 60e09edb5..000000000 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py +++ /dev/null @@ -1,378 +0,0 @@ -"""PyTorch implementation of refactored and simplified ViT. - -Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit -and https://github.com/lucidrains/vit-pytorch. -""" - -import math -from typing import Any, Optional, Tuple, Union - -import torch -from torch import nn -import torch.nn.functional as F - -from algoperf import init_utils -from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention - -DROPOUT_RATE = 0.0 - - -def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: - """Follows the MoCo v3 logic.""" - _, width, h, w = patches.shape - device = patches.device - y, x = torch.meshgrid(torch.arange(h, device=device), - torch.arange(w, device=device), indexing='ij') - - if width % 4 != 0: - raise ValueError('Width must be mult of 4 for sincos posemb.') - omega = torch.arange(width // 4, device=device) / (width // 4 - 1) - omega = 1. / (temperature**omega) - y = y.flatten()[:, None] * omega[None, :] - x = x.flatten()[:, None] * omega[None, :] - pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) - return pe[None, :, :] - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block.""" - - def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False) -> None: - super().__init__() - - self.width = width - self.mlp_dim = mlp_dim or 4 * width - self.use_glu = use_glu - - self.linear1 = nn.Linear(self.width, self.mlp_dim) - self.act_fnc = nn.GELU(approximate='tanh') - - if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) - else: - self.glu_linear = None - - self.linear2 = nn.Linear(self.mlp_dim, self.width) - - self.reset_parameters() - - def reset_parameters(self) -> None: - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight.data) - if module.bias is not None: - module.bias.data.normal_(std=1e-6) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - x = self.linear1(x) - x = self.act_fnc(x) - - if self.use_glu: - y = self.glu_linear(x) - x = x * y - - x = F.dropout(x, dropout_rate, training=self.training) - x = self.linear2(x) - return x - - -class SelfAttention(nn.Module): - """Self-attention special case of multi-head dot-product attention.""" - - def __init__(self, - width: int, - num_heads: int = 8) -> None: - super().__init__() - - self.width = width - self.num_heads = num_heads - - assert width % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') - - self.head_dim = int(width / num_heads) - self.all_head_dim = self.num_heads * self.head_dim - - self.query = nn.Linear(self.width, self.all_head_dim) - self.key = nn.Linear(self.width, self.all_head_dim) - self.value = nn.Linear(self.width, self.all_head_dim) - self.out = nn.Linear(self.width, self.width) - self.reset_parameters() - - def reset_parameters(self) -> None: - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight.data) - if module.bias is not None: - nn.init.constant_(module.bias.data, 0.) - - def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: - new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - mixed_query_layer = self.query(x) - - key_layer = self.transpose_for_scores(self.key(x)) - value_layer = self.transpose_for_scores(self.value(x)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) - context_layer = context_layer.view(new_context_layer_shape) - out = self.out(context_layer) - return out - - -class Encoder1DBlock(nn.Module): - """Single transformer encoder block (MHSA + MLP).""" - - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: - super().__init__() - - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) - self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - - if not self.use_post_layer_norm: - y = self.layer_norm0(x) - y = self.self_attention1(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - else: - y = x - y = self.self_attention1(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - x = self.layer_norm0(x) - - y = x - y = self.mlp3(y, dropout_rate) - y = F.dropout(y, dropout_rate, training=self.training) - x = x + y - x = self.layer_norm2(x) - return x - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - def __init__(self, - depth: int, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: - super().__init__() - - self.depth = depth - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - - self.net = nn.ModuleList([ - Encoder1DBlock(self.width, - self.mlp_dim, - self.num_heads, - self.use_glu, - self.use_post_layer_norm) for _ in range(depth) - ]) - - if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) - else: - self.encoder_norm = None - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - # Input Encoder. - for block in self.net: - x = block(x, dropout_rate) - if not self.use_post_layer_norm: - return self.encoder_norm(x) - else: - return x - - -class MAPHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12): - super().__init__() - self.width = width - self.mlp_dim = mlp_dim - self.num_heads = num_heads - - self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) - nn.init.xavier_uniform_(self.probe.data) - - self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) - self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - - def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - n, _, _ = x.shape - probe = torch.tile(self.probe, [n, 1, 1]) - - x = self.mha(probe, x, dropout_rate=dropout_rate)[0] - y = self.layer_norm(x) - x = x + self.mlp(y, dropout_rate) - return x[:, 0] - - -class ViT(nn.Module): - """ViT model.""" - - image_height: int = 224 - image_width: int = 224 - channels: int = 3 - - def __init__( - self, - num_classes: int = 1000, - patch_size: Tuple[int, int] = (16, 16), - width: int = 768, - depth: int = 12, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - num_heads: int = 12, - rep_size: Union[int, bool] = True, - head_zeroinit: bool = True, - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, - dtype: Any = torch.float32) -> None: - super().__init__() - - self.num_classes = num_classes - self.patch_size = patch_size - self.width = width - self.depth = depth - self.mlp_dim = mlp_dim - self.num_heads = num_heads - self.rep_size = rep_size - self.head_zeroinit = head_zeroinit - self.use_glu = use_glu - self.use_post_layer_norm = use_post_layer_norm - self.use_map = use_map - self.dtype = dtype - - if self.rep_size: - rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size) - - self.conv_patch_extract = nn.Conv2d( - self.channels, - self.width, - self.patch_size, - stride=self.patch_size, - padding='valid') - - self.encoder = Encoder( - depth=self.depth, - width=self.width, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm) - - if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes) - - if self.use_map: - self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) - else: - self.map = None - - self.reset_parameters() - - def reset_parameters(self) -> None: - init_utils.pytorch_default_init(self.conv_patch_extract) - - if self.rep_size: - init_utils.pytorch_default_init(self.pre_logits) - - if self.num_classes: - if self.head_zeroinit: - nn.init.constant_(self.head.weight.data, 0.) - nn.init.constant_(self.head.bias.data, 0.) - else: - init_utils.pytorch_default_init(self.head) - - def get_posemb(self, x: spec.Tensor) -> spec.Tensor: - return posemb_sincos_2d(x).type(self.dtype) - - def forward( - self, - x: spec.Tensor, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: - - # Patch extraction. - x = self.conv_patch_extract(x) - - # Add posemb before adding extra token. - n, c, h, w = x.shape - pes = self.get_posemb(x) - - # Reshape to match Jax's ViT implementation. - x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) - x = x + pes - - x = F.dropout(x, dropout_rate, training=self.training) - x = self.encoder(x, dropout_rate) - - if self.use_map: - x = self.map(x, dropout_rate) - else: - x = torch.mean(x, dim=1) - - if self.rep_size: - x = torch.tanh(self.pre_logits(x)) - - if self.num_classes: - x = self.head(x) - - return x diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index db1e24521..a6a60bf95 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial import math -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -17,6 +17,8 @@ from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug +DROPOUT_RATE = 0.1 + @dataclass class ConformerConfig: @@ -26,10 +28,7 @@ class ConformerConfig: num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - attention_residual_dropout_rate: float = 0.1 - conv_residual_dropout_rate: float = 0.0 feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: float = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -39,7 +38,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - input_dropout_rate: float = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -77,11 +75,9 @@ class Subsample(nn.Module): def __init__(self, encoder_dim: int = 0, - input_dropout_rate: float = 0.0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim) @@ -93,9 +89,9 @@ def __init__(self, out_features=self.encoder_dim, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -109,7 +105,7 @@ def forward(self, inputs, input_paddings): outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs, output_paddings @@ -201,15 +197,8 @@ def __init__(self, config: ConformerConfig): out_features=config.encoder_dim, bias=True) - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout( - p=feed_forward_residual_dropout_rate, inplace=True) + def forward(self, inputs, padding_mask, dropout_rate): - def forward(self, inputs, padding_mask): inputs = self.ln(inputs) inputs = self.linear1(inputs) if self.config.activation_function_name == 'swish': @@ -226,7 +215,7 @@ def forward(self, inputs, padding_mask): inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask - inputs = self.dropout2(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -280,7 +269,7 @@ def __init__(self, config: ConformerConfig): super().__init__() self.embed_dim = config.encoder_dim self.num_heads = config.num_attention_heads - self.dropout = config.attention_dropout_rate + self.attention_dropout_rate = config.attention_dropout_rate self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) @@ -297,7 +286,7 @@ def forward(self, inputs, key_padding_mask=None): key=k, value=v, attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.dropout, + dropout_p=self.attention_dropout_rate, ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) out = out * self.attention_temperature out = self.out_proj(out) @@ -313,19 +302,14 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) - def forward(self, outputs, paddings): + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( outputs, key_padding_mask=paddings == 1, ) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) return outputs @@ -405,13 +389,8 @@ def __init__(self, config): groups=config.encoder_dim) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -433,7 +412,7 @@ def forward(self, inputs, input_paddings): inputs = activation_fn(inputs) inputs = self.lin3(inputs) - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -450,12 +429,12 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask) - inputs = inputs + self.mhsa(inputs, input_paddings) - inputs = inputs + self.conv(inputs, input_paddings) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask) + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) if self.ln: inputs = self.ln(inputs) return inputs @@ -480,13 +459,8 @@ def __init__(self, config: ConformerConfig): time_masks_per_frame=config.time_masks_per_frame, use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) @@ -494,15 +468,15 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: - outputs = conformer(outputs, output_paddings) + outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) outputs = self.lin(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py deleted file mode 100644 index a6a60bf95..000000000 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py +++ /dev/null @@ -1,482 +0,0 @@ -"""This is a pytorch implementation mirroring: -https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. -""" - -from dataclasses import dataclass -from functools import partial -import math -from typing import Optional, Tuple - -import torch -from torch import nn -from torch.nn import init -import torch.nn.functional as F - -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug - -DROPOUT_RATE = 0.1 - - -@dataclass -class ConformerConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - encoder_dim: int = 512 - num_attention_heads: int = 8 - num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 - feed_forward_dropout_rate: float = 0.0 - convolution_kernel_size: int = 5 - feed_forward_expansion_factor: int = 4 - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 1 - 0.999 - batch_norm_epsilon: float = 0.001 - use_specaug: bool = True - attention_temperature: float = 1.0 - activation_function_name: str = 'swish' - use_post_layer_norm: bool = True - - -def initialize(m): - if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d): - init.xavier_uniform_(m.weight) - if m.bias is not None: - init.constant_(m.bias, 0) - elif isinstance(m, nn.MultiheadAttention): - init.xavier_uniform_(m.in_proj_weight) - for i in m.children(): - initialize(i) - - -class LayerNorm(nn.Module): - - def __init__(self, dim, epsilon=1e-6): - super().__init__() - self.dim = dim - - self.scale = nn.Parameter(torch.zeros(self.dim)) - self.bias = nn.Parameter(torch.zeros(self.dim)) - self.epsilon = epsilon - - def forward(self, x): - return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon) - - -class Subsample(nn.Module): - - def __init__(self, - encoder_dim: int = 0, - num_bins: int = 80): - super().__init__() - self.encoder_dim = encoder_dim - - self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim) - self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim) - - self.linear = nn.Linear( - in_features=self.encoder_dim * num_bins // 4, - out_features=self.encoder_dim, - bias=True) - self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - - output_paddings = input_paddings - outputs = inputs[:, None, :, :] - - outputs, output_paddings = self.conv1(outputs, output_paddings) - outputs, output_paddings = self.conv2(outputs, output_paddings) - - batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) - - outputs = self.linear(outputs) - outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME'): - super().__init__() - - self.input_channels = input_channels - self.output_channels = output_channels - self.filter_stride = filter_stride - self.padding = padding - - self.filter_shape = (output_channels, input_channels, 3, 3) - - self.kernel = nn.Parameter( - torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) - self.bias = nn.Parameter(torch.zeros(output_channels)) - self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) - - def get_same_padding(self, input_shape): - in_height, in_width = input_shape[2:] - stride_height, stride_width = self.filter_stride - filter_height, filter_width = 3, 3 - if in_height % stride_height == 0: - pad_along_height = max(filter_height - stride_height, 0) - else: - pad_along_height = max(filter_height - (in_height % stride_height), 0) - if in_width % stride_width == 0: - pad_along_width = max(filter_width - stride_width, 0) - else: - pad_along_width = max(filter_width - (in_width % stride_width), 0) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - return (pad_left, pad_right, pad_top, pad_bottom) - - def forward(self, inputs, paddings): - groups = inputs.shape[1] // self.input_channels - - if self.padding == 'SAME': - in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) - else: - in_ = inputs - outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) - - outputs = F.relu(outputs) - - input_length = paddings.shape[1] - stride = self.filter_stride[0] - pad_len = (input_length + stride - 1) // stride * stride - input_length - padded_paddings = F.pad( - paddings[:, None, :], (0, pad_len), mode='constant', value=0) - out_padding = F.conv1d( - input=padded_paddings, - weight=self.paddings_kernel, - stride=self.filter_stride[:1]) - out_padding = out_padding.squeeze(dim=1) - outputs = outputs * (1 - out_padding[:, None, :, None]) - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.config = config - - self.ln = LayerNorm(dim=config.encoder_dim) - self.linear1 = nn.Linear( - in_features=config.encoder_dim, - out_features=config.encoder_dim * config.feed_forward_expansion_factor, - bias=True) - self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) - self.linear2 = nn.Linear( - in_features=config.encoder_dim * config.feed_forward_expansion_factor, - out_features=config.encoder_dim, - bias=True) - - def forward(self, inputs, padding_mask, dropout_rate): - - inputs = self.ln(inputs) - inputs = self.linear1(inputs) - if self.config.activation_function_name == 'swish': - activation_fn = F.silu - elif self.config.activation_function_name == 'gelu': - # Use tanh approximation of GELU which is default for jax - activation_fn = partial(F.gelu, approximate='tanh') - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = self.dropout1(inputs) - inputs = inputs * padding_mask - inputs = self.linear2(inputs) - inputs = inputs * padding_mask - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) - - return inputs - - -class AddPositionalEmbedding(nn.Module): - - def __init__(self, - min_timescale: int = 1, - max_timescale: int = 10_000, - embedding_dim: int = 512): - super().__init__() - self.min_timescale = min_timescale - self.max_timescale = max_timescale - self.embedding_dim = embedding_dim - num_timescales = self.embedding_dim // 2 - log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale)) / ( - num_timescales - 1) - inv_timescales = self.min_timescale * \ - torch.exp(torch.arange(num_timescales, dtype=torch.float32) - * -log_timescale_increment) - self.register_buffer('inv_timescales', inv_timescales[None, None, :]) - - def forward(self, seq_length): - position = torch.arange( - end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) - scaled_time = position[None, :, None] * self.inv_timescales - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) - if self.embedding_dim % 2: - signal = torch.cat( - [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) - return signal - - -class QueryScaler(nn.Module): - - def __init__(self, dim): - super().__init__() - self.dim = dim - self.scale = nn.Parameter(torch.zeros(self.dim)) - - def forward(self, inputs): - r_softplus_0 = 1.442695041 - scale = r_softplus_0 * F.softplus(self.scale) - return inputs * scale - - -class MHSAwithQS(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.embed_dim = config.encoder_dim - self.num_heads = config.num_attention_heads - self.attention_dropout_rate = config.attention_dropout_rate - self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) - self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) - self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) - self.attention_temperature = config.attention_temperature - - def forward(self, inputs, key_padding_mask=None): - batch_size, seq_len, embed_dim = inputs.shape - q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) - q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) - k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - out = F.scaled_dot_product_attention( - query=q, - key=k, - value=v, - attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.attention_dropout_rate, - ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) - out = out * self.attention_temperature - out = self.out_proj(out) - return out - - -class MultiHeadedSelfAttention(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - - self.config = config - - self.ln = LayerNorm(dim=config.encoder_dim) - self.self_attention = MHSAwithQS(config) - - def forward(self, outputs, paddings, dropout_rate): - outputs = self.ln(outputs) - outputs = self.self_attention( - outputs, - key_padding_mask=paddings == 1, - ) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) - return outputs - - -class BatchNorm(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - running_mean = torch.zeros(config.encoder_dim) - running_var = torch.ones(config.encoder_dim) - self.register_buffer('running_mean', running_mean) - self.register_buffer('running_var', running_var) - self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) - self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) - - self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) - self.momentum = config.batch_norm_momentum - self.epsilon = config.batch_norm_epsilon - - def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH - """ - Alternatively: - inputs[input_paddings==0] = F.batch_norm( - input = inputs[input_paddings==0], - running_mean = self.running_mean, - running_var = self.running_var, - weight = 1+self.scale, - bias = self.bias, - training = self.training, - momentum=1-self.momentum, - eps=self.epsilon - ) - inputs.masked_fill(input_paddings[...,None] != 0, 0) - return inputs - """ - mask = 1 - input_paddings[:, :, None] - if self.training: - count = mask.sum() - masked_inp = inputs.masked_fill(mask == 0, 0) - mean = (masked_inp).sum(dim=(0, 1)) / count - var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - - self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() - self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() - - else: - mean = self.running_mean - var = self.running_var - v = (1 + self.scale) * torch.rsqrt(var + self.epsilon) - bn = (inputs - mean) * v + self.bias - output = bn.masked_fill(mask == 0, 0) - return output - - -class ConvolutionBlock(nn.Module): - - def __init__(self, config): - super().__init__() - - self.config = config - self.ln = LayerNorm(dim=config.encoder_dim) - self.lin1 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) - self.lin2 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) - - self.conv1 = nn.Conv1d( - in_channels=config.encoder_dim, - out_channels=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - stride=(1,), - padding='same', - bias=False, - groups=config.encoder_dim) - self.bn = BatchNorm(config) - self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - inputs = self.ln(inputs) - - inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) - inputs = inputs * (1 - input_paddings[:, :, None]) - - inputs = inputs.permute(0, 2, 1) - inputs = self.conv1(inputs) - inputs = inputs.permute(0, 2, 1) - - inputs = self.bn(inputs, input_paddings) - if self.config.activation_function_name == 'swish': - activation_fn = F.silu - elif self.config.activation_function_name == 'gelu': - activation_fn = F.gelu - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = self.lin3(inputs) - - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) - return inputs - - -class ConformerBlock(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - - self.ff1 = FeedForwardModule(config) - self.mhsa = MultiHeadedSelfAttention(config) - self.conv = ConvolutionBlock(config) - self.ff2 = FeedForwardModule(config) - self.ln = None - if config.use_post_layer_norm: - self.ln = LayerNorm(dim=config.encoder_dim) - - def forward(self, inputs, input_paddings, dropout_rate): - padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) - inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) - inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) - if self.ln: - inputs = self.ln(inputs) - return inputs - - -class ConformerEncoderDecoder(nn.Module): - - def __init__(self, config: ConformerConfig): - super().__init__() - self.config = config - preprocessing_config = preprocessor.PreprocessorConfig() - self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) - self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames - ) - self.subsample = Subsample( - encoder_dim=config.encoder_dim, - num_bins=preprocessing_config.num_bins) - self.conformers = nn.ModuleList( - [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) - - self.ln = LayerNorm(config.encoder_dim) - self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - - def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): - outputs = inputs - output_paddings = input_paddings - outputs, output_paddings = self.preprocessor(outputs, output_paddings) - if self.training and self.config.use_specaug: - outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) - for conformer in self.conformers: - outputs = conformer(outputs, output_paddings, dropout_rate) - outputs = self.ln(outputs) - outputs = self.lin(outputs) - return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 84d317326..a8480a343 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -17,6 +17,7 @@ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DROPOUT_RATE = 0.1 @dataclass @@ -38,10 +39,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -87,13 +84,8 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - self.dropout = nn.Dropout(p=input_dropout_rate) + def forward(self, inputs, input_paddings, dropout_rate): - def forward(self, inputs, input_paddings): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -106,7 +98,7 @@ def forward(self, inputs, input_paddings): subsampled_dims * channels) outputs = self.lin(outputs) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) return outputs, output_paddings @@ -205,13 +197,9 @@ def __init__(self, config: DeepspeechConfig): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - self.dropout = nn.Dropout(p=feed_forward_dropout_rate) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): + padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) @@ -226,7 +214,7 @@ def forward(self, inputs, input_paddings): inputs = F.relu(inputs) inputs = inputs * padding_mask - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training) return inputs @@ -363,14 +351,14 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) @@ -379,9 +367,9 @@ def forward(self, inputs, input_paddings): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) else: - outputs = self.ffns[idx](outputs, output_paddings) + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) if self.config.enable_decoder_layer_norm: outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py deleted file mode 100644 index a8480a343..000000000 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py +++ /dev/null @@ -1,379 +0,0 @@ -"""This is a pytorch implementation mirroring: -https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. -""" - -from dataclasses import dataclass -import os -from typing import Optional, Tuple - -import torch -from torch import nn -import torch.distributed.nn as dist_nn -import torch.nn.functional as F - -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug - -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ -DROPOUT_RATE = 0.1 - - -@dataclass -class DeepspeechConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - encoder_dim: int = 512 - num_lstm_layers: int = 6 - num_ffn_layers: int = 3 - conv_subsampling_factor: int = 2 - conv_subsampling_layers: int = 2 - use_specaug: bool = True - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 1 - 0.999 - batch_norm_epsilon: float = 0.001 - enable_residual_connections: bool = True - enable_decoder_layer_norm: bool = True - bidirectional: bool = True - use_tanh: bool = False - layernorm_everywhere: bool = False - - -class LayerNorm(nn.Module): - - def __init__(self, dim, epsilon=1e-6): - super().__init__() - self.dim = dim - - self.scale = nn.Parameter(torch.zeros(self.dim)) - self.bias = nn.Parameter(torch.zeros(self.dim)) - self.epsilon = epsilon - - def forward(self, x): - mean = x.mean(dim=-1, keepdims=True) - var = x.var(dim=-1, unbiased=False, keepdims=True) - - normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) - normed_x *= (1 + self.scale) - normed_x += self.bias - - return normed_x - - -class Subsample(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - encoder_dim = config.encoder_dim - - self.encoder_dim = encoder_dim - - self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) - self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, - output_channels=encoder_dim, - use_tanh=config.use_tanh) - - self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - - def forward(self, inputs, input_paddings, dropout_rate): - - output_paddings = input_paddings - outputs = inputs[:, None, :, :] - - outputs, output_paddings = self.conv1(outputs, output_paddings) - outputs, output_paddings = self.conv2(outputs, output_paddings) - - batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) - - outputs = self.lin(outputs) - outputs = F.dropout(outputs, dropout_rate, training=self.training) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME', - batch_norm_momentum: float = 0.999, - batch_norm_epsilon: float = 0.001, - use_tanh: bool = False): - super().__init__() - - self.input_channels = input_channels - self.output_channels = output_channels - self.filter_stride = filter_stride - self.padding = padding - - self.filter_shape = (output_channels, input_channels, 3, 3) - - self.kernel = nn.Parameter( - nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) - self.bias = nn.Parameter(torch.zeros(output_channels)) - - self.use_tanh = use_tanh - - def get_same_padding(self, input_shape): - in_height, in_width = input_shape[2:] - stride_height, stride_width = self.filter_stride - filter_height, filter_width = 3, 3 - if in_height % stride_height == 0: - pad_along_height = max(filter_height - stride_height, 0) - else: - pad_along_height = max(filter_height - (in_height % stride_height), 0) - if in_width % stride_width == 0: - pad_along_width = max(filter_width - stride_width, 0) - else: - pad_along_width = max(filter_width - (in_width % stride_width), 0) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - return (pad_left, pad_right, pad_top, pad_bottom) - - def forward(self, inputs, paddings): - groups = inputs.shape[1] // self.input_channels - - if self.padding == 'SAME': - in_ = F.pad(inputs, self.get_same_padding(inputs.shape)) - else: - in_ = inputs - outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) - - if self.use_tanh: - outputs = F.tanh(outputs) - else: - outputs = F.relu(outputs) - - input_length = paddings.shape[1] - stride = self.filter_stride[0] - pad_len = (input_length + stride - 1) // stride * stride - input_length - out_padding = F.conv1d( - input=torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) - ], - dim=2), - weight=torch.ones([1, 1, 1], device=paddings.device), - stride=self.filter_stride[:1]) - out_padding = out_padding.squeeze(dim=1) - outputs = outputs * (1 - out_padding[:, None, :, None]) - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - - if config.layernorm_everywhere: - self.normalization_layer = LayerNorm(config.encoder_dim) - else: - self.bn_normalization_layer = BatchNorm( - dim=config.encoder_dim, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon) - self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - - def forward(self, inputs, input_paddings, dropout_rate): - - padding_mask = (1 - input_paddings)[:, :, None] - if self.config.layernorm_everywhere: - inputs = self.normalization_layer(inputs) - else: # batchnorm - inputs = self.bn_normalization_layer(inputs, input_paddings) - - inputs = self.lin(inputs) - - if self.config.use_tanh: - inputs = F.tanh(inputs) - else: - inputs = F.relu(inputs) - - inputs = inputs * padding_mask - inputs = F.dropout(inputs, dropout_rate, training=self.training) - - return inputs - - -class BatchNorm(nn.Module): - - def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): - super().__init__() - running_mean = torch.zeros(dim) - running_var = torch.ones(dim) - self.register_buffer('running_mean', running_mean) - self.register_buffer('running_var', running_var) - self.weight = nn.Parameter(torch.zeros(dim)) - self.bias = nn.Parameter(torch.zeros(dim)) - - self.momentum = batch_norm_momentum - self.epsilon = batch_norm_epsilon - self.dim = dim - - def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH - mask = 1 - input_paddings[:, :, None] - if self.training: - count = mask.sum() - masked_inp = inputs.masked_fill(mask == 0, 0) - sum_ = (masked_inp).sum(dim=(0, 1)) - if USE_PYTORCH_DDP: - sum_ = dist_nn.all_reduce(sum_) - count = dist_nn.all_reduce(count) - mean = sum_ / count - - sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) - if USE_PYTORCH_DDP: - sum_ = dist_nn.all_reduce(sum_) - var = sum_ / count - - self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() - self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() - else: - mean = self.running_mean - var = self.running_var - v = (1 + self.weight) * torch.rsqrt(var + self.epsilon) - bn = (inputs - mean) * v + self.bias - output = bn.masked_fill(mask == 0, 0) - return output - - -class BatchRNN(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - hidden_size = config.encoder_dim - input_size = config.encoder_dim - bidirectional = config.bidirectional - self.bidirectional = bidirectional - - if config.layernorm_everywhere: - self.normalization_layer = LayerNorm(config.encoder_dim) - else: - self.bn_normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) - - if bidirectional: - self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size // 2, - bidirectional=True, - batch_first=True) - else: - self.lstm = nn.LSTM( - input_size=input_size, hidden_size=hidden_size, batch_first=True) - - def forward(self, inputs, input_paddings): - if self.config.layernorm_everywhere: - inputs = self.normalization_layer(inputs) - else: - inputs = self.bn_normalization_layer(inputs, input_paddings) - lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() - packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( - inputs, lengths, batch_first=True, enforce_sorted=False) - packed_outputs, _ = self.lstm(packed_inputs) - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( - packed_outputs, batch_first=True) - if outputs.shape[1] < inputs.shape[1]: - outputs = torch.cat([ - outputs, - torch.zeros( - size=(outputs.shape[0], - inputs.shape[1] - outputs.shape[1], - outputs.shape[2]), - device=outputs.device) - ], - dim=1) - return outputs - - -class DeepspeechEncoderDecoder(nn.Module): - - def __init__(self, config: DeepspeechConfig): - super().__init__() - self.config = config - - self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames - ) - preprocessing_config = preprocessor.PreprocessorConfig() - self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) - - self.subsample = Subsample(config=config) - - self.lstms = nn.ModuleList( - [BatchRNN(config) for _ in range(config.num_lstm_layers)]) - self.ffns = nn.ModuleList( - [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) - - if config.enable_decoder_layer_norm: - self.ln = LayerNorm(config.encoder_dim) - else: - self.ln = nn.Identity() - - self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - - def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): - outputs = inputs - output_paddings = input_paddings - - outputs, output_paddings = self.preprocessor(outputs, output_paddings) - if self.training and self.config.use_specaug: - outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) - for idx in range(self.config.num_lstm_layers): - if self.config.enable_residual_connections: - outputs = outputs + self.lstms[idx](outputs, output_paddings) - else: - outputs = self.lstms[idx](outputs, output_paddings) - - for idx in range(self.config.num_ffn_layers): - if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) - else: - outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) - - if self.config.enable_decoder_layer_norm: - outputs = self.ln(outputs) - - outputs = self.lin(outputs) - - return outputs, output_paddings diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index fe9b29bc1..be5882333 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -9,17 +9,20 @@ from torch import nn from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +DROPOUT_RATE = 0.1 -def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): + +def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" - layers = nn.Sequential() + layers = SequentialWithDropout() for i, dim in enumerate(hidden_dims): layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + layers.add_module(f'dropout_{i}', CustomDropout()) in_dim = dim return layers @@ -33,7 +36,6 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, activation_fn_name: str = 'relu', latent_dim: int = 256, hidden_dims: Tuple[int] = (256,), @@ -43,8 +45,6 @@ def __init__(self, self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -77,17 +77,14 @@ def __init__(self, GraphNetwork( update_edge_fn=_make_mlp(in_dim_edge_fn, self.hidden_dims, - dropout_rate, activation_fn), update_node_fn=_make_mlp(in_dim_node_fn, self.hidden_dims, - dropout_rate, activation_fn), update_global_fn=_make_mlp(last_in_dim, self.hidden_dims, - dropout_rate, activation_fn))) - self.graph_network = nn.Sequential(*graph_network_layers) + self.graph_network = SequentialWithDropout(*graph_network_layers) self.decoder = nn.Linear( in_features=self.hidden_dims[-1], out_features=self.num_outputs) @@ -96,14 +93,18 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple) -> torch.Tensor: + def forward( + self, + graph: GraphsTuple, + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: + graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device)) graph = graph._replace(nodes=self.node_embedder(graph.nodes)) graph = graph._replace(edges=self.edge_embedder(graph.edges)) - graph = self.graph_network(graph) + graph = self.graph_network(graph, dropout_rate) # Map globals to represent the final result graph = graph._replace(globals=self.decoder(graph.globals)) @@ -145,8 +146,9 @@ def __init__(self, self.update_edge_fn = update_edge_fn self.update_node_fn = update_node_fn self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -159,6 +161,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ @@ -179,7 +182,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], dim=-1) - edges = self.update_edge_fn(edge_fn_inputs) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) if self.update_node_fn: sent_attributes = tree.tree_map( @@ -194,7 +197,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) - nodes = self.update_node_fn(node_fn_inputs) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) if self.update_global_fn: n_graph = n_node.shape[0] @@ -213,7 +216,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # These pooled nodes are the inputs to the global update fn. global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], dim=-1) - globals_ = self.update_global_fn(global_fn_inputs) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) return GraphsTuple( nodes=nodes, diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py deleted file mode 100644 index be5882333..000000000 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py +++ /dev/null @@ -1,315 +0,0 @@ -# Ported to PyTorch from -# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from functools import partial -from typing import Callable, Optional, Tuple - -import jax.tree_util as tree -from jraph import GraphsTuple -import torch -from torch import nn - -from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout - -DROPOUT_RATE = 0.1 - - -def _make_mlp(in_dim, hidden_dims, activation_fn): - """Creates a MLP with specified dimensions.""" - layers = SequentialWithDropout() - for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', - nn.Linear(in_features=in_dim, out_features=dim)) - layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', CustomDropout()) - in_dim = dim - return layers - - -class GNN(nn.Module): - """Defines a graph network. - - The model assumes the input data is a jraph.GraphsTuple without global - variables. The final prediction will be encoded in the globals. - """ - - def __init__(self, - num_outputs: int = 128, - activation_fn_name: str = 'relu', - latent_dim: int = 256, - hidden_dims: Tuple[int] = (256,), - num_message_passing_steps: int = 5) -> None: - super().__init__() - self.latent_dim = latent_dim - self.hidden_dims = hidden_dims - self.num_message_passing_steps = num_message_passing_steps - self.num_outputs = num_outputs - # in_features are specifically chosen for the ogbg workload. - self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) - self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) - - if activation_fn_name == 'relu': - activation_fn = nn.ReLU - elif activation_fn_name == 'gelu': - activation_fn = partial(nn.GELU, approximate='tanh') - elif activation_fn_name == 'silu': - activation_fn = nn.SiLU - else: - raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') - - graph_network_layers = [] - for st in range(self.num_message_passing_steps): - # Constants in in_dims are based on forward call of GraphNetwork: - # specifically update_edge_fn update_node_fn and update_global_fn. - if st == 0: - in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[ - -1] * 2 + self.num_outputs - last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs - else: - in_dim_edge_fn = self.hidden_dims[-1] * 4 - in_dim_node_fn = self.hidden_dims[-1] * 4 - last_in_dim = self.hidden_dims[-1] * 3 - - graph_network_layers.append( - GraphNetwork( - update_edge_fn=_make_mlp(in_dim_edge_fn, - self.hidden_dims, - activation_fn), - update_node_fn=_make_mlp(in_dim_node_fn, - self.hidden_dims, - activation_fn), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - activation_fn))) - self.graph_network = SequentialWithDropout(*graph_network_layers) - - self.decoder = nn.Linear( - in_features=self.hidden_dims[-1], out_features=self.num_outputs) - - for m in self.modules(): - if isinstance(m, nn.Linear): - init_utils.pytorch_default_init(m) - - def forward( - self, - graph: GraphsTuple, - dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: - - graph = graph._replace( - globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], - device=graph.n_node.device)) - graph = graph._replace(nodes=self.node_embedder(graph.nodes)) - graph = graph._replace(edges=self.edge_embedder(graph.edges)) - - graph = self.graph_network(graph, dropout_rate) - - # Map globals to represent the final result - graph = graph._replace(globals=self.decoder(graph.globals)) - - return graph.globals - - -class GraphNetwork(nn.Module): - """Returns a method that applies a configured GraphNetwork. - This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 - There is one difference. For the nodes update the class aggregates over the - sender edges and receiver edges separately. This is a bit more general - than the algorithm described in the paper. The original behaviour can be - recovered by using only the receiver edge aggregations for the update. - In addition this implementation supports softmax attention over incoming - edge features. - Example usage:: - gn = GraphNetwork(update_edge_function, - update_node_function, **kwargs) - # Conduct multiple rounds of message passing with the same parameters: - for _ in range(num_message_passing_steps): - graph = gn(graph) - Args: - update_edge_fn: function used to update the edges or None to deactivate edge - updates. - update_node_fn: function used to update the nodes or None to deactivate node - updates. - update_global_fn: function used to update the globals or None to deactivate - globals updates. - Returns: - A method that applies the configured GraphNetwork. - """ - - def __init__(self, - update_edge_fn: Optional[Callable] = None, - update_node_fn: Optional[Callable] = None, - update_global_fn: Optional[Callable] = None) -> None: - super().__init__() - self.update_edge_fn = update_edge_fn - self.update_node_fn = update_node_fn - self.update_global_fn = update_global_fn - self._supports_custom_dropout = True # supports SequentialWithDropout - - def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: - """Applies a configured GraphNetwork to a graph. - This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 - There is one difference. For the nodes update the class aggregates over the - sender edges and receiver edges separately. This is a bit more general - the algorithm described in the paper. The original behaviour can be - recovered by using only the receiver edge aggregations for the update. - In addition this implementation supports softmax attention over incoming - edge features. - Many popular Graph Neural Networks can be implemented as special cases of - GraphNets, for more information please see the paper. - Args: - graph: a `GraphsTuple` containing the graph. - dropout_rate: dropout probability value. - Returns: - Updated `GraphsTuple`. - """ - nodes, edges, receivers, senders, globals_, n_node, n_edge = graph - sum_n_node = tree.tree_leaves(nodes)[0].shape[0] - if not tree.tree_all( - tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): - raise ValueError( - 'All node arrays in nest must contain the same number of nodes.') - - sent_attributes = tree.tree_map(lambda n: n[senders], nodes) - received_attributes = tree.tree_map(lambda n: n[receivers], nodes) - # Here we scatter the global features to the corresponding edges, - # giving us tensors of shape [num_edges, global_feat]. - global_edge_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) - if self.update_edge_fn: - edge_fn_inputs = torch.cat( - [edges, sent_attributes, received_attributes, global_edge_attributes], - dim=-1) - edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) - - if self.update_node_fn: - sent_attributes = tree.tree_map( - lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) - received_attributes = tree.tree_map( - lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), - edges) - # Here we scatter the global features to the corresponding nodes, - # giving us tensors of shape [num_nodes, global_feat]. - global_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) - node_fn_inputs = torch.cat( - [nodes, sent_attributes, received_attributes, global_attributes], - dim=-1) - nodes = self.update_node_fn(node_fn_inputs, dropout_rate) - - if self.update_global_fn: - n_graph = n_node.shape[0] - graph_idx = torch.arange(n_graph, device=graph.n_node.device) - # To aggregate nodes and edges from each graph to global features, - # we first construct tensors that map the node to the corresponding graph. - # For example, if you have `n_node=[1,2]`, we construct the tensor - # [0, 1, 1]. We then do the same for edges. - node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0) - edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) - # We use the aggregation function to pool the nodes/edges per graph. - node_attributes = tree.tree_map( - lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) - edge_attributes = tree.tree_map( - lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) - # These pooled nodes are the inputs to the global update fn. - global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], - dim=-1) - globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) - - return GraphsTuple( - nodes=nodes, - edges=edges, - receivers=receivers, - senders=senders, - globals=globals_, - n_node=n_node, - n_edge=n_edge) - - -# Forked from -# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. -def scatter_sum(src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None) -> torch.Tensor: - r""" - | - .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ - master/docs/source/_figures/add.svg?sanitize=true - :align: center - :width: 400px - | - Reduces all values from the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`. - For each value in :attr:`src`, its output index is specified by its index - in :attr:`src` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. - The applied reduction is here defined as a sum. - Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional - tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` - and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional - tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. - Moreover, the values of :attr:`index` must be between :math:`0` and - :math:`y - 1`, although no specific ordering of indices is required. - The :attr:`index` tensor supports broadcasting in case its dimensions do - not match with :attr:`src`. - For one-dimensional tensors, the operation computes - .. math:: - \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. - .. note:: - This operation is implemented via atomic operations on the GPU and is - therefore **non-deterministic** since the order of parallel operations - to the same value is undetermined. - For floating-point variables, this results in a source of variance in - the result. - :param src: The source tensor. - :param index: The indices of elements to scatter. - :param dim: The axis along which to index. (default: :obj:`-1`) - :param out: The destination tensor. - :param dim_size: If :attr:`out` is not given, automatically create output - with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor - according to :obj:`index.max() + 1` is returned. - :rtype: :class:`Tensor` - .. code-block:: python - src = torch.randn(10, 6, 64) - index = torch.tensor([0, 1, 0, 1, 2, 1]) - # Broadcasting in the first and last dim. - out = scatter_sum(src, index, dim=1) - print(out.size()) - .. code-block:: - torch.Size([10, 3, 64]) - """ - index = broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -# Forked from -# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py. -def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand(other.size()) - return src diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..a43df30d4 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -9,6 +9,8 @@ from torch.nn.init import normal_ from torch.nn.init import xavier_uniform_ +DROPOUT_RATE = 0.1 + def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: """Make a causal mask for self-attention. @@ -104,26 +106,18 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, - attention_dropout_rate: Optional[float] = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, attention_temp: float = 1.0, pre_ln: bool = True): super().__init__() - if dropout_rate is None: - dropout_rate = 0.1 - if attention_dropout_rate is None: - attention_dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) self.encoder = Encoder(d_model, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -133,8 +127,6 @@ def __init__(self, nhead, d_hid, nlayers, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -163,7 +155,8 @@ def forward(self, targets_positions: Optional[Tensor] = None, inputs_segmentation: Optional[Tensor] = None, targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + decode: bool = False, + dropout_rate: float = DROPOUT_RATE) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -173,16 +166,19 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( src, inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate) output = self.decoder( tgt, memory, @@ -190,7 +186,8 @@ def forward(self, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, - decode=decode) + decode=decode, + dropout_rate=dropout_rate) return output @@ -229,12 +226,15 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -243,7 +243,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: output = output.to_padded_tensor(0.) @@ -261,8 +261,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -276,8 +274,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate=attention_dropout_rate, activation=activation, glu=glu, layer_norm_eps=layer_norm_eps, @@ -290,12 +286,13 @@ def __init__(self, def forward(self, src: Tensor, inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory @@ -306,8 +303,6 @@ def __init__(self, nhead: int = 16, d_hid: int = 1024, nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -320,8 +315,6 @@ def __init__(self, self.decoder = TransformerDecoder(d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -339,7 +332,8 @@ def forward( targets_segmentation: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( tgt, src, inputs_segmentation, targets_segmentation, @@ -347,7 +341,7 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -357,7 +351,8 @@ def forward( memory_mask=memory_mask, decode=decode, max_len=max_len, - cache=cache) + cache=cache, + dropout_rate=dropout_rate) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -371,10 +366,8 @@ class PositionalEncoding(nn.Module): def __init__(self, d_model: int, - dropout_rate: float = 0.1, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) @@ -389,7 +382,8 @@ def forward( x: Tensor, inputs_positions: Optional[Tensor] = None, decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = 0.0 ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +391,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -412,14 +407,14 @@ def forward( } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: pe = self.pe[:, :x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -457,8 +451,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -472,7 +464,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -482,50 +473,55 @@ def __init__(self, self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, + inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -537,7 +533,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -555,8 +550,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps, @@ -569,8 +562,6 @@ def __init__(self, d_model, nhead, d_hid, - dropout_rate, - attention_dropout_rate, activation, glu, layer_norm_eps=layer_norm_eps, @@ -587,7 +578,8 @@ def forward(self, memory_mask: Optional[Tensor] = None, decode: bool = False, max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -596,6 +588,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -610,7 +603,8 @@ def forward(self, decode=decode, max_len=max_len, cache=cache, - index=idx) + index=idx, + dropout_rate=dropout_rate) if self.norm is not None: output = self.norm(output) @@ -636,7 +630,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -656,8 +649,6 @@ def __init__(self, d_model: int = 1024, nhead: int = 16, dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, glu: bool = False, layer_norm_eps: float = 1e-6, @@ -671,7 +662,6 @@ def __init__(self, d_model, nhead, self_attn=True, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -679,7 +669,6 @@ def __init__(self, d_model, nhead, self_attn=False, - dropout_rate=attention_dropout_rate, attention_temp=attention_temp, bias=False, **factory_kwargs) @@ -691,16 +680,12 @@ def __init__(self, self.linear_glu = nn.Linear(dim_feedforward, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation @@ -713,7 +698,8 @@ def forward( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -722,6 +708,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -735,10 +722,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( x, @@ -746,10 +734,11 @@ def forward( # pylint: disable=arguments-renamed decode=decode, max_len=max_len, cache=cache, - index=index) + index=index, + dropout_rate=dropout_rate) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache @@ -761,30 +750,38 @@ def _sa_block( # pylint: disable=arguments-renamed decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: x, cache = self.self_attn( x, attn_mask=attn_mask, decode=decode, max_len=max_len, cache=cache, - index=index) - return self.dropout1(x), cache + index=index, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0) -> Tensor: + x, _ = self.multihead_attn( + x, + mem, + attn_mask=attn_mask, + dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block(self, inputs: Tensor, + dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -802,8 +799,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -817,7 +812,6 @@ def __init__(self, embed_dim: int, num_heads: int, self_attn: bool = True, - dropout_rate: float = 0., attention_temp: float = 1.0, bias: bool = False, device: Optional[torch.device] = None, @@ -826,7 +820,6 @@ def __init__(self, self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp assert self.head_dim * num_heads == self.embed_dim, \ @@ -861,7 +854,8 @@ def forward(self, decode: bool = False, max_len: Optional[int] = None, cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape @@ -887,6 +881,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -976,12 +971,12 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate) # Rearrange for output projection. attn_output = attn_output.transpose(1, 2).contiguous().view( bsz, tgt_len, embed_dim) diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py deleted file mode 100644 index a43df30d4..000000000 --- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py +++ /dev/null @@ -1,989 +0,0 @@ -import copy -import math -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F -from torch.nn.init import normal_ -from torch.nn.init import xavier_uniform_ - -DROPOUT_RATE = 0.1 - - -def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: - """Make a causal mask for self-attention. - - Args: - x: input array of shape `[batch..., len]` - device: device to store the idxs - - Returns: - A `[batch..., len, len]` shaped causal attention mask. - """ - idxs = torch.broadcast_to( - torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) - return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) - - -def make_src_mask(src, inputs_segmentation, nhead): - """Utility for creating src mask and adjust it for PyTorch Transformer API.""" - src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) - # Add segmentation block-diagonal attention mask if using segmented data. - if inputs_segmentation is not None: - src_mask = torch.logical_and( - src_mask, - torch.eq( - inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) - # Flip values and ensure numerical stability. - src_mask = torch.repeat_interleave( - torch.logical_not(src_mask), repeats=nhead, dim=0) - new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) - new_src_mask.masked_fill_(src_mask, -1e10) - return new_src_mask - - -def make_tgt_and_memory_mask(tgt, - src, - inputs_segmentation, - targets_segmentation, - decode, - nhead): - """ Utility for creating target and memory mask and adjust them for PyTorch - Transformer API.""" - if not decode: - tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), - make_causal_mask(tgt, device=tgt.device)) - memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) - else: - tgt_mask = None - memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)) - # Add segmentation block-diagonal attention masks if using segmented data. - if inputs_segmentation is not None: - tgt_mask = torch.logical_and( - tgt_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2))) - memory_mask = torch.logical_and( - memory_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) - # Flip values and ensure numerical stability. - memory_mask = torch.repeat_interleave( - torch.logical_not(memory_mask), repeats=nhead, dim=0) - new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) - new_memory_mask.masked_fill_(memory_mask, -1e10) - if tgt_mask is not None: - tgt_mask = torch.repeat_interleave( - torch.logical_not(tgt_mask), repeats=nhead, dim=0) - new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) - new_tgt_mask.masked_fill_(tgt_mask, -1e10) - tgt_mask = new_tgt_mask - return tgt_mask, new_memory_mask - - -def shift_right(x, axis=1): - """Shift the input to the right by padding on axis 1.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = (1, 0) - pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup) - padded = F.pad(x, pad_widths, mode='constant') - return padded[:, :-1] - - -class Transformer(nn.Module): - """Transformer architecture based on the model from the WMT Jax workload.""" - - def __init__(self, - ntoken: int = 32000, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.pos_encoder = PositionalEncoding(d_model) - self.shared_embedding = nn.Embedding(ntoken, d_model) - self.encoder = Encoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - self.decoder = Decoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - # Share positional encoding and embedding between encoder and decoder. - self.encoder.pos_encoder = self.pos_encoder - self.encoder.shared_embedding = self.shared_embedding - self.decoder.pos_encoder = self.pos_encoder - self.decoder.shared_embedding = self.shared_embedding - - self._reset_parameters() - - def _reset_parameters(self): - """Initiate parameters in the transformer model.""" - for module in self.modules(): - if isinstance(module, nn.Linear): - xavier_uniform_(module.weight) - if module.bias is not None: - normal_(module.bias, std=1e-6) - - def forward(self, - src: Tensor, - tgt: Tensor, - inputs_positions: Optional[Tensor] = None, - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - dropout_rate: float = DROPOUT_RATE) -> Tensor: - """ - Args: - src: Tensor, shape [batch_size, seq_len] - tgt: Tensor, shape [batch_size, seq_len] - inputs_positions: Optional[Tensor], shape [batch_size, seq_len] - targets_positions: Optional[Tensor], shape [batch_size, seq_len] - inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] - targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] - decode: bool - dropout_rate: float - - Returns: - output Tensor of shape [batch_size, seq_len, ntoken] - """ - if src.size(0) != tgt.size(0): - raise RuntimeError('The batch size of src and tgt must be equal.') - - memory = self.encoder( - src, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - dropout_rate=dropout_rate) - output = self.decoder( - tgt, - memory, - src, # just for calculating the padding mask - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - decode=decode, - dropout_rate=dropout_rate) - return output - - -class TransformerEncoder(nn.Module): - r"""TransformerEncoder is a stack of N encoder layers. Users can build the - BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. - - Args: - encoder_layer: an instance of the TransformerEncoderLayer() class. - num_layers: the number of sub-encoder-layers in the encoder. - norm: the layer normalization component (optional). - enable_nested_tensor: if True, input will automatically convert to - nested tensor (and convert back on output). This will improve - the overall performance of TransformerEncoder when padding - rate is high. - - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(12, 8) - >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - __constants__ = ['norm'] - - def __init__(self, - encoder_layer, - num_layers, - norm=None, - enable_nested_tensor=True, - mask_check=True): - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) - self.num_layers = num_layers - self.norm = norm - self.enable_nested_tensor = enable_nested_tensor - self.mask_check = mask_check - - def forward(self, src: Tensor, - mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - """Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability (optional). - - Shape: - see the docs in Transformer class. - """ - output = src - convert_to_nested = False - - for mod in self.layers: - output = mod(output, src_mask=mask, dropout_rate=dropout_rate) - - if convert_to_nested: - output = output.to_padded_tensor(0.) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class Encoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.nhead = nhead - self.shared_embedding = None - self.pos_encoder = None - encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - d_hid, - activation=activation, - glu=glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) - encoder_norm = ( - nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) - - def forward(self, - src: Tensor, - inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - src = src.to(torch.int) - src_mask = make_src_mask(src, inputs_segmentation, self.nhead) - src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) - memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) - return memory - - -class Decoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): - super().__init__() - self.nhead = nhead - self.shared_embedding = None - self.pos_encoder = None - self.decoder = TransformerDecoder(d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - nlayers, - attention_temp, - pre_ln) - - def forward( - self, - tgt: Tensor, - memory: Tensor, - src: Tensor, # just for calculating the padding mask - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - tgt = tgt.to(torch.int) - tgt_mask, memory_mask = make_tgt_and_memory_mask( - tgt, src, inputs_segmentation, targets_segmentation, - decode, self.nhead) - if not decode: - tgt = shift_right(tgt) - tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) - if decode: - tgt, cache = tgt - output = self.decoder( - tgt, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - dropout_rate=dropout_rate) - if decode: - output, cache = output - normalize = math.sqrt(output.shape[-1]) - output = torch.matmul(output, self.shared_embedding.weight.T) / normalize - if decode: - return output, cache - return output - - -class PositionalEncoding(nn.Module): - - def __init__(self, - d_model: int, - max_len: int = 256): - super().__init__() - - position = torch.arange(max_len).unsqueeze(1) - scale_factor = -math.log(10000.0) / (d_model // 2 - 1) - div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) - pe = torch.zeros(1, max_len, d_model) - pe[0, :, :d_model // 2] = torch.sin(position * div_term) - pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward( - self, - x: Tensor, - inputs_positions: Optional[Tensor] = None, - decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None, - dropout_rate: Optional[float] = 0.0 - ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: - """ - Args: - x: Tensor (shape [batch_size, seq_len, embedding_dim]) - inputs_positions: Tensor (shape [batch_size, seq_len]) or None - decode: bool - cache: Dict[str, Dict[str, Tensor]] or None - dropout_rate: Optional[float] - Returns: - Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] - """ - # We use a cache position index for tracking decoding position. - if decode: - name = self._get_name() - if cache is None: - cache = { - name: { - 'cache_index': - torch.tensor(0, dtype=torch.long, device=self.pe.device), - }, - } - pe = self.pe[0, cache[name]['cache_index'], :] - cache[name]['cache_index'] += 1 - return F.dropout(x + pe, dropout_rate, self.training), cache - if inputs_positions is None: - # normal unpacked case: - pe = self.pe[:, :x.size(1), :] - else: - # for packed data we need to use known position indices: - pe = self.pe[0, inputs_positions, :] - return F.dropout(x + pe, dropout_rate, self.training) - - -# TransformerEncoderLayer and TransformerDecoderLayer are taken from: -# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py -# Main difference is the use of custom MultiheadAttention modules. -class TransformerEncoderLayer(nn.Module): - r"""TransformerEncoderLayer is made up of self-attn and feedforward network. - This standard encoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, - Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all - you need. In Advances in Neural Information Processing Systems, - pages 6000-6010. Users may modify or implement in a different way during - application. - Args: - d_model: the number of expected features in the input (default=1024). - nhead: the number of heads in the multiheadattention models (default=16). - dim_feedforward: the dimension of the feedforward network model - (default=1024). - activation: the activation function of the intermediate layer, can be a - string ("relu" or "gelu") or a unary callable (default=F.relu). - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - pre_ln: if ``True``, layer norm is done prior to attention and - feedforward operations, respectivaly. Otherwise it's done after. - Default: ``True``. - Examples:: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(32, 10, 512) - >>> out = encoder_layer(src) - """ - __constants__ = ['pre_ln'] - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True, - device=None, - dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - - # Implementation of Feedforward model. - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.glu = glu - if self.glu: - self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.pre_ln = pre_ln - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - - self.activation = activation - - def forward(self, - src: Tensor, - src_mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: - r"""Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - dropout_rate: the dropout probability value (optional). - Shape: - see the docs in Transformer class. - """ - x = src - if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) - x = x + self._ff_block(self.norm2(x), dropout_rate) - else: - x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) - x = self.norm2(x + self._ff_block(x, dropout_rate)) - - return x - - # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, training=self.training) - - # Feed forward block: - def _ff_block(self, - inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: - x = self.activation(self.linear1(inputs)) - if self.glu: - y = self.linear_glu(inputs) - x = x * y - x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) - return F.dropout(x, dropout_rate, training=self.training) - - -# Modified to use cache for autoregressive decoding and custom -# MultiheadAttention modules. -class TransformerDecoder(nn.Module): - r"""TransformerDecoder is a stack of N decoder layers - Args: - d_model: the number of expected features in the input (default=1024) - nhead: the number of heads in the multiheadattention models (default=16) - d_hid: the dimension of the feedforward network model - (default=1024) - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - decoder_layer: an instance of the TransformerDecoderLayer() class - num_layers: the number of sub-decoder-layers in the decoder - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(12, 8) - >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = transformer_decoder(tgt, memory) - """ - __constants__ = ['norm'] - - def __init__(self, - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - num_layers, - attention_temp, - pre_ln): - super().__init__() - self.layers = nn.ModuleList([ - TransformerDecoderLayer( - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) for _ in range(num_layers) - ]) - self.num_layers = num_layers - self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - - def forward(self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - r"""Pass the inputs (and mask) through the decoder layer in turn. - Args: - tgt: the sequence to the decoder (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - decode: whether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - dropout_rate: the dropout probability value (optional) - Shape: - see the docs in Transformer class. - """ - output = tgt - - for idx, mod in enumerate(self.layers): - output, cache = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=idx, - dropout_rate=dropout_rate) - - if self.norm is not None: - output = self.norm(output) - - if decode: - return output, cache - return output - - -# Modified to use cache for autoregressive decoding and custom -# MultiheadAttention modules. -class TransformerDecoderLayer(nn.Module): - r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and - feedforward network. - This standard decoder layer is based on the paper "Attention Is All You Need". - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, - Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all - you need. In Advances in Neural Information Processing Systems, - pages 6000-6010. Users may modify or implement in a different way during - application. - Args: - d_model: the number of expected features in the input (default=1024). - nhead: the number of heads in the multiheadattention models (default=16). - dim_feedforward: the dimension of the feedforward network model - (default=1024). - activation: the activation function of the intermediate layer, can be a - string ("relu" or "gelu") or a unary callable (default=F.relu). - layer_norm_eps: the eps value in layer normalization components - (default=1e-6). - pre_ln: if ``True``, layer norm is done prior to self attention, - multihead attention and feedforward operations, respectivaly. - Otherwise it's done after. Default: ``True``. - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(32, 10, 512) - >>> tgt = torch.rand(32, 20, 512) - >>> out = decoder_layer(tgt, memory) - """ - __constants__ = ['pre_ln'] - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - pre_ln: bool = True, - attention_temp: float = 1.0, - device=None, - dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - self.multihead_attn = MultiheadAttention( - d_model, - nhead, - self_attn=False, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) - - # Implementation of Feedforward model. - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.glu = glu - if self.glu: - self.linear_glu = nn.Linear(dim_feedforward, - dim_feedforward, - **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.pre_ln = pre_ln - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - - self.activation = activation - - def forward( # pylint: disable=arguments-renamed - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - r"""Pass the inputs (and mask) through the decoder layer. - Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - decode: wether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - dropout_rate: the dropout probability value (optional) - Shape: - see the docs in Transformer class. - """ - # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf - - x = tgt - if self.pre_ln: - sa_out, cache = self._sa_block( - self.norm1(x), - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) - x = x + self._ff_block(self.norm3(x), dropout_rate) - else: - sa_out, cache = self._sa_block( - x, - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) - x = self.norm3(x + self._ff_block(x, dropout_rate)) - - return x, cache - - # Self-attention block: - def _sa_block( # pylint: disable=arguments-renamed - self, - x: Tensor, - attn_mask: Optional[Tensor], - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: - x, cache = self.self_attn( - x, - attn_mask=attn_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, self.training), cache - - # Multihead attention block: - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: - x, _ = self.multihead_attn( - x, - mem, - attn_mask=attn_mask, - dropout_rate=dropout_rate) - return F.dropout(x, dropout_rate, self.training) - - # Feed forward block. - def _ff_block(self, inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: - x = self.activation(self.linear1(inputs)) - if self.glu: - y = self.linear_glu(inputs) - x = x * y - x = self.linear2(F.dropout(x, dropout_rate, self.training)) - return F.dropout(x, dropout_rate, self.training) - - -class MultiheadAttention(nn.Module): - r"""Allows the model to jointly attend to information - from different representation subspaces. Supports self-attention and - encoder-decoder attention. - See `Attention Is All You Need `_. - .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O - where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. - Args: - embed_dim: Total dimension of the model. - num_heads: Number of parallel attention heads. Note that ``embed_dim`` will - be split across ``num_heads`` (i.e. each head will have dimension - ``embed_dim // num_heads``). - self_attn: Whether self attention or encoder-decoder attention is used. - Default: ``True``. - bias: If specified, adds bias to input / output projection layers. - Default: ``False``. - device: The device of the module. - dtype: The dtype of the module. - Examples:: - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - >>> attn_output, cache = multihead_attn(x) - """ - - def __init__(self, - embed_dim: int, - num_heads: int, - self_attn: bool = True, - attention_temp: float = 1.0, - bias: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.self_attn = self_attn - self.head_dim = embed_dim // num_heads - self.attention_temp = attention_temp - assert self.head_dim * num_heads == self.embed_dim, \ - 'embed_dim must be divisible by num_heads.' - - factory_kwargs = {'device': device, 'dtype': dtype} - if self_attn: - # Self-attention. - self.in_proj = nn.Linear( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - else: - # Encoder-decoder attention. - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - self.kv_proj = nn.Linear( - embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - - self._reset_parameters() - - def _reset_parameters(self): - """Initiate parameters in the MultiheadAttention module.""" - for module in self.modules(): - if isinstance(module, nn.Linear): - xavier_uniform_(module.weight) - if module.bias is not None: - normal_(module.bias, std=1e-6) - - def forward(self, - x: Tensor, - mem: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! - r""" - Args: - x: Batch of input sequences of shape - (batch size, sequence length, embedding dimensionality) for self - attention mechanism. See "Attention Is All You Need" for more details. - mem: Batch of input sequences of shape - (batch size, sequence length, embedding dimensionality) for - encoder-decoder attention. See "Attention Is All You Need" for more - details. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain - positions. Must be of shape :math:`(L, S)` or - :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the - batch size, :math:`L` is the target sequence length, and :math:`S` - is the source sequence length. A 2D mask will be broadcasted across - the batch while a 3D mask allows for a different mask for each entry - in the batch. Binary, byte, and float masks are supported. - For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, - a non-zero value indicates that the corresponding position is not - allowed to attend. For a float mask, the mask values will be added to - the attention weight. - decode: wether to use cache for autoregressive decoding or not. - max_len: maximum sequence length, necessary for decoding cache. - cache: cache dictionary for autoregressive decoding. - index: index of the current decoding step, necessary for decoding cache. - dropout_rate: dropout probability on ``attn_output_weights``. - Outputs: - - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where - :math:`L` is the target sequence length, :math:`N` is the batch size, - and :math:`E` is the embedding dimension ``embed_dim``. - - **cache** - For autoregressive decoding. - """ - # Shape: (batch size, sequence length, embedding dimensionality) - bsz, seq_len, embed_dim = x.size() - # In projection. - if self.self_attn: - q, k, v = self.in_proj(x).split(self.embed_dim, dim=2) - else: - q = self.q_proj(x) - k, v = self.kv_proj(mem).split(self.embed_dim, dim=2) - # This is 1 (!= seq_len) during autoreregressive decoding. - tgt_len = q.size(1) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - name = f'decoder.layers.{index}.self_attn' - loc_cache = cache[name] if decode and name in cache else None - if decode: - if loc_cache is None: - loc_cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), - } - cached_key = loc_cache['cached_key'] - cached_value = loc_cache['cached_value'] - cache_index = loc_cache['cache_index'] - # Shape check of cached keys against query input. - expected_shape = (bsz, 1, embed_dim) - if expected_shape != x.shape: - raise ValueError('Autoregressive cache shape error, expected query ' - f'shape {expected_shape} instead got {x.shape}.') - # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k - cached_value[:, cache_index:cache_index + 1, :] = v - k = cached_key - v = cached_value - cache_index += 1 - # Causal mask for cached decoder self-attention: - # our single query position should only attend to those key - # positions that have already been generated and cached, - # not the remaining zero elements. - if attn_mask is not None: - raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) - - # Update sequence length to account for complete sequence. - seq_len = k.size(1) - - # Rearrange q, k, v for multihead attention. - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - - # Check dtype and shape of attention mask. - if not decode and attn_mask is not None: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'Float and bool dtypes are supported, not {attn_mask.dtype}.' - # Ensure attn_mask's dim is 3. - if attn_mask.dim() == 3: - correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) - if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'but should be {correct_3d_size}.') - else: - raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") - # Reshape attention mask to be consistent with q, k, v. - attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) - - # Convert attention mask to float. - if attn_mask is not None and attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, -1e10) - attn_mask = new_attn_mask - - # Adjust dropout_rate probability. - attn_dropout_rate = dropout_rate if self.training else 0.0 - - # Calculate attention. - q = self.attention_temp * q - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, attn_dropout_rate) - # Rearrange for output projection. - attn_output = attn_output.transpose(1, 2).contiguous().view( - bsz, tgt_len, embed_dim) - # Output projection. - attn_output = self.out_proj(attn_output) - - if decode: - cache[name] = loc_cache - - return attn_output, cache From f7d99a62670e8c525eef295f423b47f2026f5a38 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 17 Jun 2025 21:17:19 +0000 Subject: [PATCH 060/123] fixes --- algoperf/jax_utils.py | 8 ++++---- .../librispeech_conformer/librispeech_jax/models.py | 9 ++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 3ca3f1bfc..c4904dc75 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -14,8 +14,8 @@ class Dropout(Module): """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: - 1) adding rate argument to the __call__ method - 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code + 1) adding rate argument to the __call__ method. + 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code. .. note:: When using :meth:`Module.apply() `, make sure @@ -82,8 +82,8 @@ def __call__( deterministic = merge_param("deterministic", self.deterministic, deterministic) # Override self.rate if rate is passed to __call__ - if not (self.rate is not None and rate is not None): - rate = merge_param("rate", self.rate, rate) + if rate is None: + rate = self.rate if self.legacy: if rate == 0.0: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index f7beed914..0de6b1449 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -221,12 +221,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ inputs) inputs = inputs * padding_mask - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - inputs = Dropout(rate=feed_forward_residual_dropout_rate)( + inputs = Dropout(rate=dropout_rate)( inputs, deterministic=not train) return inputs @@ -401,7 +396,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): use_bias=True, broadcast_dropout=False, attention_fn=attention_fn, - dropout_rate=config.attention_dropout_rate, + dropout_rate=dropout_rate, deterministic=not train)( inputs_q=inputs, mask=attention_mask) From 3a41559dea66881a6264d2e1c96ff4d58a530353 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 18:56:24 +0000 Subject: [PATCH 061/123] fix reference_algorithm_tests.py --- tests/reference_algorithm_tests.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..58a4a5ddc 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -184,12 +184,11 @@ def __init__(self): if 'librispeech' in workload_name: self.tokenizer = _FakeTokenizer() - def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): + def init_model_fn(self, rng): # pylint: disable=line-too-long if not (FLAGS.identical and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): - return super().init_model_fn( - rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate) + return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( f'tests.modeldiffs.{workload_name}.compare') @@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): return (FrozenDict(**jax_utils.replicate(jax_params)), FrozenDict(**jax_utils.replicate(model_state)) if model_state is not None else model_state) - return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0) + return super().init_model_fn([0]) @property def num_eval_train_examples(self): From 7c430227152b8951fb21454960f71169ab00eb09 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:12:45 +0000 Subject: [PATCH 062/123] fixes to ogbg and fastmri --- algoperf/workloads/fastmri/fastmri_jax/workload.py | 3 ++- algoperf/workloads/ogbg/ogbg_jax/models.py | 11 +++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index acdf077e1..b8067cbad 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload @@ -52,7 +53,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 59d989284..d51ca2f20 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -18,7 +18,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout, activation_fn): +def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -28,7 +28,7 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = dropout(x) + x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate) return x return make_fn @@ -47,7 +47,6 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): - dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) @@ -70,11 +69,11 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) graph = net(graph) From 894f4fb50f5bfdf0e4d2e197cf090e507a05fc15 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:29:37 +0000 Subject: [PATCH 063/123] fixes to fastmri and deepspeech --- algoperf/workloads/fastmri/fastmri_jax/workload.py | 11 +++++------ .../librispeech_conformer/librispeech_jax/models.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index b8067cbad..ccf9c6bad 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -58,12 +58,11 @@ def model_fn( del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - if train: - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + logits = self._model.apply({'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 0de6b1449..366e42195 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -642,7 +642,7 @@ def __call__(self, train, update_batch_norm: Optional[bool] = None, use_running_average_bn: Optional[bool] = None, - dropout_rate: float = DROPOUT_RATE: + dropout_rate: float = DROPOUT_RATE): config = self.config outputs = inputs From 0bcf484282777f39d01b64e41ebba773aef1c913 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:36:49 +0000 Subject: [PATCH 064/123] fixes to conformer vit --- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 4 ++-- .../librispeech_conformer/librispeech_jax/models.py | 3 --- .../librispeech_conformer/librispeech_jax/workload.py | 3 --- .../librispeech_deepspeech/librispeech_jax/models.py | 6 +----- algoperf/workloads/wmt/wmt_jax/models.py | 2 +- 5 files changed, 4 insertions(+), 14 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7c5d7bd26..091a3473e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -125,14 +125,14 @@ class Encoder(nn.Module): depth: int mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 - dropout_rate: float = 0.0 use_glu: bool = False use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, - train: bool = True) -> spec.Tensor: + train: bool = True, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): block = Encoder1DBlock( diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 366e42195..bf0eb813e 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -38,7 +38,6 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - dropout_rate: float = DROPOUT_RATE convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -191,8 +190,6 @@ class FeedForwardModule(nn.Module): @nn.compact def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): config = self.config - if dropout_rate is None: - dropout_rate = config.dropout_rate inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 8d966ef87..eec707e5f 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -71,9 +71,6 @@ def init_model_fn( else: activation_function_name = 'swish' model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, use_specaug=self.use_specaug, attention_temperature=self.attention_temperature, use_post_layer_norm=self.use_post_layer_norm, diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 84ba58ee2..b47b1359a 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -105,12 +105,8 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): kernel_init=nn.initializers.xavier_uniform())( outputs) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs = Dropout( - rate=input_dropout_rate, deterministic=not train)( + rate=dropout_rate, deterministic=not train)( outputs, rate=dropout_rate) return outputs, output_paddings diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index e262214ac..38f76db80 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -574,7 +574,7 @@ def decode( targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=droput_rate) + dropout_rate=dropout_rate) return logits.astype(self.config.dtype) def __call__(self, From 73c2276cb1907534f16b76f82e95c95400d04f8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 20:48:56 +0000 Subject: [PATCH 065/123] conformer and vit fix for dropout refactor --- algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 3 +-- .../librispeech_conformer/librispeech_jax/models.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 091a3473e..a78a5e791 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -226,8 +226,7 @@ def __call__(self, num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer', - dropout_rate=dropout_rate)( + name='Transformer',)( x, train=not train, dropout_rate=dropout_rate) if self.use_map: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index bf0eb813e..1c2d79e15 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -206,8 +206,8 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ 'config.activation_function_name values, recieved ' f'{config.activation_function_name}') inputs = activation_fn(inputs) - inputs = Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate) inputs = inputs * padding_mask @@ -665,8 +665,7 @@ def __call__(self, outputs, output_paddings = self.specaug(outputs, output_paddings) outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim, - dropout_rate=dropout_rate)( + encoder_dim=config.encoder_dim,)( outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. From 5ff94d23242a6613dc5d62579a9a4fe44d017eec Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:12:10 +0000 Subject: [PATCH 066/123] wmt fixes --- .../imagenet_vit/imagenet_jax/models.py | 54 +++++++++---------- algoperf/workloads/wmt/wmt_jax/workload.py | 2 +- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index a78a5e791..716bd4239 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -85,7 +85,7 @@ def __call__(self, deterministic=train, name='MultiHeadDotProductAttention_1')( y) - y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) @@ -121,33 +121,31 @@ def __call__(self, class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - depth: int - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - use_glu: bool = False - use_post_layer_norm: bool = False - - @nn.compact - def __call__(self, - x: spec.Tensor, - train: bool = True, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: - # Input Encoder - for lyr in range(self.depth): - block = Encoder1DBlock( - name=f'encoderblock_{lyr}', - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate)( - dropout_rate=dropout_rate) - x = block(x, train) - if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) - else: - return x + """Transformer Model Encoder for sequence to sequence translation.""" + + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False + + @nn.compact + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: + # Input Encoder + for lyr in range(self.depth): + x = Encoder1DBlock( + name=f"encoderblock_{lyr}", + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + )(x, train=train, dropout_rate=dropout_rate) + if not self.use_post_layer_norm: + return nn.LayerNorm(name="encoder_layernorm")(x) + else: + return x class MAPHead(nn.Module): diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 9548f5b7e..24d4852b8 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -278,7 +278,7 @@ def model_fn( inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, rngs={'dropout': rng}, - dropout_rate=None) + dropout_rate=dropout_rate) return logits_batch, None def _normalize_eval_metrics( From 9090e43e1970454dd250f277a7a26cbc4ff6f8dd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:16:14 +0000 Subject: [PATCH 067/123] fix linting --- algoperf/pytorch_utils.py | 7 ++- .../criteo1tb/criteo1tb_pytorch/models.py | 7 ++- .../criteo1tb/criteo1tb_pytorch/workload.py | 7 +-- .../fastmri/fastmri_pytorch/models.py | 19 ++---- .../fastmri/fastmri_pytorch/workload.py | 7 +-- .../imagenet_pytorch/workload.py | 7 +-- .../imagenet_vit/imagenet_pytorch/models.py | 18 +++--- .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- .../librispeech_pytorch/models.py | 19 +++--- .../librispeech_pytorch/workload.py | 7 +-- .../librispeech_pytorch/models.py | 3 +- .../librispeech_pytorch/workload.py | 19 +++--- .../workloads/ogbg/ogbg_pytorch/models.py | 10 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 12 ++-- algoperf/workloads/wmt/wmt_pytorch/models.py | 63 +++++++++++-------- .../workloads/wmt/wmt_pytorch/workload.py | 7 +-- .../test_model_equivalence.py | 23 +++---- .../fastmri_pytorch/test_model_equivalence.py | 14 +++-- .../test_model_equivalence.py | 16 +++-- .../test_model_equivalence.py | 24 +++---- .../test_model_equivalence.py | 24 +++---- .../ogbg_pytorch/test_model_equivalence.py | 14 +++-- .../wmt_pytorch/test_model_equivalence.py | 21 ++++--- 23 files changed, 193 insertions(+), 167 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 4af77088e..bae26dea0 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -6,8 +6,8 @@ import tensorflow as tf import torch from torch import Tensor -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from algoperf import spec @@ -84,6 +84,7 @@ def update_batch_norm_fn(module: spec.ParameterContainer, class CustomDropout(nn.Module): """A module around torch.nn.functional.dropout.""" + def __init__(self): super().__init__() self._supports_custom_dropout = True @@ -94,6 +95,7 @@ def forward(self, input: Tensor, p: float) -> Tensor: class CustomDropout2d(nn.Module): """A module around torch.nn.functional.dropout2d.""" + def __init__(self): super().__init__() self._supports_custom_dropout = True @@ -104,6 +106,7 @@ def forward(self, input: Tensor, p: float) -> Tensor: class SequentialWithDropout(nn.Sequential): """Sequential of modules with dropout.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._supports_custom_dropout = True @@ -113,5 +116,5 @@ def forward(self, x: Tensor, p: float) -> Tensor: if getattr(module, '_supports_custom_dropout', False): x = module(x, p) else: - x = module(x) + x = module(x) return x diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index f0653a665..7574de3a7 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,13 +5,15 @@ import torch from torch import nn -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout +from algoperf.pytorch_utils import SequentialWithDropout DROPOUT_RATE = 0.0 class DenseBlock(nn.Module): """Dense block with optional residual connection.""" "" + def __init__(self, module, resnet=False): super().__init__() self.module = module @@ -23,6 +25,7 @@ def forward(self, x): class DenseBlockWithDropout(nn.Module): """Dense block with optional residual connection and support for dropout.""" + def __init__(self, module, resnet=False): super().__init__() self.module = module @@ -30,7 +33,7 @@ def __init__(self, module, resnet=False): self._supports_custom_dropout = True def forward(self, x, p): - return self.module(x, p) + x if self.resnet else self.module(x, p) + return self.module(x, p) + x if self.resnet else self.module(x, p) class DotInteract(nn.Module): diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 48c6592f2..69f24c69d 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -65,9 +65,7 @@ def loss_fn( 'per_example': per_example_losses, } - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -104,7 +102,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 0b8ac5499..3e7d7671c 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -13,7 +13,8 @@ from torch.nn import functional as F from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout2d +from algoperf.pytorch_utils import SequentialWithDropout DROPOUT_RATE = 0.0 @@ -39,12 +40,8 @@ def __init__(self, self.num_channels = num_channels self.num_pool_layers = num_pool_layers - self.down_sample_layers = nn.ModuleList([ - ConvBlock(in_chans, - num_channels, - use_tanh, - use_layer_norm) - ]) + self.down_sample_layers = nn.ModuleList( + [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)]) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( @@ -58,8 +55,7 @@ def __init__(self, for _ in range(num_pool_layers - 1): self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( @@ -74,10 +70,7 @@ def __init__(self, if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward( - self, - x: Tensor, - dropout_rate: float = DROPOUT_RATE) -> Tensor: + def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor: stack = [] output = x diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 9b96230fc..1adbb57ca 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -106,9 +106,7 @@ def _build_input_queue(self, batch['volume_max'] = aux_tensors[2][RANK] yield batch - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( num_pool_layers=self.num_pool_layers, @@ -136,7 +134,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index f28eb1762..285ba3b4b 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -154,9 +154,7 @@ def _build_dataset( return dataloader - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: @@ -190,7 +188,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0 + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del dropout_rate diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 60e09edb5..9453780d0 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,7 +14,8 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ + MultiheadAttention DROPOUT_RATE = 0.0 @@ -86,9 +87,7 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__(self, - width: int, - num_heads: int = 8) -> None: + def __init__(self, width: int, num_heads: int = 8) -> None: super().__init__() self.width = width @@ -161,9 +160,7 @@ def __init__(self, self.self_attention1 = SelfAttention(self.width, self.num_heads) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu) + width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -345,10 +342,9 @@ def reset_parameters(self) -> None: def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward( - self, - x: spec.Tensor, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + def forward(self, + x: spec.Tensor, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index f86a1b1c2..e1c6844fe 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -21,9 +21,7 @@ # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( num_classes=self._num_classes, @@ -52,7 +50,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -71,8 +70,9 @@ def model_fn( } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate) + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits_batch, None diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index a6a60bf95..10b8e585a 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -73,9 +73,7 @@ def forward(self, x): class Subsample(nn.Module): - def __init__(self, - encoder_dim: int = 0, - num_bins: int = 80): + def __init__(self, encoder_dim: int = 0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim @@ -105,7 +103,8 @@ def forward(self, inputs, input_paddings, dropout_rate): outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + outputs = F.dropout( + outputs, dropout_rate, training=self.training, inplace=True) return outputs, output_paddings @@ -215,7 +214,8 @@ def forward(self, inputs, padding_mask, dropout_rate): inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + inputs = F.dropout( + inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -309,7 +309,8 @@ def forward(self, outputs, paddings, dropout_rate): outputs, key_padding_mask=paddings == 1, ) - outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True) + outputs = F.dropout( + outputs, dropout_rate, training=self.training, inplace=True) return outputs @@ -412,7 +413,8 @@ def forward(self, inputs, input_paddings, dropout_rate): inputs = activation_fn(inputs) inputs = self.lin3(inputs) - inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True) + inputs = F.dropout( + inputs, dropout_rate, training=self.training, inplace=True) return inputs @@ -460,8 +462,7 @@ def __init__(self, config: ConformerConfig): use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames ) self.subsample = Subsample( - encoder_dim=config.encoder_dim, - num_bins=preprocessing_config.num_bins) + encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 0477a7389..dbeabb16c 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -61,9 +61,7 @@ def use_gelu(self) -> bool: def attention_temperature(self) -> float: return 1.0 - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. @@ -106,7 +104,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a8480a343..644c13a16 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -367,7 +367,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate) + outputs = outputs + self.ffns[idx]( + outputs, output_paddings, dropout_rate) else: outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index bf345cfc9..7640d69a5 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -23,9 +23,7 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( @@ -55,7 +53,7 @@ def init_model_fn( else: model = torch.nn.DataParallel(model) return model, None - + def model_fn( self, params: spec.ParameterContainer, @@ -64,11 +62,16 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: # override super method, changing only the default dropout_rate - return super().model_fn( - params, augmented_and_preprocessed_input_batch, model_state, - mode, rng, update_batch_norm, dropout_rate) + return super().model_fn(params, + augmented_and_preprocessed_input_batch, + model_state, + mode, + rng, + update_batch_norm, + dropout_rate) def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index be5882333..8a40bef58 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -9,7 +9,8 @@ from torch import nn from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout +from algoperf.pytorch_utils import SequentialWithDropout DROPOUT_RATE = 0.1 @@ -93,10 +94,9 @@ def __init__(self, if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward( - self, - graph: GraphsTuple, - dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: + def forward(self, + graph: GraphsTuple, + dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: graph = graph._replace( globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 758b36b60..a45a93668 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -137,9 +137,7 @@ def _build_input_queue(self, yield batch - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( num_outputs=self._num_outputs, @@ -168,7 +166,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. @@ -188,8 +187,9 @@ def model_fn( } with contexts[mode](): - logits = model(augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate) + logits = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate) return logits, None diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a43df30d4..0de719c4b 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -226,7 +226,8 @@ def __init__(self, self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, + def forward(self, + src: Tensor, mask: Optional[Tensor] = None, dropout_rate: Optional[float] = 0.0) -> Tensor: """Pass the input through the encoder layers in turn. @@ -341,7 +342,12 @@ def forward( if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate) + tgt = self.pos_encoder( + tgt, + targets_positions, + decode=decode, + cache=cache, + dropout_rate=dropout_rate) if decode: tgt, cache = tgt output = self.decoder( @@ -364,9 +370,7 @@ def forward( class PositionalEncoding(nn.Module): - def __init__(self, - d_model: int, - max_len: int = 256): + def __init__(self, d_model: int, max_len: int = 256): super().__init__() position = torch.arange(max_len).unsqueeze(1) @@ -481,9 +485,9 @@ def __init__(self, self.activation = activation - def forward(self, - src: Tensor, - src_mask: Optional[Tensor] = None, + def forward(self, + src: Tensor, + src_mask: Optional[Tensor] = None, dropout_rate: Optional[float] = 0.0) -> Tensor: r"""Pass the input through the encoder layer. @@ -505,16 +509,16 @@ def forward(self, return x # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], + def _sa_block(self, + x: Tensor, + attn_mask: Optional[Tensor], dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, - inputs: Tensor, + def _ff_block(self, + inputs: Tensor, dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: @@ -763,18 +767,21 @@ def _sa_block( # pylint: disable=arguments-renamed return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: - def _mha_block(self, x: Tensor, mem: Tensor, + def _mha_block(self, + x: Tensor, + mem: Tensor, attn_mask: Optional[Tensor], dropout_rate: Optional[float] = 0.0) -> Tensor: x, _ = self.multihead_attn( - x, - mem, - attn_mask=attn_mask, + x, + mem, + attn_mask=attn_mask, dropout_rate=dropout_rate) return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor, + def _ff_block(self, + inputs: Tensor, dropout_rate: Optional[float] = 0.0) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: @@ -847,15 +854,17 @@ def _reset_parameters(self): if module.bias is not None: normal_(module.bias, std=1e-6) - def forward(self, - x: Tensor, - mem: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?! + def forward( + self, + x: Tensor, + mem: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0 + ) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index 4c787becc..4ec816f2f 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -166,9 +166,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -204,7 +202,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index db56b17cf..f51f1f9aa 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -4,20 +4,21 @@ python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch import os -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( - DLRMResNet as OriginalDLRMResNet, - DlrmSmall as OriginalDlrmSmall, -) -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import ( - DLRMResNet as CustomDLRMResNet, - DlrmSmall as CustomDlrmSmall, -) +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ + DLRMResNet as OriginalDLRMResNet +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ + DlrmSmall as OriginalDlrmSmall +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \ + DLRMResNet as CustomDLRMResNet +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \ + DlrmSmall as CustomDlrmSmall BATCH, DENSE, SPARSE = 16, 13, 26 FEATURES = DENSE + SPARSE diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 0d3d52980..46f1a0f5a 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -4,13 +4,17 @@ python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch import os -from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet -from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.fastmri.fastmri_pytorch.models import \ + UNet as OriginalUNet +from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \ + UNet as CustomUNet BATCH, IN_CHANS, H, W = 4, 1, 256, 256 OUT_CHANS, C, LAYERS = 1, 32, 4 diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index d19fad0ba..e6d58f5f7 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -4,14 +4,18 @@ python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch -import os import itertools +import os + +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ + ViT as OriginalVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \ + ViT as CustomVit # Model / test hyper-params BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index 4a1252a39..d9511fcc4 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -6,19 +6,21 @@ NOTE: we don't test for default dropout_rate values, since they changed. """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch import os -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( - ConformerConfig as OriginalConfig, - ConformerEncoderDecoder as OriginalModel -) -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import( - ConformerConfig as CustomConfig, - ConformerEncoderDecoder as CustomModel, -) +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + ConformerConfig as OriginalConfig +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + ConformerEncoderDecoder as OriginalModel +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ + ConformerConfig as CustomConfig +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ + ConformerEncoderDecoder as CustomModel N_LAYERS = 3 B, T = 32, 36_000 diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index 58ddb354e..610e6b18e 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -8,19 +8,21 @@ - `feed_forward_dropout_rate` (if None, 0.1) """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch import os -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( - DeepspeechEncoderDecoder as OriginalModel, - DeepspeechConfig as OriginalConfig -) -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import( - DeepspeechEncoderDecoder as CustomModel, - DeepspeechConfig as CustomConfig -) +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechConfig as OriginalConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechEncoderDecoder as OriginalModel +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ + DeepspeechConfig as CustomConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ + DeepspeechEncoderDecoder as CustomModel B, T = 32, 30_000 DEVICE = 'cuda' diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py index aaca6cebd..f5c7e992b 100644 --- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -4,13 +4,19 @@ python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py """ -from absl.testing import absltest, parameterized -from torch.testing import assert_close -import torch, os, random, numpy as np +import os +import random + +from absl.testing import absltest +from absl.testing import parameterized from jraph import GraphsTuple +import numpy as np +import torch +from torch.testing import assert_close from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel -from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel +from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \ + GNN as CustomModel B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py index 9675f1df2..918043cfd 100644 --- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -4,16 +4,19 @@ python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py """ -from absl.testing import absltest, parameterized +import os +import random + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +import torch from torch.testing import assert_close -import torch, os, random, numpy as np - -from algoperf.workloads.wmt.wmt_pytorch.models import ( - Transformer as OriginalModel, -) -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import ( - Transformer as CustomModel, -) + +from algoperf.workloads.wmt.wmt_pytorch.models import \ + Transformer as OriginalModel +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ + Transformer as CustomModel B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 DEVICE = "cuda" From 4e69255642807cbb38c9d3390898f9713e59ece7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:16:42 +0000 Subject: [PATCH 068/123] formatting --- algoperf/jax_utils.py | 104 +++++++++--------- .../workloads/cifar/cifar_jax/workload.py | 4 +- .../criteo1tb/criteo1tb_jax/models.py | 1 + .../criteo1tb/criteo1tb_jax/workload.py | 12 +- .../workloads/fastmri/fastmri_jax/models.py | 1 + .../workloads/fastmri/fastmri_jax/workload.py | 16 +-- .../imagenet_resnet/imagenet_jax/workload.py | 3 +- .../imagenet_vit/imagenet_jax/models.py | 58 +++++----- .../imagenet_vit/imagenet_jax/workload.py | 12 +- .../librispeech_jax/models.py | 15 ++- .../librispeech_jax/workload.py | 3 +- .../librispeech_jax/models.py | 4 +- .../librispeech_jax/workload.py | 9 +- .../workloads/mnist/mnist_jax/workload.py | 4 +- algoperf/workloads/ogbg/ogbg_jax/models.py | 22 +++- algoperf/workloads/ogbg/ogbg_jax/workload.py | 7 +- algoperf/workloads/wmt/wmt_jax/models.py | 20 ++-- algoperf/workloads/wmt/wmt_jax/workload.py | 7 +- 18 files changed, 164 insertions(+), 138 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index c4904dc75..369eb1b1a 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -1,17 +1,19 @@ from collections.abc import Sequence -import jax -import jax.numpy as jnp -from jax import lax, random - import flax.linen as nn -from flax.linen.module import Module, compact, merge_param +from flax.linen.module import compact +from flax.linen.module import merge_param +from flax.linen.module import Module from flax.typing import PRNGKey +import jax +from jax import lax +from jax import random +import jax.numpy as jnp # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -51,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -79,40 +81,44 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", self.deterministic, deterministic) + deterministic = merge_param("deterministic", + self.deterministic, + deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + "force_terminal": False, "force_jupyter": False, "width": 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index 3f2397f8c..c6cc50fbf 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -79,9 +79,7 @@ def sync_batch_stats( new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 57cb7f2d9..4a91a80b8 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -9,6 +9,7 @@ DROPOUT_RATE = 0.0 + class DLRMResNet(nn.Module): """Define a DLRMResNet model. diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index cb7e8cf9f..d84d18d5c 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -89,16 +89,17 @@ def init_model_fn( use_layer_norm=self.use_layer_norm, embedding_init_multiplier=self.embedding_init_multiplier) - params_rng, _= jax.random.split(rng) + params_rng, _ = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) - initial_variables = jax.jit(init_fn)( - {'params': params_rng,}, - jnp.ones(input_shape, jnp.float32)) + initial_variables = jax.jit(init_fn)({ + 'params': params_rng, + }, + jnp.ones(input_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -115,7 +116,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 5850defa7..70c7fc4a5 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -23,6 +23,7 @@ DROPOUT_RATE = 0.0 + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index ccf9c6bad..bd0aa1d0b 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -30,12 +30,11 @@ def init_model_fn( num_channels=self.num_channels, use_tanh=self.use_tanh, use_layer_norm=self.use_layer_norm, - ) + ) params_rng, _ = jax.random.split(rng) init_fn = functools.partial(self._model.init, train=False) - variables = jax.jit(init_fn)({'params': params_rng}, - fake_batch) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -53,16 +52,17 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate) return logits, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 2a255fee4..7896dcd05 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -84,7 +84,8 @@ def sync_batch_stats( def init_model_fn( self, - rng: spec.RandomState,) -> spec.ModelInitState: + rng: spec.RandomState, + ) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 716bd4239..f33dea723 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -15,6 +15,7 @@ DROPOUT_RATE = 0.0 + def posemb_sincos_2d(h: int, w: int, width: int, @@ -121,31 +122,32 @@ def __call__(self, class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - - depth: int - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - use_glu: bool = False - use_post_layer_norm: bool = False - - @nn.compact - def __call__( - self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE - ) -> spec.Tensor: - # Input Encoder - for lyr in range(self.depth): - x = Encoder1DBlock( - name=f"encoderblock_{lyr}", - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - )(x, train=train, dropout_rate=dropout_rate) - if not self.use_post_layer_norm: - return nn.LayerNorm(name="encoder_layernorm")(x) - else: - return x + """Transformer Model Encoder for sequence to sequence translation.""" + + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False + + @nn.compact + def __call__(self, + x: spec.Tensor, + train: bool = True, + dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + # Input Encoder + for lyr in range(self.depth): + x = Encoder1DBlock( + name=f"encoderblock_{lyr}", + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + )(x, train=train, dropout_rate=dropout_rate) + if not self.use_post_layer_norm: + return nn.LayerNorm(name="encoder_layernorm")(x) + else: + return x class MAPHead(nn.Module): @@ -182,7 +184,7 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: [float] = DROPOUT_RATE + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False @@ -224,8 +226,8 @@ def __call__(self, num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - name='Transformer',)( - x, train=not train, dropout_rate=dropout_rate) + name='Transformer', + )(x, train=not train, dropout_rate=dropout_rate) if self.use_map: x = MAPHead( diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 08a8f4eb1..d0fb4fd72 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -24,15 +24,12 @@ def initialized(self, key: spec.RandomState, model: nn.Module) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) params_rng, _ = jax.random.split(key) - variables = jax.jit( - model.init)({'params': params_rng}, - jnp.ones(input_shape)) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape)) model_state, params = pop(variables, "params") return params, model_state - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, @@ -57,7 +54,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 1c2d79e15..b2eee1c37 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -22,14 +22,15 @@ import jax.numpy as jnp import numpy as np +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter -from algoperf.jax_utils import Dropout DROPOUT_RATE = 0.1 + @struct.dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -92,6 +93,7 @@ class Subsample(nn.Module): input_dropout_rate: dropout rate for inputs. """ encoder_dim: int = 0 + @nn.compact def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings @@ -188,7 +190,11 @@ class FeedForwardModule(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE): + def __call__(self, + inputs, + padding_mask=None, + train=False, + dropout_rate=DROPOUT_RATE): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -218,8 +224,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_ inputs) inputs = inputs * padding_mask - inputs = Dropout(rate=dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train) return inputs @@ -583,7 +588,7 @@ def __call__(self, input_paddings, train, update_batch_norm, - use_running_average, + use_running_average, dropout_rate=DROPOUT_RATE): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index eec707e5f..1e1a1d3f8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -83,8 +83,7 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, _ = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng}, - *fake_input_batch) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) model_state, params = pop(variables, "params") diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b47b1359a..1bd998027 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -16,11 +16,11 @@ from jax.experimental import rnn import jax.numpy as jnp +from algoperf.jax_utils import Dropout from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter -from algoperf.jax_utils import Dropout Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -31,7 +31,7 @@ CarryHistory = Any Output = Any -DROPOUT_RATE=0.1 +DROPOUT_RATE = 0.1 @struct.dataclass diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2bb119439..81a56db72 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -15,9 +15,7 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function. """ model_config = models.DeepspeechConfig( @@ -36,8 +34,9 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, _ = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng,}, - *fake_input_batch) + variables = model_init_fn({ + 'params': params_rng, + }, *fake_input_batch) model_state = variables[ 'batch_stats'] if not self.layernorm_everywhere else {} diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5f3fdcf78..27bd9ae54 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -32,9 +32,7 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() initial_params = self._model.init({'params': rng}, init_val, diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index d51ca2f20..06eef6187 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -8,7 +8,8 @@ from algoperf.jax_utils import Dropout -DROPOUT_RATE=0.1 +DROPOUT_RATE = 0.1 + def _make_embed(latent_dim, name): @@ -28,7 +29,9 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate) + x = Dropout( + rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate) return x return make_fn @@ -69,11 +72,20 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate)) graph = net(graph) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e03252ed9..04a9bce2e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -17,9 +17,7 @@ class OgbgWorkload(BaseOgbgWorkload): - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, params_rng = jax.random.split(rng, 2) self._model = models.GNN( self._num_outputs, @@ -53,7 +51,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 38f76db80..1147eb34b 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -13,9 +13,9 @@ from algoperf.jax_utils import Dropout - DROPOUT_RATE = 0.1 + @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" @@ -171,7 +171,8 @@ def __call__(self, inputs, dropout_rate=DROPOUT_RATE): )( inputs) x = x * y - x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic) + x = Dropout(rate=dropout_rate)( + x, rate=dropout_rate, deterministic=cfg.deterministic) output = nn.Dense( actual_out_dim, dtype=cfg.dtype, @@ -223,7 +224,8 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -285,7 +287,8 @@ def __call__( decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -304,7 +307,8 @@ def __call__( deterministic=cfg.deterministic, )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -361,7 +365,8 @@ def __call__(self, x = AddPositionEmbs( config=cfg, decode=False, name="posembed_input")( x, inputs_positions=inputs_positions) - x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate) x = x.astype(cfg.dtype) @@ -432,7 +437,8 @@ def __call__( y = AddPositionEmbs( config=cfg, decode=cfg.decode, name="posembed_output")( y, inputs_positions=targets_positions) - y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate) y = y.astype(cfg.dtype) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index 24d4852b8..d402f9d95 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -206,9 +206,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -250,7 +248,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: [float] = models.DROPOUT_RATE + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm From 3ac97ae017defbed75aa8eedd4c1ab09b759e4c1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:20:02 +0000 Subject: [PATCH 069/123] fix formatting --- .../test_model_equivalence.py | 125 ++++++----- .../fastmri_pytorch/test_model_equivalence.py | 177 +++++++++------- .../test_model_equivalence.py | 199 +++++++++--------- .../test_model_equivalence.py | 81 ++++--- .../test_model_equivalence.py | 145 +++++++------ .../ogbg_pytorch/test_model_equivalence.py | 123 ++++++----- .../wmt_pytorch/test_model_equivalence.py | 160 +++++++------- 7 files changed, 540 insertions(+), 470 deletions(-) diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py index f51f1f9aa..733052dd0 100644 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py @@ -31,63 +31,90 @@ torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) + class ModelEquivalenceTest(parameterized.TestCase): - @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0), - dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0), - dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1), - dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1), - dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0), - dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0), + @parameterized.named_parameters( + dict( + testcase_name='DLRMResNet, p=0.0', + model='dlrm_resnet', + dropout_rate=0.0), + dict( + testcase_name='DlrmSmall, p=0.0', + model='dlrm_small', + dropout_rate=0.0), + dict( + testcase_name='DLRMResNet, p=0.1', + model='dlrm_resnet', + dropout_rate=0.1), + dict( + testcase_name='DlrmSmall, p=0.1', + model='dlrm_small', + dropout_rate=0.1), + dict( + testcase_name='DLRMResNet, p=1.0', + model='dlrm_resnet', + dropout_rate=1.0), + dict( + testcase_name='DlrmSmall, p=1.0', + model='dlrm_small', + dropout_rate=1.0), + ) + def test_forward(self, model, dropout_rate): + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) ) - def test_forward(self, model, dropout_rate): - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) + y1 = orig(x) + torch.manual_seed(SEED) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB).to(DEVICE) - - x = torch.randn(BATCH, FEATURES, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1 = orig(x) - torch.manual_seed(SEED); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED); y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), - dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), + dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + ) + def test_default_dropout(self, model): + """Test default dropout_rate.""" + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) ) - def test_default_dropout(self, model): - """Test default dropout_rate.""" - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB).to(DEVICE) - torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB).to(DEVICE) + torch.manual_seed(SEED) + orig = OrigCls(vocab_size=VOCAB).to(DEVICE) + torch.manual_seed(SEED) + cust = CustCls(vocab_size=VOCAB).to(DEVICE) + + x = torch.randn(BATCH, FEATURES, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) - x = torch.randn(BATCH, FEATURES, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py index 46f1a0f5a..1d318e8c6 100644 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py @@ -27,87 +27,104 @@ torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) + class FastMRIModeEquivalenceTest(parameterized.TestCase): - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0); y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_dropout_values(self, dropout_rate): - """Test different values of dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - self.fwd_pass(orig, cust, dropout_rate) - - - @parameterized.named_parameters( - dict(testcase_name='default', use_tanh=False, use_layer_norm=False), - dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), - dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), - dict(testcase_name='both', use_tanh=True, use_layer_norm=True), - ) - def test_arch_configs(self, use_tanh, use_layer_norm): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate, - use_tanh=use_tanh, use_layer_norm=use_layer_norm - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, - use_tanh=use_tanh, use_layer_norm=use_layer_norm - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different values of dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name='default', use_tanh=False, use_layer_norm=False), + dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), + dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), + dict(testcase_name='both', use_tanh=True, use_layer_norm=True), + ) + def test_arch_configs(self, use_tanh, use_layer_norm): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, + OUT_CHANS, + C, + LAYERS, + dropout_rate=dropout_rate, + use_tanh=use_tanh, + use_layer_norm=use_layer_norm).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, + OUT_CHANS, + C, + LAYERS, + use_tanh=use_tanh, + use_layer_norm=use_layer_norm).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py index e6d58f5f7..f51eaec7e 100644 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py @@ -18,7 +18,7 @@ ViT as CustomVit # Model / test hyper-params -BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) +BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) WIDTH, DEPTH, HEADS = 256, 4, 8 DROPOUT_RATE = None DEVICE = 'cuda' @@ -29,103 +29,110 @@ torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) + class ImageNetVitModeEquivalenceTest(parameterized.TestCase): - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, C, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0); y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.6', dropout_rate=0.6), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_dropout_values(self, dropout_rate): - """Test different dropout_rates.""" - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - - @parameterized.named_parameters([ - dict( - testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", - use_glu=use_glu, - use_post_ln=use_post_ln, - use_map=use_map, - ) - for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3) - ]) - def test_arch(self, use_glu, use_post_ln, use_map): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - - x = torch.randn(BATCH, C, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(0); y1 = orig(x) - torch.manual_seed(0); y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.6', dropout_rate=0.6), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different dropout_rates.""" + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters([ + dict( + testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", + use_glu=use_glu, + use_post_ln=use_post_ln, + use_map=use_map, + ) for use_glu, + use_post_ln, + use_map in itertools.product([False, True], repeat=3) + ]) + def test_arch(self, use_glu, use_post_ln, use_map): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py index d9511fcc4..02f3a3d84 100644 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py @@ -26,7 +26,7 @@ B, T = 32, 36_000 DEVICE = 'cuda' -os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(mode=True) @@ -35,53 +35,50 @@ class ConformerEquivalenceTest(parameterized.TestCase): - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( num_encoder_layers=N_LAYERS, attention_residual_dropout_rate=dropout_rate, conv_residual_dropout_rate=dropout_rate, feed_forward_residual_dropout_rate=dropout_rate, input_dropout_rate=dropout_rate, )).to(DEVICE) - + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval torch.manual_seed(SEED) - cust = CustomModel( - CustomConfig( - num_encoder_layers=N_LAYERS - ) - ).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py index 610e6b18e..7d6a94592 100644 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py @@ -28,7 +28,7 @@ DEVICE = 'cuda' TORCH_COMPILE = False -os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(mode=True) @@ -37,83 +37,88 @@ class DeepSpeechEquivalenceTest(parameterized.TestCase): - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - """Test different dropout_rate values.""" - - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rate values.""" + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( num_lstm_layers=2, num_ffn_layers=2, input_dropout_rate=dropout_rate, feed_forward_dropout_rate=dropout_rate, )).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig( - num_lstm_layers=2, - num_ffn_layers=2, - )).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - torch.manual_seed(SEED) - orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( + num_lstm_layers=2, + num_ffn_layers=2, + )).to(DEVICE) - if TORCH_COMPILE: - orig = torch.compile(orig); cust = torch.compile(cust) + orig.load_state_dict(cust.state_dict()) # sync weights - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig(num_lstm_layers=2, + num_ffn_layers=2)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_lstm_layers=2, + num_ffn_layers=2)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) - for mode in ('train', 'eval'): - getattr(orig, mode)(); getattr(cust, mode)() - torch.manual_seed(SEED); y1, p1 = orig(x, paddings) - torch.manual_seed(SEED); y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py index f5c7e992b..3b3feb680 100644 --- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py @@ -18,7 +18,7 @@ from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \ GNN as CustomModel -B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph +B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims DEVICE = 'cuda' @@ -30,80 +30,89 @@ def _rand_graph(): - total_nodes, total_edges = B * N, B * E - nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) - edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) - senders, receivers = [], [] - for i in range(B): - offset = i * N - s = torch.randint(N, (E,), device=DEVICE) + offset - r = torch.randint(N, (E,), device=DEVICE) + offset - senders.append(s), receivers.append(r) - senders = torch.cat(senders); receivers = torch.cat(receivers) - n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) - n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) - return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) + total_nodes, total_edges = B * N, B * E + nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) + edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) + senders, receivers = [], [] + for i in range(B): + offset = i * N + s = torch.randint(N, (E,), device=DEVICE) + offset + r = torch.randint(N, (E,), device=DEVICE) + offset + senders.append(s), receivers.append(r) + senders = torch.cat(senders) + receivers = torch.cat(receivers) + n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) + n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) + return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) class GNNEquivalenceTest(parameterized.TestCase): - @parameterized.named_parameters( - dict(testcase_name='0.0', dropout_rate=0.0), - dict(testcase_name='0.2', dropout_rate=0.2), - dict(testcase_name='0.7', dropout_rate=0.7), - dict(testcase_name='1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - """Test different dropout_rates.""" + @parameterized.named_parameters( + dict(testcase_name='0.0', dropout_rate=0.0), + dict(testcase_name='0.2', dropout_rate=0.2), + dict(testcase_name='0.7', dropout_rate=0.7), + dict(testcase_name='1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rates.""" - orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights - graph = _rand_graph() + graph = _rand_graph() - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(graph) + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(graph) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(graph, dropout_rate=dropout_rate) + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) + assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(graph) - assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph) + assert_close(y1, y2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" - @parameterized.named_parameters( - dict(testcase_name=''), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights - graph = _rand_graph() + graph = _rand_graph() - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(graph) + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(graph) - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(graph) + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph) - assert_close(y1, y2, atol=0, rtol=0) + assert_close(y1, y2, atol=0, rtol=0) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py index 918043cfd..03f289a68 100644 --- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py @@ -29,85 +29,93 @@ def _rand_tokens(bs, seqlen): - return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) + return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) class TransformerEquivalenceTest(parameterized.TestCase): - @parameterized.named_parameters( - # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention - dict(testcase_name="0.0", dropout_rate=0.0, compile=False), - dict(testcase_name="0.2", dropout_rate=0.2, compile=False), - dict(testcase_name="0.7", dropout_rate=0.7, compile=False), - dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), - dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), - dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), - ) - def test_dropout_value(self, dropout_rate, compile): - - orig = OriginalModel( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate - ).to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(src, tgt, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(src, tgt) - assert_close(y1, y2, atol=0, rtol=0) - - - @parameterized.named_parameters( - dict(testcase_name="default", compile=False), - dict(testcase_name="default_compile", compile=True), - ) - def test_default(self, compile): - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED) - y2 = cust(src, tgt) - - assert_close(y1, y2, atol=0, rtol=0) - + @parameterized.named_parameters( + # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention + dict(testcase_name="0.0", dropout_rate=0.0, compile=False), + dict(testcase_name="0.2", dropout_rate=0.2, compile=False), + dict(testcase_name="0.7", dropout_rate=0.7, compile=False), + dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), + ) + def test_dropout_value(self, dropout_rate, compile): + + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name="default", compile=False), + dict(testcase_name="default_compile", compile=True), + ) + def test_default(self, compile): + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt) + + assert_close(y1, y2, atol=0, rtol=0) + if __name__ == "__main__": - absltest.main() + absltest.main() From badf12453a56b078c3b156c0b850ec5e8158bb81 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:23:03 +0000 Subject: [PATCH 070/123] fix test --- tests/reference_algorithm_tests.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..58a4a5ddc 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -184,12 +184,11 @@ def __init__(self): if 'librispeech' in workload_name: self.tokenizer = _FakeTokenizer() - def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): + def init_model_fn(self, rng): # pylint: disable=line-too-long if not (FLAGS.identical and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): - return super().init_model_fn( - rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate) + return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( f'tests.modeldiffs.{workload_name}.compare') @@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): return (FrozenDict(**jax_utils.replicate(jax_params)), FrozenDict(**jax_utils.replicate(model_state)) if model_state is not None else model_state) - return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0) + return super().init_model_fn([0]) @property def num_eval_train_examples(self): From eff3ea19572d0c3f3497bcf997948d7b3cbe7c69 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:28:18 +0000 Subject: [PATCH 071/123] fix lint errors --- algoperf/workloads/fastmri/fastmri_jax/models.py | 1 - algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +- .../workloads/librispeech_deepspeech/librispeech_jax/models.py | 3 ++- algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 70c7fc4a5..a5fe060b9 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -13,7 +13,6 @@ github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ import functools -from typing import Optional import flax.linen as nn import jax diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index d0fb4fd72..ab9df0f62 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,6 +1,6 @@ """ImageNet workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from flax import jax_utils from flax import linen as nn diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 1bd998027..fab0b3259 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -489,7 +489,8 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) + config=config)(outputs, output_paddings, train, + dropout_rate=dropout_rate) # Run the lstm layers. for _ in range(config.num_lstm_layers): diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 06eef6187..8524bb60e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -1,6 +1,6 @@ # Forked from the init2winit implementation here # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Optional, Tuple +from typing import Tuple from flax import linen as nn import jax.numpy as jnp diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index 04a9bce2e..0535aea83 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -1,6 +1,6 @@ """OGBG workload implemented in Jax.""" import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple from flax import jax_utils import jax From f7fd6c7452ead2771a9ebc6cd1e50ba99d5f3d9a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:42:36 +0000 Subject: [PATCH 072/123] formatting --- algoperf/jax_utils.py | 92 +++++++++---------- .../librispeech_jax/models.py | 2 +- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 369eb1b1a..214a178c6 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,7 +13,7 @@ # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -53,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -81,44 +81,40 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", - self.deterministic, - deterministic) + deterministic = merge_param("deterministic", self.deterministic, deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={ - "force_terminal": False, "force_jupyter": False, "width": 240 - }, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index fab0b3259..262fc1a95 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -489,7 +489,7 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train, + config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the lstm layers. From 8fc4cc5cd7914d698271bd50583432832e8dc98c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:44:35 +0000 Subject: [PATCH 073/123] fix spacing issues --- algoperf/jax_utils.py | 92 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 214a178c6..369eb1b1a 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,7 +13,7 @@ # Custom Layers class Dropout(Module): - """Create a dropout layer. + """Create a dropout layer. Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. The reference dropout implementation is modified support changes to dropout rate during training by: 1) adding rate argument to the __call__ method. @@ -53,21 +53,21 @@ class Dropout(Module): rng_collection: the rng collection name to use when requesting an rng key. """ - rate: float | None = None - broadcast_dims: Sequence[int] = () - deterministic: bool | None = None - rng_collection: str = "dropout" - legacy: bool = True - - @compact - def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, - ): - """Applies a random dropout mask to the input. + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = "dropout" + legacy: bool = True + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. @@ -81,40 +81,44 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ - deterministic = merge_param("deterministic", self.deterministic, deterministic) + deterministic = merge_param("deterministic", + self.deterministic, + deterministic) - # Override self.rate if rate is passed to __call__ - if rate is None: - rate = self.rate + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate - if self.legacy: - if rate == 0.0: - return inputs + if self.legacy: + if rate == 0.0: + return inputs - # Prevent gradient NaNs in 1.0 edge-case. - if rate == 1.0: - return jnp.zeros_like(inputs) + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) - if deterministic: - return inputs + if deterministic: + return inputs - keep_prob = 1.0 - rate - if rng is None: - rng = self.make_rng(self.rng_collection) - broadcast_shape = list(inputs.shape) - for dim in self.broadcast_dims: - broadcast_shape[dim] = 1 - mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) - mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs, jnp.zeros_like(inputs)) # Utilities for debugging def print_jax_model_summary(model, fake_inputs): - """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240}, - ) - print(tabulate_fn(fake_inputs, train=False)) + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + "force_terminal": False, "force_jupyter": False, "width": 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) From 99c31114af33c6bd2ea5e9fcdc400131dc17bc78 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 21:53:37 +0000 Subject: [PATCH 074/123] formatting --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..1daa72848 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ pytorch_gpu = [ based_on_style = "yapf" each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true +column_limit = 80 [tool.yapfignore] ignore_patterns = ["algoperf/_version.py"] From c2f4ed0eb0fe5ceebcd9a21c7a857de644654f2e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 18 Jun 2025 22:05:33 +0000 Subject: [PATCH 075/123] formatting --- algoperf/jax_utils.py | 30 ++++++++++++++++++------------ submission_runner.py | 3 +-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 369eb1b1a..467606241 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -13,11 +13,15 @@ # Custom Layers class Dropout(Module): + # pylint: disable=line-too-long """Create a dropout layer. - Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. - The reference dropout implementation is modified support changes to dropout rate during training by: + Forked from + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes + to dropout rate during training by: 1) adding rate argument to the __call__ method. - 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code. + 2) removing the if-else condition to check for edge cases, which + will trigger a recompile for jitted code. .. note:: When using :meth:`Module.apply() `, make sure @@ -47,10 +51,11 @@ class Dropout(Module): Attributes: rate: the dropout probability. (_not_ the keep rate!) broadcast_dims: dimensions that will share the same dropout mask - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and - masked, whereas if true, no mask is applied and the inputs are returned as - is. - rng_collection: the rng collection name to use when requesting an rng key. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rng_collection: the rng collection name to use when requesting an rng + key. """ rate: float | None = None @@ -71,12 +76,13 @@ def __call__( Args: inputs: the inputs that should be randomly masked. - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and - masked, whereas if true, no mask is applied and the inputs are returned - as is. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. rate: the dropout probability. (_not_ the keep rate!) - rng: an optional PRNGKey used as the random key, if not specified, one - will be generated using ``make_rng`` with the ``rng_collection`` name. + rng: an optional PRNGKey used as the random key, if not specified, + one will be generated using ``make_rng`` with the + ``rng_collection`` name. Returns: The masked inputs reweighted to preserve mean. diff --git a/submission_runner.py b/submission_runner.py index d076a1043..221a7c21d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,8 +228,7 @@ def train_once( global_batch_size=global_batch_size) logging.info('Initializing model.') with profiler.profile('Initializing model'): - model_params, model_state = workload.init_model_fn( - model_init_rng) + model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', From b20f49dbee50e2c9b1114aa9529fb6a83a012e79 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Jun 2025 00:01:36 +0000 Subject: [PATCH 076/123] formatting --- algoperf/pytorch_utils.py | 4 ++-- algoperf/workloads/fastmri/fastmri_pytorch/models.py | 1 - algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py | 2 +- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 2 +- .../librispeech_conformer/librispeech_pytorch/models.py | 5 +++-- .../librispeech_deepspeech/librispeech_pytorch/models.py | 6 ++++-- 6 files changed, 11 insertions(+), 9 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index bae26dea0..f3d8782a4 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -89,8 +89,8 @@ def __init__(self): super().__init__() self._supports_custom_dropout = True - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout(input, p, training=self.training) + def forward(self, x: Tensor, p: float) -> Tensor: + return F.dropout(x, p, training=self.training) class CustomDropout2d(nn.Module): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 3e7d7671c..8441f06c2 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -5,7 +5,6 @@ """ from functools import partial -from typing import Optional import torch from torch import nn diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 9453780d0..6aa3306ba 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,7 +14,7 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ +from algoperf.workloads.wmt.wmt_pytorch.models import \ MultiheadAttention DROPOUT_RATE = 0.0 diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index e1c6844fe..d43e90e80 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -1,7 +1,7 @@ """ImageNet ViT workload implemented in PyTorch.""" import contextlib -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index 10b8e585a..d917151a3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import partial import math -from typing import Optional, Tuple +from typing import Tuple import torch from torch import nn @@ -475,7 +475,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) + outputs, output_paddings = self.subsample(outputs, output_paddings, + dropout_rate) for conformer in self.conformers: outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 644c13a16..589793dbd 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import os -from typing import Optional, Tuple +from typing import Tuple import torch from torch import nn @@ -358,7 +358,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) + outputs, output_paddings = self.subsample(outputs, + output_paddings, + dropout_rate) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) From 0ea37ee8977c5d33219baada62fdc60decfa5a8d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Jun 2025 00:02:05 +0000 Subject: [PATCH 077/123] fix --- .../librispeech_conformer/librispeech_pytorch/models.py | 2 +- .../librispeech_deepspeech/librispeech_pytorch/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index d917151a3..3a2eda4af 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -475,7 +475,7 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, + outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate) for conformer in self.conformers: outputs = conformer(outputs, output_paddings, dropout_rate) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 589793dbd..3d8c000e1 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -359,7 +359,7 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) outputs, output_paddings = self.subsample(outputs, - output_paddings, + output_paddings, dropout_rate) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: From 594f285c00289e03a1452c2fb674ab917f51f466 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Jun 2025 00:20:35 +0000 Subject: [PATCH 078/123] pylint fixes --- algoperf/pytorch_utils.py | 6 +++--- algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index f3d8782a4..bac171ca9 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -7,7 +7,7 @@ import torch from torch import Tensor import torch.distributed as dist -import torch.nn as nn +from torch import nn import torch.nn.functional as F from algoperf import spec @@ -100,8 +100,8 @@ def __init__(self): super().__init__() self._supports_custom_dropout = True - def forward(self, input: Tensor, p: float) -> Tensor: - return F.dropout2d(input, p, training=self.training) + def forward(self, x: Tensor, p: float) -> Tensor: + return F.dropout2d(x, p, training=self.training) class SequentialWithDropout(nn.Sequential): diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index ab9df0f62..914ab8f86 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,6 +1,6 @@ """ImageNet workload implemented in Jax.""" -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils from flax import linen as nn @@ -54,10 +54,12 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, dropout_rate: float = models.DROPOUT_RATE ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm + del use_running_average_bn train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply({'params': params}, augmented_and_preprocessed_input_batch['inputs'], From f14ff8f7c0e391578e1d785f53cff51854de8979 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Jun 2025 00:24:55 +0000 Subject: [PATCH 079/123] isort fixes --- algoperf/pytorch_utils.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +- algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index bac171ca9..b81d2969a 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -5,9 +5,9 @@ import jax import tensorflow as tf import torch +from torch import nn from torch import Tensor import torch.distributed as dist -from torch import nn import torch.nn.functional as F from algoperf import spec diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 914ab8f86..0e320b9b9 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -1,6 +1,6 @@ """ImageNet workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils from flax import linen as nn diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 6aa3306ba..cb503cd9f 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,8 +14,7 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import \ - MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention DROPOUT_RATE = 0.0 From 2a8586a34bdd0baa8f641d2bc52fa37805186ca7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Jun 2025 00:36:48 +0000 Subject: [PATCH 080/123] pylint fixes --- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 7640d69a5..0b9ce1e3c 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -65,6 +65,7 @@ def model_fn( dropout_rate: float = models.DROPOUT_RATE ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: # override super method, changing only the default dropout_rate + # pylint: disable=useless-parent-delegation return super().model_fn(params, augmented_and_preprocessed_input_batch, model_state, From ad36a7c3a409146c6bc6ebb9301b7b9beb1a6716 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 21 Jun 2025 03:06:36 +0000 Subject: [PATCH 081/123] add dropout tests --- algoperf/jax_utils.py | 4 +- .../criteo1tb/criteo1tb_jax/models_ref.py | 219 ++++++ .../fastmri/fastmri_jax/models_ref.py | 220 ++++++ .../imagenet_jax/models_ref.py | 132 ++++ .../imagenet_vit/imagenet_jax/models_ref.py | 235 ++++++ .../librispeech_jax/models_ref.py | 712 ++++++++++++++++++ .../librispeech_jax/models_ref.py | 525 +++++++++++++ .../workloads/ogbg/ogbg_jax/models_ref.py | 88 +++ algoperf/workloads/wmt/wmt_jax/models_ref.py | 604 +++++++++++++++ .../criteo1tb_jax/test_model_equivalence.py | 156 ++++ .../fastmri_jax/test_model_equivalence.py | 130 ++++ .../test_model_equivalence.py | 138 ++++ .../test_model_equivalence.py | 84 +++ .../test_model_equivalence.py | 124 +++ .../test_model_equivalence.py | 118 +++ .../wmt_pytorch_jax/test_model_equivalence.py | 121 +++ tests/test_jax_utils.py | 237 ++++++ 17 files changed, 3845 insertions(+), 2 deletions(-) create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py create mode 100644 algoperf/workloads/fastmri/fastmri_jax/models_ref.py create mode 100644 algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py create mode 100644 algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py create mode 100644 algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py create mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py create mode 100644 algoperf/workloads/ogbg/ogbg_jax/models_ref.py create mode 100644 algoperf/workloads/wmt/wmt_jax/models_ref.py create mode 100644 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/fastmri_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py create mode 100644 tests/test_jax_utils.py diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 467606241..28a4ba8c9 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -62,7 +62,7 @@ class Dropout(Module): broadcast_dims: Sequence[int] = () deterministic: bool | None = None rng_collection: str = "dropout" - legacy: bool = True + legacy: bool = False @compact def __call__( @@ -114,7 +114,7 @@ def __call__( broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) - return lax.select(mask, inputs, jnp.zeros_like(inputs)) + return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) # Utilities for debugging diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py new file mode 100644 index 000000000..8406b9eb1 --- /dev/null +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py @@ -0,0 +1,219 @@ +"""A JAX implementation of DLRM-Small.""" + +from typing import Sequence + +import flax.linen as nn +from jax import nn as jnn +import jax.numpy as jnp + + +class DLRMResNet(nn.Module): + """Define a DLRMResNet model. + + Parameters: + vocab_size: the size of a single unified embedding table. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + num_dense_features: number of dense features as the bottom mlp input. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304 + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (256, 256, 256) + mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False # Unused. + embedding_init_multiplier: float = None # Unused + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # bottom mlp + mlp_bottom_dims = self.mlp_bottom_dims + + bot_mlp_input = nn.Dense( + mlp_bottom_dims[0], + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), + )( + bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + + for dense_dim in mlp_bottom_dims[1:]: + x = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), + )( + bot_mlp_input) + bot_mlp_input += nn.relu(x) + + base_init_fn = jnn.initializers.uniform(scale=1.0) + # Embedding table init and lookup for a single unified table. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + + def scaled_init(key, shape, dtype=jnp.float_): + return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) + + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) + + embed_features = embedding_table[idx_lookup] + batch_size = bot_mlp_input.shape[0] + embed_features = jnp.reshape(embed_features, + (batch_size, 26 * self.embed_dim)) + top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + top_mlp_input = nn.Dense( + mlp_top_dims[0], + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( + top_mlp_input) + top_mlp_input = nn.relu(top_mlp_input) + for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: + fan_in = mlp_top_dims[layer_idx - 1] + x = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + top_mlp_input) + x = nn.relu(x) + if self.dropout_rate and layer_idx == num_layers_top - 2: + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + top_mlp_input += x + # In the DLRM model the last layer width is always 1. We can hardcode that + # below. + logits = nn.Dense( + 1, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( + top_mlp_input) + return logits + + +def dot_interact(concat_features): + """Performs feature interaction operation between dense or sparse features. + Input tensors represent dense or sparse features. + Pre-condition: The tensors have been stacked along dimension 1. + Args: + concat_features: Array of features with shape [B, n_features, feature_dim]. + Returns: + activations: Array representing interacted features. + """ + batch_size = concat_features.shape[0] + + # Interact features, select upper or lower-triangular portion, and reshape. + xactions = jnp.matmul(concat_features, + jnp.transpose(concat_features, [0, 2, 1])) + feature_dim = xactions.shape[-1] + + indices = jnp.array(jnp.triu_indices(feature_dim)) + num_elems = indices.shape[1] + indices = jnp.tile(indices, [1, batch_size]) + indices0 = jnp.reshape( + jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), + [1, -1]) + indices = tuple(jnp.concatenate((indices0, indices), 0)) + activations = xactions[indices] + activations = jnp.reshape(activations, [batch_size, -1]) + return activations + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304. + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (512, 256, 128) + mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False + embedding_init_multiplier: float = None + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # Bottom MLP. + for dense_dim in self.mlp_bottom_dims: + bot_mlp_input = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )( + bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + if self.use_layer_norm: + bot_mlp_input = nn.LayerNorm()(bot_mlp_input) + bot_mlp_output = bot_mlp_input + batch_size = bot_mlp_output.shape[0] + feature_stack = jnp.reshape(bot_mlp_output, + [batch_size, -1, self.embed_dim]) + + # Embedding table look-up. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + + if self.embedding_init_multiplier is None: + scale = 1 / jnp.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + + def scaled_init(key, shape, dtype=jnp.float_): + return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale + + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) + + idx_lookup = jnp.reshape(idx_lookup, [-1]) + embed_features = embedding_table[idx_lookup] + embed_features = jnp.reshape(embed_features, + [batch_size, -1, self.embed_dim]) + if self.use_layer_norm: + embed_features = nn.LayerNorm()(embed_features) + feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) + dot_interact_output = dot_interact(concat_features=feature_stack) + top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], + axis=-1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + for layer_idx, fan_out in enumerate(mlp_top_dims): + fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] + top_mlp_input = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( + top_mlp_input) + if layer_idx < (num_layers_top - 1): + top_mlp_input = nn.relu(top_mlp_input) + if self.use_layer_norm: + top_mlp_input = nn.LayerNorm()(top_mlp_input) + if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + top_mlp_input = nn.Dropout( + rate=self.dropout_rate, deterministic=not train)( + top_mlp_input) + logits = top_mlp_input + return logits \ No newline at end of file diff --git a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py new file mode 100644 index 000000000..a2d56a4b4 --- /dev/null +++ b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py @@ -0,0 +1,220 @@ +"""Jax / Flax implementation of FastMRI U-Net. + +Forked from +https://github.com/google/init2winit/blob/master/init2winit/model_lib/unet.py + +Original implementation: +github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py + +Training: +github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/unet_module.py + +Data: +github.com/facebookresearch/fastMRI/tree/main/fastmri/data +""" +import functools +from typing import Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def _instance_norm2d(x, axes, epsilon=1e-5): + # promote x to at least float32, this avoids half precision computation + # but preserves double or complex floating points + x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) + mean = jnp.mean(x, axes) + mean2 = jnp.mean(jnp.square(x), axes) + # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due + # to floating point round-off errors. + var = jnp.maximum(0., mean2 - jnp.square(mean)) + stats_shape = list(x.shape) + for axis in axes: + stats_shape[axis] = 1 + mean = mean.reshape(stats_shape) + var = var.reshape(stats_shape) + y = x - mean + mul = jnp.sqrt(var + epsilon) + y /= mul + return y + + +class UNet(nn.Module): + """Jax / Flax implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + + out_channels: Number of channels in the output to the U-Net model. + channels: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + dropout_rate: Dropout probability. + """ + num_channels: int = 32 + num_pool_layers: int = 4 + out_channels = 1 + dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + use_tanh: bool = False + use_layer_norm: bool = False + + @nn.compact + def __call__(self, x, train=True): + dropout_rate = self.dropout_rate + if dropout_rate is None: + dropout_rate = 0.0 + + # pylint: disable=invalid-name + _ConvBlock = functools.partial( + ConvBlock, + dropout_rate=dropout_rate, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm) + _TransposeConvBlock = functools.partial( + TransposeConvBlock, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm) + + down_sample_layers = [_ConvBlock(self.num_channels)] + + ch = self.num_channels + for _ in range(self.num_pool_layers - 1): + down_sample_layers.append(_ConvBlock(ch * 2)) + ch *= 2 + conv = _ConvBlock(ch * 2) + + up_conv = [] + up_transpose_conv = [] + for _ in range(self.num_pool_layers - 1): + up_transpose_conv.append(_TransposeConvBlock(ch)) + up_conv.append(_ConvBlock(ch)) + ch //= 2 + + up_transpose_conv.append(_TransposeConvBlock(ch)) + up_conv.append(_ConvBlock(ch)) + + stack = [] + output = jnp.expand_dims(x, axis=-1) + + # apply down-sampling layers + for layer in down_sample_layers: + output = layer(output, train) + stack.append(output) + output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2)) + + output = conv(output, train) + + # apply up-sampling layers + for transpose_conv, conv in zip(up_transpose_conv, up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding_right = 0 + padding_bottom = 0 + if output.shape[-2] != downsample_layer.shape[-2]: + padding_right = 1 # padding right + if output.shape[-3] != downsample_layer.shape[-3]: + padding_bottom = 1 # padding bottom + + if padding_right or padding_bottom: + padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0)) + output = jnp.pad(output, padding, mode='reflect') + + output = jnp.concatenate((output, downsample_layer), axis=-1) + output = conv(output, train) + + output = nn.Conv( + self.out_channels, kernel_size=(1, 1), strides=(1, 1))( + output) + return output.squeeze(-1) + + +class ConvBlock(nn.Module): + """A Convolutional Block. + out_channels: Number of channels in the output. + dropout_rate: Dropout probability. + """ + out_channels: int + dropout_rate: float + use_tanh: bool + use_layer_norm: bool + + @nn.compact + def __call__(self, x, train=True): + """Forward function. + Note: Pytorch is NCHW and jax/flax is NHWC. + Args: + x: Input 4D tensor of shape `(N, H, W, in_channels)`. + train: deterministic or not (use init2winit naming). + Returns: + jnp.array: Output tensor of shape `(N, H, W, out_channels)`. + """ + x = nn.Conv( + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False)( + x) + if self.use_layer_norm: + x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) + else: + # DO NOT SUBMIT check that this comment edit is correct + # InstanceNorm2d was run with no learnable params in reference code + # so this is a simple normalization along spatial dims. + x = _instance_norm2d(x, (1, 2)) + if self.use_tanh: + activation_fn = nn.tanh + else: + activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) + x = activation_fn(x) + # Ref code uses dropout2d which applies the same mask for the entire channel + # Replicated by using broadcast dims to have the same filter on HW + x = nn.Dropout( + self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x) + x = nn.Conv( + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False)( + x) + if self.use_layer_norm: + x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) + else: + x = _instance_norm2d(x, (1, 2)) + x = activation_fn(x) + x = nn.Dropout( + self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x) + return x + + +class TransposeConvBlock(nn.Module): + """A Transpose Convolutional Block. + out_channels: Number of channels in the output. + """ + out_channels: int + use_tanh: bool + use_layer_norm: bool + + @nn.compact + def __call__(self, x): + """Forward function. + Args: + x: Input 4D tensor of shape `(N, H, W, in_channels)`. + Returns: + jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`. + """ + x = nn.ConvTranspose( + self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( + x) + x = _instance_norm2d(x, (1, 2)) + if self.use_tanh: + activation_fn = nn.tanh + else: + activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) + x = activation_fn(x) + return x \ No newline at end of file diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py new file mode 100644 index 000000000..357dadc13 --- /dev/null +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py @@ -0,0 +1,132 @@ +"""Jax implementation of ResNet V1. + +Adapted from Flax example: +https://github.com/google/flax/blob/main/examples/imagenet/models.py. +""" + +import functools +from typing import Any, Callable, Optional, Tuple + +from flax import linen as nn +import jax.numpy as jnp + +from algoperf import spec + +ModuleDef = nn.Module + + +class ResNetBlock(nn.Module): + """ResNet block.""" + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable + strides: Tuple[int, int] = (1, 1) + bn_init_scale: float = 0. + + @nn.compact + def __call__(self, x: spec.Tensor) -> spec.Tensor: + residual = x + y = self.conv(self.filters, (3, 3), self.strides)(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3))(y) + y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y) + + if residual.shape != y.shape or self.strides != (1, 1): + residual = self.conv( + self.filters, (1, 1), self.strides, name='Conv_proj')( + residual) + residual = self.norm(name='BatchNorm_proj')(residual) + + return self.act(residual + y) + + +class BottleneckResNetBlock(nn.Module): + """Bottleneck ResNet block.""" + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable + strides: Tuple[int, int] = (1, 1) + bn_init_scale: Optional[float] = None + + @nn.compact + def __call__(self, x: spec.Tensor) -> spec.Tensor: + residual = x + y = self.conv(self.filters, (1, 1))(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3), self.strides)(y) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters * 4, (1, 1))(y) + y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y) + + if residual.shape != y.shape or self.strides != (1, 1): + residual = self.conv( + self.filters * 4, (1, 1), self.strides, name='Conv_proj')( + residual) + residual = self.norm(name='BatchNorm_proj')(residual) + + return self.act(residual + y) + + +class ResNet(nn.Module): + stage_sizes: Tuple[int] + block_cls: ModuleDef + num_classes: int + num_filters: int = 64 + dtype: Any = jnp.float32 + act: Callable = nn.relu + bn_init_scale: float = 0. + + @nn.compact + def __call__(self, + x: spec.Tensor, + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm + norm = functools.partial( + nn.BatchNorm, + use_running_average=use_running_average_bn, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype) + + x = conv( + self.num_filters, (7, 7), (2, 2), + padding=[(3, 3), (3, 3)], + name='Conv_init')( + x) + x = norm(name='BatchNorm_init')(x) + x = self.act(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + for i, block_size in enumerate(self.stage_sizes): + for j in range(block_size): + strides = (2, 2) if i > 0 and j == 0 else (1, 1) + x = self.block_cls( + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + bn_init_scale=self.bn_init_scale)( + x) + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense( + self.num_classes, + kernel_init=nn.initializers.normal(), + dtype=self.dtype)( + x) + return x + + +ResNet18 = functools.partial( + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) +ResNet50 = functools.partial( + ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) \ No newline at end of file diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py new file mode 100644 index 000000000..beb8a2eb8 --- /dev/null +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py @@ -0,0 +1,235 @@ +"""Jax implementation of refactored and simplified ViT. + +Forked from: +https://github.com/google/init2winit/blob/master/init2winit/model_lib/vit.py, +originally from https://github.com/google/big_vision with modifications noted. +""" + +from typing import Optional, Sequence, Union + +from flax import linen as nn +import jax.numpy as jnp + +from algoperf import spec + + +def posemb_sincos_2d(h: int, + w: int, + width: int, + temperature: int = 10_000., + dtype: jnp.dtype = jnp.float32) -> spec.Tensor: + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence + + if width % 4 != 0: + raise ValueError('Width must be mult of 4 for sincos posemb.') + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = jnp.einsum('m,d->md', y.flatten(), omega) + x = jnp.einsum('m,d->md', x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + use_glu: bool = False + dropout_rate: float = 0.0 + + @nn.compact + def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + """Applies Transformer MlpBlock module.""" + inits = { + 'kernel_init': nn.initializers.xavier_uniform(), + 'bias_init': nn.initializers.normal(stddev=1e-6), + } + + d = x.shape[2] + x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) + x = nn.gelu(x) + + if self.use_glu: + y = nn.Dense(self.mlp_dim, **inits)(x) + x = x * y + + x = nn.Dropout(rate=self.dropout_rate)(x, train) + x = nn.Dense(d, **inits)(x) + return x + + +class Encoder1DBlock(nn.Module): + """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False + dropout_rate: float = 0.0 + + @nn.compact + def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + if not self.use_post_layer_norm: + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + else: + y = x + y = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = x + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) + + return x + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + depth: int + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + dropout_rate: float = 0.0 + use_glu: bool = False + use_post_layer_norm: bool = False + + @nn.compact + def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + # Input Encoder + for lyr in range(self.depth): + block = Encoder1DBlock( + name=f'encoderblock_{lyr}', + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + dropout_rate=self.dropout_rate) + x = block(x, train) + if not self.use_post_layer_norm: + return nn.LayerNorm(name='encoder_layernorm')(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + + @nn.compact + def __call__(self, x): + n, _, d = x.shape + probe = self.param('probe', + nn.initializers.xavier_uniform(), (1, 1, d), + x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] + + +class ViT(nn.Module): + """ViT model.""" + + num_classes: int = 1000 + patch_size: Sequence[int] = (16, 16) + width: int = 768 + depth: int = 12 + mlp_dim: Optional[int] = None # Defaults to 4x input dim. + num_heads: int = 12 + rep_size: Union[int, bool] = True + dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + reinit: Optional[Sequence[str]] = None + head_zeroinit: bool = True + use_glu: bool = False + use_post_layer_norm: bool = False + use_map: bool = False + + def get_posemb(self, + seqshape: tuple, + width: int, + dtype: jnp.dtype = jnp.float32) -> spec.Tensor: + return posemb_sincos_2d(*seqshape, width, dtype=dtype) + + @nn.compact + def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: + # Patch extraction + x = nn.Conv( + self.width, + self.patch_size, + strides=self.patch_size, + padding='VALID', + name='conv_patch_extract')( + x) + + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # Add posemb before adding extra token. + x = x + self.get_posemb((h, w), c, x.dtype) + + dropout_rate = self.dropout_rate + if dropout_rate is None: + dropout_rate = 0.0 + x = nn.Dropout(rate=dropout_rate)(x, not train) + + x = Encoder( + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + dropout_rate=dropout_rate, + name='Transformer')( + x, train=not train) + + if self.use_map: + x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + else: + x = jnp.mean(x, axis=1) + + if self.rep_size: + rep_size = self.width if self.rep_size is True else self.rep_size + hid = nn.Dense(rep_size, name='pre_logits') + x = nn.tanh(hid(x)) + + if self.num_classes: + kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} + head = nn.Dense(self.num_classes, name='head', **kw) + x = head(x) + + return x \ No newline at end of file diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py new file mode 100644 index 000000000..969d9423c --- /dev/null +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py @@ -0,0 +1,712 @@ +r"""Conformer. + +This model uses a conformer network to convert speech to text. +paper : https://arxiv.org/abs/2005.08100 + +high-level overview of Conformer encoder layer. + + x = x + 0.5 * FeedForward(x) + x = x + MHSA(x) + x = x + ConvolutionBlock(x) + x = x + 0.5 * FeedForward(x) + y = layer_norm(x) +""" + +import functools +import math +from typing import Any, List, Optional + +from flax import linen as nn +from flax import struct +import jax +import jax.numpy as jnp +import numpy as np + +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ + librispeech_preprocessor as preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ + spectrum_augmenter + + +@struct.dataclass +class ConformerConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + dtype: Any = jnp.float32 + encoder_dim: int = 512 + num_attention_heads: int = 8 + num_encoder_layers: int = 4 + attention_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + attention_residual_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.0. + conv_residual_dropout_rate: Optional[float] = 0.0 + feed_forward_dropout_rate: float = 0.0 + # If None, defaults to 0.1. + feed_forward_residual_dropout_rate: Optional[float] = 0.1 + convolution_kernel_size: int = 5 + feed_forward_expansion_factor: int = 4 + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + batch_norm_momentum: float = 0.999 + batch_norm_epsilon: float = 0.001 + use_specaug: bool = True + attention_temperature: float = 1.0 + activation_function_name: str = 'swish' + use_post_layer_norm: bool = True + + +class LayerNorm(nn.Module): + """Module implementing layer normalization. + + This implementation is same as in this paper: + https://arxiv.org/pdf/1607.06450.pdf. + + note: we multiply normalized inputs by (1 + scale) and initialize scale to + zeros, this differs from default flax implementation of multiplying by scale + and initializing to ones. + """ + dim: int = 0 + epsilon: float = 1e-6 + + def setup(self): + self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) + self.bias = self.param('bias', nn.initializers.zeros, [self.dim]) + + @nn.compact + def __call__(self, inputs): + mean = jnp.mean(inputs, axis=[-1], keepdims=True) + var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) + + normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) + normed_inputs *= (1 + self.scale) + normed_inputs += self.bias + + return normed_inputs + + +class Subsample(nn.Module): + """Module to perform strided convolution in order to subsample inputs. + + Attributes: + encoder_dim: model dimension of conformer. + input_dropout_rate: dropout rate for inputs. + """ + encoder_dim: int = 0 + input_dropout_rate: float = 0.0 + + @nn.compact + def __call__(self, inputs, input_paddings, train): + output_paddings = input_paddings + outputs = jnp.expand_dims(inputs, axis=-1) + + outputs, output_paddings = Conv2dSubsampling( + input_channels=1, output_channels=self.encoder_dim)( + outputs, output_paddings) + + outputs, output_paddings = Conv2dSubsampling( + input_channels=self.encoder_dim, + output_channels=self.encoder_dim)(outputs, output_paddings) + + batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape + + outputs = jnp.reshape( + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + + outputs = nn.Dense( + self.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + outputs) + + outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( + seq_length=outputs.shape[1]) + + outputs = nn.Dropout( + rate=self.input_dropout_rate, deterministic=not train)( + outputs) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + """Helper module used in Subsample layer. + + 1) Performs strided convolution over inputs and then applies non-linearity. + 2) Also performs strided convolution over input_paddings to return the correct + paddings for downstream layers. + """ + input_channels: int = 0 + output_channels: int = 0 + filter_stride: List[int] = (2, 2) + padding: str = 'SAME' + + def setup(self): + self.filter_shape = (3, 3, self.input_channels, self.output_channels) + self.kernel = self.param('kernel', + nn.initializers.xavier_uniform(), + self.filter_shape) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + + @nn.compact + def __call__(self, inputs, paddings): + # Computing strided convolution to subsample inputs. + feature_group_count = inputs.shape[3] // self.filter_shape[2] + outputs = jax.lax.conv_general_dilated( + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count) + + outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) + outputs = nn.relu(outputs) + + # Computing correct paddings post input convolution. + input_length = paddings.shape[1] + stride = self.filter_stride[0] + + pad_len = (input_length + stride - 1) // stride * stride - input_length + out_padding = jax.lax.conv_general_dilated( + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC')) + out_padding = jnp.squeeze(out_padding, axis=-1) + + # Mask outputs by correct paddings to ensure padded elements in inputs map + # to padded value in outputs. + outputs = outputs * \ + (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + """Feedforward block of conformer layer. + """ + config: ConformerConfig + + @nn.compact + def __call__(self, inputs, padding_mask=None, train=False): + config = self.config + + inputs = LayerNorm(dim=config.encoder_dim)(inputs) + + inputs = nn.Dense( + config.encoder_dim * config.feed_forward_expansion_factor, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + inputs) + if config.activation_function_name == 'swish': + activation_fn = nn.swish + elif config.activation_function_name == 'gelu': + activation_fn = nn.gelu + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( + inputs, deterministic=not train) + + inputs = inputs * padding_mask + + inputs = nn.Dense( + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + inputs) + inputs = inputs * padding_mask + + if config.feed_forward_residual_dropout_rate is None: + feed_forward_residual_dropout_rate = 0.1 + else: + feed_forward_residual_dropout_rate = ( + config.feed_forward_residual_dropout_rate) + inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( + inputs, deterministic=not train) + + return inputs + + +class AddPositionalEmbedding(nn.Module): + """Adds (optionally learned) positional embeddings to the inputs. + + Attributes: + max_len: maximum possible length for the input + posemb_init: positional embedding initializer + """ + min_timescale: int = 1 + max_timescale: int = 10_000 + embedding_dim: int = 512 + + @nn.compact + def __call__(self, seq_length): + position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] + num_timescales = self.embedding_dim // 2 + log_timescale_increment = ( + math.log(float(self.max_timescale) / float(self.min_timescale)) / + jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)) + inv_timescales = self.min_timescale * jnp.exp( + jnp.arange(num_timescales, dtype=jnp.float32) * + -log_timescale_increment) + scaled_time = ( + position[:, :, jnp.newaxis] * + inv_timescales[jnp.newaxis, jnp.newaxis, :]) + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], + axis=2).astype(jnp.float32) + # Force usage of `np` rather than `jnp` to compute static values at trace + # time. + signal = jnp.pad(signal, + [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) + return signal + + +# Adapted from lingvo attention layer for query scaling +# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201 +class QueryScaler(nn.Module): + """A layer to scale individual dims of the query attention matrix.""" + dim: int = 0 + + def setup(self): + self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) + + @nn.compact + def __call__(self, inputs): + inputs_shape = inputs.shape + if inputs_shape[-1] != self.dim: + raise ValueError('QueryScaler expects inputs to have' + ' same last dimension as scaling param.') + + # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we + # can avoid unnecessary XLA op fusion mess on TPU. + r_softplus_0 = 1.442695041 + + scale = jnp.array(r_softplus_0, dtype=inputs.dtype) + scale *= jax.nn.softplus(self.scale) + + return inputs * scale + + +# Modifying flax linen default dot product attention function to add +# query scaling, reference to original function here : +# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121 +def dot_product_attention(query, + key, + value, + bias=None, + mask=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0., + deterministic=False, + dtype=jnp.float32, + precision=None, + temperature=1.0): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It's slightly modified to add query scaling. + It calculates the attention weights given query and key and combines the + values using the attention weights. + + Note: query, key, value needn't have any batch dimensions. + + Args: + query: queries for calculating attention with shape of + `[batch..., q_length, num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of + `[batch..., kv_length, num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of + `[batch..., kv_length, num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch..., num_heads, q_length, kv_length]`. + This can be used for incorporating causal masks, padding masks, + proximity bias, etc. + mask: mask for the attention weights. This should be broadcastable to the + shape `[batch..., num_heads, q_length, kv_length]`. + This can be used for incorporating causal masks. + Attention weights are masked out if their corresponding mask value + is `False`. + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: float32) + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + + Returns: + Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + 'q, k, v batch dims must match.') + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + 'q, k, v num_heads must match.') + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + + # compute attention weights + query = QueryScaler(dim=query.shape[-1])(query) + attn_weights = nn.attention.dot_product_attention_weights( + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision) + + # return weighted sum over values for each query position + return jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, + precision=precision) * temperature + + +class MultiHeadedSelfAttention(nn.Module): + """Self attention sub-layer used in the Conformer layer. + + Input is first normalized using layer norm. Output is processed using + multi-headed attention. + + Note: this attention implementation uses a learned scale parameter to scale + query matrix before passing it to flax attention module. + """ + config: ConformerConfig = None + + @nn.compact + def __call__(self, inputs, paddings, train): + config = self.config + mask_paddings = 1 - paddings + attention_mask = nn.make_attention_mask( + mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) + + inputs = LayerNorm(dim=config.encoder_dim)(inputs) + attention_fn = functools.partial( + dot_product_attention, temperature=config.attention_temperature) + result = nn.MultiHeadDotProductAttention( + num_heads=config.num_attention_heads, + qkv_features=config.encoder_dim, + decode=False, + dtype=config.dtype, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + use_bias=True, + broadcast_dropout=False, + attention_fn=attention_fn, + dropout_rate=config.attention_dropout_rate, + deterministic=not train)( + inputs_q=inputs, mask=attention_mask) + + if config.attention_residual_dropout_rate is None: + attention_residual_dropout_rate = 0.1 + else: + attention_residual_dropout_rate = config.attention_residual_dropout_rate + result = nn.Dropout( + rate=attention_residual_dropout_rate, deterministic=not train)( + result) + + return result + + +class BatchNorm(nn.Module): + """Implements batch norm respecting input paddings. + + This implementation takes into account input padding by masking inputs before + computing mean and variance. + + This is inspired by lingvo jax implementation of BatchNorm: + https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92 + + and the corresponding defaults for momentum and epsilon have been copied over + from lingvo. + """ + config: ConformerConfig + + def setup(self): + dim = self.config.encoder_dim + dtype = self.config.dtype + + self.ra_mean = self.variable('batch_stats', + 'mean', + lambda s: jnp.zeros(s, dtype), + dim) + self.ra_var = self.variable('batch_stats', + 'var', + lambda s: jnp.ones(s, dtype), + dim) + + self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) + self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) + + @nn.compact + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): + rank = inputs.ndim + reduce_over_dims = list(range(0, rank - 1)) + + padding = jnp.expand_dims(input_paddings, -1) + momentum = self.config.batch_norm_momentum + epsilon = self.config.batch_norm_epsilon + + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics + mask = 1.0 - padding + sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) + count_v = jnp.sum( + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + + count_v = jnp.maximum(count_v, 1.0) + mean = sum_v / count_v + + sum_vv = jnp.sum( + (inputs - mean) * (inputs - mean) * mask, + axis=reduce_over_dims, + keepdims=True) + + var = sum_vv / count_v + + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var + + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) + bn_output = (inputs - mean) * inv + self.beta + bn_output *= 1.0 - padding + + return bn_output + + +class ConvolutionBlock(nn.Module): + r"""Convolution block in conformer layer. + + architecture: + + input # (batch, time, hidden_dim) + | + layer_norm(.) # (batch, time, hidden_dim) + dense(.), dense(.) # (batch, time, 2 * hidden_dim) + | / + glu(.) # (batch, time, hidden_dim) + depthwise_conv1d(.) + batch_norm(.) + act(.) + | + dense(.) + dropout(.) + | + output + """ + config: ConformerConfig + + @nn.compact + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): + config = self.config + inputs = LayerNorm(dim=config.encoder_dim)(inputs) + + input_gated1 = nn.Dense( + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True)( + inputs) + + input_gated2 = nn.Dense( + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True)( + inputs) + + inputs = input_gated1 * jax.nn.sigmoid(input_gated2) + inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1)) + + inputs = nn.Conv( + features=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + strides=(1,), + padding='SAME', + feature_group_count=config.encoder_dim, + use_bias=False, + kernel_init=nn.initializers.xavier_uniform())( + inputs) + + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) + if config.activation_function_name == 'swish': + activation_fn = nn.swish + elif config.activation_function_name == 'gelu': + activation_fn = nn.gelu + else: + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}') + inputs = activation_fn(inputs) + inputs = nn.Dense( + config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( + inputs) + + if config.conv_residual_dropout_rate is None: + conv_residual_dropout_rate = 0.0 + else: + conv_residual_dropout_rate = config.conv_residual_dropout_rate + inputs = nn.Dropout( + rate=conv_residual_dropout_rate, deterministic=not train)( + inputs) + return inputs + + +class ConformerBlock(nn.Module): + """Implements a single conformer encoder layer. + + High level overview: + + x = x + 0.5 * FeedForward(x) + x = x + MHSA(x) + x = x + ConvolutionBlock(x) + x = x + 0.5 * FeedForward(x) + + y = layer_norm(x) + + """ + config: ConformerConfig + + @nn.compact + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): + config = self.config + padding_mask = jnp.expand_dims(1 - input_paddings, -1) + + inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( + inputs, padding_mask, train) + + inputs = inputs + MultiHeadedSelfAttention(config=self.config)( + inputs, input_paddings, train) + + inputs = inputs + \ + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) + + inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( + inputs, padding_mask, train) + + if config.use_post_layer_norm: + inputs = LayerNorm(dim=config.encoder_dim)(inputs) + + return inputs + + +class Conformer(nn.Module): + """Conformer (encoder + decoder) block. + + Takes audio input signals and outputs probability distribution over vocab size + for each time step. The output is then fed into a CTC loss which eliminates + the need for alignment with targets. + """ + config: ConformerConfig + + def setup(self): + self.specaug = spectrum_augmenter.SpecAug( + freq_mask_count=self.config.freq_mask_count, + freq_mask_max_bins=self.config.freq_mask_max_bins, + time_mask_count=self.config.time_mask_count, + time_mask_max_frames=self.config.time_mask_max_frames, + time_mask_max_ratio=self.config.time_mask_max_ratio, + time_masks_per_frame=self.config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=self.config + .use_dynamic_time_mask_max_frames) + + @nn.compact + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): + config = self.config + + outputs = inputs + output_paddings = input_paddings + + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + + # Compute normalized log mel spectrograms from input audio signal. + preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() + outputs, output_paddings = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)( + outputs, output_paddings) + + # Ablate random parts of input along temporal and frequency dimension + # following the specaug procedure in https://arxiv.org/abs/1904.08779. + if train and config.use_specaug: + outputs, output_paddings = self.specaug(outputs, output_paddings) + + # Subsample input by a factor of 4 by performing strided convolutions. + if config.input_dropout_rate is None: + input_dropout_rate = 0.1 + else: + input_dropout_rate = config.input_dropout_rate + outputs, output_paddings = Subsample( + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate)( + outputs, output_paddings, train) + + # Run the conformer encoder layers. + for _ in range(config.num_encoder_layers): + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) + + outputs = LayerNorm(config.encoder_dim)(outputs) + # Run the decoder which in this case is a trivial projection layer. + outputs = nn.Dense( + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + outputs) + + return outputs, output_paddings \ No newline at end of file diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py new file mode 100644 index 000000000..7b7c9720a --- /dev/null +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py @@ -0,0 +1,525 @@ +r"""Deepspeech. + +This model uses a deepspeech2 network to convert speech to text. +paper : https://arxiv.org/abs/1512.02595 + +# BiLSTM code contributed by bastings@ +# github : https://github.com/bastings +# webpage : https://bastings.github.io/ +""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +from flax import linen as nn +from flax import struct +import jax +from jax.experimental import rnn +import jax.numpy as jnp + +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ + librispeech_preprocessor as preprocessor +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ + spectrum_augmenter + +Array = jnp.ndarray +StateType = Union[Array, Tuple[Array, ...]] +PRNGKey = Any +Shape = Tuple[int] +Dtype = Any +Carry = Any +CarryHistory = Any +Output = Any + + +@struct.dataclass +class DeepspeechConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 + dtype: Any = jnp.float32 + encoder_dim: int = 512 + num_lstm_layers: int = 6 + num_ffn_layers: int = 3 + conv_subsampling_factor: int = 2 + conv_subsampling_layers: int = 2 + use_specaug: bool = True + freq_mask_count: int = 2 + freq_mask_max_bins: int = 27 + time_mask_count: int = 10 + time_mask_max_frames: int = 40 + time_mask_max_ratio: float = 0.05 + time_masks_per_frame: float = 0.0 + use_dynamic_time_mask_max_frames: bool = True + batch_norm_momentum: float = 0.999 + batch_norm_epsilon: float = 0.001 + # If None, defaults to 0.1. + input_dropout_rate: Optional[float] = 0.1 + # If None, defaults to 0.1. + feed_forward_dropout_rate: Optional[float] = 0.1 + enable_residual_connections: bool = True + enable_decoder_layer_norm: bool = True + bidirectional: bool = True + use_tanh: bool = False + layernorm_everywhere: bool = False + + +class Subsample(nn.Module): + """Module to perform strided convolution in order to subsample inputs. + + Attributes: + encoder_dim: model dimension of conformer. + input_dropout_rate: dropout rate for inputs. + """ + config: DeepspeechConfig + + @nn.compact + def __call__(self, inputs, output_paddings, train): + config = self.config + outputs = jnp.expand_dims(inputs, axis=-1) + + outputs, output_paddings = Conv2dSubsampling( + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=1, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh + )(outputs, output_paddings, train) + + outputs, output_paddings = Conv2dSubsampling( + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=config.encoder_dim, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh)(outputs, output_paddings, train) + + batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape + + outputs = jnp.reshape( + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + + outputs = nn.Dense( + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + outputs) + + if config.input_dropout_rate is None: + input_dropout_rate = 0.1 + else: + input_dropout_rate = config.input_dropout_rate + outputs = nn.Dropout( + rate=input_dropout_rate, deterministic=not train)( + outputs) + + return outputs, output_paddings + + +class Conv2dSubsampling(nn.Module): + """Helper module used in Subsample layer. + + 1) Performs strided convolution over inputs and then applies non-linearity. + 2) Also performs strided convolution over input_paddings to return the correct + paddings for downstream layers. + """ + input_channels: int = 0 + output_channels: int = 0 + filter_stride: List[int] = (2, 2) + padding: str = 'SAME' + encoder_dim: int = 0 + dtype: Any = jnp.float32 + batch_norm_momentum: float = 0.999 + batch_norm_epsilon: float = 0.001 + use_tanh: bool = False + + def setup(self): + self.filter_shape = (3, 3, self.input_channels, self.output_channels) + self.kernel = self.param('kernel', + nn.initializers.xavier_uniform(), + self.filter_shape) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + + @nn.compact + def __call__(self, inputs, paddings, train): + # Computing strided convolution to subsample inputs. + feature_group_count = inputs.shape[3] // self.filter_shape[2] + outputs = jax.lax.conv_general_dilated( + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count) + + outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) + + if self.use_tanh: + outputs = nn.tanh(outputs) + else: + outputs = nn.relu(outputs) + + # Computing correct paddings post input convolution. + input_length = paddings.shape[1] + stride = self.filter_stride[0] + + pad_len = (input_length + stride - 1) // stride * stride - input_length + out_padding = jax.lax.conv_general_dilated( + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC')) + out_padding = jnp.squeeze(out_padding, axis=-1) + + # Mask outputs by correct paddings to ensure padded elements in inputs map + # to padded value in outputs. + outputs = outputs * (1.0 - + jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + + return outputs, out_padding + + +class FeedForwardModule(nn.Module): + """Feedforward block of conformer layer.""" + config: DeepspeechConfig + + @nn.compact + def __call__(self, inputs, input_paddings=None, train=False): + padding_mask = jnp.expand_dims(1 - input_paddings, -1) + config = self.config + + if config.layernorm_everywhere: + inputs = LayerNorm(config.encoder_dim)(inputs) + else: + inputs = BatchNorm(config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, + input_paddings, + train) + inputs = nn.Dense( + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + inputs) + if config.use_tanh: + inputs = nn.tanh(inputs) + else: + inputs = nn.relu(inputs) + inputs *= padding_mask + + if config.feed_forward_dropout_rate is None: + feed_forward_dropout_rate = 0.1 + else: + feed_forward_dropout_rate = config.feed_forward_dropout_rate + inputs = nn.Dropout(rate=feed_forward_dropout_rate)( + inputs, deterministic=not train) + + return inputs + + +class LayerNorm(nn.Module): + """Module implementing layer normalization. + + This implementation is same as in this paper: + https://arxiv.org/pdf/1607.06450.pdf. + + note: we multiply normalized inputs by (1 + scale) and initialize scale to + zeros, this differs from default flax implementation of multiplying by scale + and initializing to ones. + """ + dim: int = 0 + epsilon: float = 1e-6 + + def setup(self): + self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) + self.bias = self.param('bias', nn.initializers.zeros, [self.dim]) + + @nn.compact + def __call__(self, inputs): + mean = jnp.mean(inputs, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) + + normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) + normed_inputs *= (1 + self.scale) + normed_inputs += self.bias + + return normed_inputs + + +class BatchNorm(nn.Module): + """Implements batch norm respecting input paddings. + + This implementation takes into account input padding by masking inputs before + computing mean and variance. + + This is inspired by lingvo jax implementation of BatchNorm: + https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92 + + and the corresponding defaults for momentum and epsilon have been copied over + from lingvo. + """ + encoder_dim: int = 0 + dtype: Any = jnp.float32 + batch_norm_momentum: float = 0.999 + batch_norm_epsilon: float = 0.001 + + def setup(self): + dim = self.encoder_dim + dtype = self.dtype + + self.ra_mean = self.variable('batch_stats', + 'mean', + lambda s: jnp.zeros(s, dtype), + dim) + self.ra_var = self.variable('batch_stats', + 'var', + lambda s: jnp.ones(s, dtype), + dim) + + self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) + self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) + + def _get_default_paddings(self, inputs): + """Gets the default paddings for an input.""" + in_shape = list(inputs.shape) + in_shape[-1] = 1 + + return jnp.zeros(in_shape, dtype=inputs.dtype) + + @nn.compact + def __call__(self, inputs, input_paddings=None, train=False): + rank = inputs.ndim + reduce_over_dims = list(range(0, rank - 1)) + + if input_paddings is None: + padding = self._get_default_paddings(inputs) + else: + padding = jnp.expand_dims(input_paddings, -1) + + momentum = self.batch_norm_momentum + epsilon = self.batch_norm_epsilon + + if train: + mask = 1.0 - padding + sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) + count_v = jnp.sum( + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + + sum_v = jax.lax.psum(sum_v, axis_name='batch') + count_v = jax.lax.psum(count_v, axis_name='batch') + + count_v = jnp.maximum(count_v, 1.0) + mean = sum_v / count_v + variance = (inputs - mean) * (inputs - mean) * mask + + sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) + + sum_vv = jax.lax.psum(sum_vv, axis_name='batch') + var = sum_vv / count_v + + self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var + else: + mean = self.ra_mean.value + var = self.ra_var.value + + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) + + bn_output = (inputs - mean) * inv + self.beta + bn_output *= 1.0 - padding + + return bn_output + # return inputs + + +class CudnnLSTM(nn.Module): + features: int + num_layers: int = 1 + dropout_rate: float = 0.0 + bidirectional: bool = False + + @nn.compact + def __call__( + self, + inputs: Array, + segmentation_mask: Optional[Array] = None, + return_carry: Optional[bool] = None, + deterministic: bool = False, + initial_states: Optional[Tuple[Array, Array]] = None, + use_cuda: bool = True, + ) -> Union[Array, Tuple[Array, Carry]]: + + if jax.devices()[0].platform != 'gpu': + use_cuda = False + + batch_size = inputs.shape[0] + input_size = inputs.shape[2] + num_directions = 2 if self.bidirectional else 1 + dropout = 0.0 if deterministic else self.dropout_rate + + weights = self.param( + 'weights', + rnn.init_lstm_weight, + input_size, + self.features, + self.num_layers, + self.bidirectional, + ) + + if initial_states is None: + h_0 = jnp.zeros( + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, + ) + c_0 = jnp.zeros( + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, + ) + else: + h_0, c_0 = initial_states + + if segmentation_mask is not None: + seq_lengths = jnp.sum(1 - segmentation_mask, axis=1, dtype=jnp.int32) + else: + seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) + + if use_cuda: + y, h, c = rnn.lstm( + x=inputs, h_0=h_0, c_0=c_0, weights=weights, + seq_lengths=seq_lengths, input_size=input_size, + hidden_size=self.features, num_layers=self.num_layers, + dropout=dropout, bidirectional=self.bidirectional, + ) + else: + weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights( + weights, input_size) + y, h, c = rnn.lstm_ref( + x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh, + b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths, + input_size=input_size, hidden_size=self.features, + num_layers=self.num_layers, dropout=dropout, + bidirectional=self.bidirectional, + ) + + if return_carry: + return y, (h, c) + + return y + + @nn.nowrap + def unpack_weights( + self, weights: Array, input_size: int + ) -> Tuple[ + Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]: + return jax.experimental.rnn.unpack_lstm_weights( + weights, + input_size, + self.features, + self.num_layers, + self.bidirectional, + ) + + +class BatchRNN(nn.Module): + """Implements a single deepspeech encoder layer. + """ + config: DeepspeechConfig + + @nn.compact + def __call__(self, inputs, input_paddings, train): + config = self.config + + if config.layernorm_everywhere: + inputs = LayerNorm(config.encoder_dim)(inputs) + else: + inputs = BatchNorm(config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, + input_paddings, + train) + output = CudnnLSTM( + features=config.encoder_dim // 2, + bidirectional=config.bidirectional, + num_layers=1)(inputs, input_paddings) + + return output + + +class Deepspeech(nn.Module): + """Conformer (encoder + decoder) block. + + Takes audio input signals and outputs probability distribution over vocab size + for each time step. The output is then fed into a CTC loss which eliminates + the need for alignment with targets. + """ + config: DeepspeechConfig + + def setup(self): + config = self.config + self.specaug = spectrum_augmenter.SpecAug( + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + ) + + @nn.compact + def __call__(self, inputs, input_paddings, train): + config = self.config + + outputs = inputs + output_paddings = input_paddings + + # Compute normalized log mel spectrograms from input audio signal. + preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() + outputs, output_paddings = preprocessor.MelFilterbankFrontend( + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs, + output_paddings) + + # Ablate random parts of input along temporal and frequency dimension + # following the specaug procedure in https://arxiv.org/abs/1904.08779. + if config.use_specaug and train: + outputs, output_paddings = self.specaug(outputs, output_paddings) + + # Subsample input by a factor of 4 by performing strided convolutions. + outputs, output_paddings = Subsample( + config=config)(outputs, output_paddings, train) + + # Run the lstm layers. + for _ in range(config.num_lstm_layers): + if config.enable_residual_connections: + outputs = outputs + BatchRNN(config)(outputs, output_paddings, train) + else: + outputs = BatchRNN(config)(outputs, output_paddings, train) + + for _ in range(config.num_ffn_layers): + if config.enable_residual_connections: + outputs = outputs + FeedForwardModule(config=self.config)( + outputs, output_paddings, train) + else: + outputs = FeedForwardModule(config=self.config)(outputs, + output_paddings, + train) + + # Run the decoder which in this case is a trivial projection layer. + if config.enable_decoder_layer_norm: + outputs = LayerNorm(config.encoder_dim)(outputs) + + outputs = nn.Dense( + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform())( + outputs) + + return outputs, output_paddings \ No newline at end of file diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py new file mode 100644 index 000000000..f0a9e3dc1 --- /dev/null +++ b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py @@ -0,0 +1,88 @@ +# Forked from the init2winit implementation here +# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from typing import Optional, Tuple + +from flax import linen as nn +import jax.numpy as jnp +import jraph + + +def _make_embed(latent_dim, name): + + def make_fn(inputs): + return nn.Dense(features=latent_dim, name=name)(inputs) + + return make_fn + + +def _make_mlp(hidden_dims, dropout, activation_fn): + """Creates a MLP with specified dimensions.""" + + @jraph.concatenated_args + def make_fn(inputs): + x = inputs + for dim in hidden_dims: + x = nn.Dense(features=dim)(x) + x = nn.LayerNorm()(x) + x = activation_fn(x) + x = dropout(x) + return x + + return make_fn + + +class GNN(nn.Module): + """Defines a graph network. + The model assumes the input data is a jraph.GraphsTuple without global + variables. The final prediction will be encoded in the globals. + """ + num_outputs: int + latent_dim: int = 256 + hidden_dims: Tuple[int] = (256,) + # If None, defaults to 0.1. + dropout_rate: Optional[float] = 0.1 + num_message_passing_steps: int = 5 + activation_fn_name: str = 'relu' + + @nn.compact + def __call__(self, graph, train): + if self.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = self.dropout_rate + dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) + + graph = graph._replace( + globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) + + embedder = jraph.GraphMapFeatures( + embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), + embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) + graph = embedder(graph) + + if self.activation_fn_name == 'relu': + activation_fn = nn.relu + elif self.activation_fn_name == 'gelu': + activation_fn = nn.gelu + elif self.activation_fn_name == 'silu': + activation_fn = nn.silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + + for _ in range(self.num_message_passing_steps): + net = jraph.GraphNetwork( + update_edge_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_node_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_global_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + + graph = net(graph) + + # Map globals to represent the final result + decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) + graph = decoder(graph) + + return graph.globals \ No newline at end of file diff --git a/algoperf/workloads/wmt/wmt_jax/models_ref.py b/algoperf/workloads/wmt/wmt_jax/models_ref.py new file mode 100644 index 000000000..e1f44aaa6 --- /dev/null +++ b/algoperf/workloads/wmt/wmt_jax/models_ref.py @@ -0,0 +1,604 @@ +"""Transformer-based machine translation model. + +Reference https://github.com/google/flax/tree/main/examples/wmt. +""" + +from typing import Any, Callable, Optional + +from flax import linen as nn +from flax import struct +from jax import lax +import jax.numpy as jnp +import numpy as np + + +@struct.dataclass +class TransformerConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + share_embeddings: bool = True + dtype: Any = jnp.float32 + vocab_size: int = 32000 + emb_dim: int = 1024 + num_heads: int = 16 + num_layers: int = 6 + qkv_dim: int = 1024 + mlp_dim: int = 1024 + max_len: int = 256 + activation: Callable = nn.relu + glu: bool = False + #If None, defaults to 0.1. + dropout_rate: Optional[float] = 0.1 + #If None, defaults to 0.1. + attention_dropout_rate: Optional[float] = 0.1 + attention_temp: float = 1.0 + deterministic: bool = False + decode: bool = False + kernel_init: Callable = nn.initializers.xavier_uniform() + bias_init: Callable = nn.initializers.normal(stddev=1e-6) + posemb_init: Optional[Callable] = None + pre_ln: bool = True + + +def shift_right(x, axis=1): + """Shift the input to the right by padding on axis 1.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + padded = jnp.pad( + x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + return padded[:, :-1] + + +def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): + """1D Sinusoidal Position Embedding Initializer. + + Args: + max_len: maximum possible length for the input. + min_scale: float: minimum frequency-scale in sine grating. + max_scale: float: maximum frequency-scale in sine grating. + + Returns: + output: init function returning `(1, max_len, d_feature)` + """ + + def init(key, shape, dtype=np.float32): + """Sinusoidal init.""" + del key, dtype + d_feature = shape[-1] + pe = np.zeros((max_len, d_feature), dtype=np.float32) + position = np.arange(0, max_len)[:, np.newaxis] + scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) + div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) + pe[:, :d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] + return jnp.array(pe) + + return init + + +class AddPositionEmbs(nn.Module): + """Adds (optionally learned) positional embeddings to the inputs. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + decode: whether to run in single-position autoregressive mode. + """ + config: TransformerConfig + decode: bool = False + + @nn.compact + def __call__(self, inputs, inputs_positions=None): + """Applies AddPositionEmbs module. + + By default this layer uses a fixed sinusoidal embedding table. If a + learned position embedding is desired, pass an initializer to + posemb_init in the configuration. + + Args: + inputs: input data. + inputs_positions: input position indices for packed sequences. + + Returns: + output: `(bs, timesteps, in_dim)` + """ + cfg = self.config + # inputs.shape is (batch_size, seq_len, emb_dim) + assert inputs.ndim == 3, ('Number of dimensions should be 3,' + f' but it is: {inputs.ndim}') + length = inputs.shape[1] + pos_emb_shape = (1, cfg.max_len, inputs.shape[-1]) + if cfg.posemb_init is None: + # Use a fixed (non-learned) sinusoidal position embedding. + pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None, + pos_emb_shape, + None) + else: + pos_embedding = self.param('pos_embedding', + cfg.posemb_init, + pos_emb_shape) + pe = pos_embedding[:, :length, :] + + # We use a cache position index for tracking decoding position. + if self.decode: + is_initialized = self.has_variable('cache', 'cache_index') + cache_index = self.variable('cache', + 'cache_index', + lambda: jnp.array(0, dtype=jnp.uint32)) + if is_initialized: + i = cache_index.value + cache_index.value = i + 1 + _, _, df = pos_embedding.shape + pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) + if inputs_positions is None: + # normal unpacked case: + return inputs + pe + else: + # for packed data we need to use known position indices: + return inputs + jnp.take(pe[0], inputs_positions, axis=0) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ + config: TransformerConfig + out_dim: Optional[int] = None + + @nn.compact + def __call__(self, inputs): + """Applies Transformer MlpBlock module.""" + cfg = self.config + actual_out_dim = ( + inputs.shape[-1] if self.out_dim is None else self.out_dim) + x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init)( + inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init)( + inputs) + x = x * y + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init)( + x) + output = nn.Dropout(rate=dropout_rate)( + output, deterministic=cfg.deterministic) + return output + + +class Encoder1DBlock(nn.Module): + """Transformer encoder layer. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ + config: TransformerConfig + + @nn.compact + def __call__(self, inputs, encoder_mask=None): + """Applies Encoder1DBlock module. + + Args: + inputs: input data. + encoder_mask: encoder self-attention mask. + + Returns: + output after transformer encoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Attention block. + assert inputs.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs + if cfg.attention_dropout_rate is None: + attention_dropout_rate = 0.1 + else: + attention_dropout_rate = cfg.attention_dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=attention_dropout_rate, + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) + + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = x + inputs + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # MLP block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = MlpBlock(config=cfg)(y) + + return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) + + +class EncoderDecoder1DBlock(nn.Module): + """Transformer encoder-decoder layer. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ + config: TransformerConfig + + @nn.compact + def __call__(self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None): + """Applies EncoderDecoder1DBlock module. + + Args: + targets: input data for decoder + encoded: input data from encoder + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output after transformer encoder-decoder block. + """ + cfg = self.config + pre_ln = cfg.pre_ln + + # Decoder block. + assert targets.ndim == 3 + x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets + + if cfg.attention_dropout_rate is None: + attention_dropout_rate = 0.1 + else: + attention_dropout_rate = cfg.attention_dropout_rate + x = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=attention_dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = x + targets + if not pre_ln: + x = nn.LayerNorm(dtype=cfg.dtype)(x) + + # Encoder-Decoder block. + y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x + y = nn.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=attention_dropout_rate, + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = y + x + if not pre_ln: + y = nn.LayerNorm(dtype=cfg.dtype)(y) + + # MLP block. + z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y + z = MlpBlock(config=cfg)(z) + + return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ + config: TransformerConfig + shared_embedding: Any = None + + @nn.compact + def __call__(self, inputs, inputs_positions=None, encoder_mask=None): + """Applies Transformer model on the inputs. + + Args: + inputs: input data + inputs_positions: input subsequence positions for packed examples. + encoder_mask: decoder self-attention mask. + + Returns: + output of a transformer encoder. + """ + cfg = self.config + assert inputs.ndim == 2 # (batch, len) + + # Input Embedding + if self.shared_embedding is None: + input_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0)) + else: + input_embed = self.shared_embedding + x = inputs.astype('int32') + x = input_embed(x) + x = AddPositionEmbs( + config=cfg, decode=False, name='posembed_input')( + x, inputs_positions=inputs_positions) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + + x = x.astype(cfg.dtype) + + # Input Encoder + for lyr in range(cfg.num_layers): + x = Encoder1DBlock( + config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) + + encoded = ( + nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) + if cfg.pre_ln else x) + + return encoded + + +class Decoder(nn.Module): + """Transformer Model Decoder for sequence to sequence translation. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ + config: TransformerConfig + shared_embedding: Any = None + + @nn.compact + def __call__(self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None): + """Applies Transformer model on the inputs. + + Args: + encoded: encoded input data from encoder. + targets: target inputs. + targets_positions: input subsequence positions for packed examples. + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. + + Returns: + output of a transformer decoder. + """ + cfg = self.config + + assert encoded.ndim == 3 # (batch, len, depth) + assert targets.ndim == 2 # (batch, len) + + # Target Embedding + if self.shared_embedding is None: + output_embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0)) + else: + output_embed = self.shared_embedding + + y = targets.astype('int32') + if not cfg.decode: + y = shift_right(y) + y = output_embed(y) + y = AddPositionEmbs( + config=cfg, decode=cfg.decode, name='posembed_output')( + y, inputs_positions=targets_positions) + if cfg.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = cfg.dropout_rate + y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + + y = y.astype(cfg.dtype) + + # Target-Input Decoder + for lyr in range(cfg.num_layers): + y = EncoderDecoder1DBlock( + config=cfg, name=f'encoderdecoderblock_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask) + y = ( + nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) + if cfg.pre_ln else y) + + # Use the transpose of embedding matrix for logit transform. + logits = output_embed.attend(y.astype(jnp.float32)) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + return logits + + +class Transformer(nn.Module): + """Transformer Model for sequence to sequence translation. + + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ + config: TransformerConfig + + def setup(self): + cfg = self.config + + if cfg.share_embeddings: + if cfg.vocab_size is not None: + assert cfg.vocab_size == cfg.vocab_size, ( + "can't share embedding with different vocab sizes.") + self.shared_embedding = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0)) + else: + self.shared_embedding = None + + self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) + + def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): + """Applies Transformer encoder-branch on the inputs. + + Args: + inputs: input data. + inputs_positions: input subsequence positions for packed examples. + inputs_segmentation: input segmentation info for packed examples. + + Returns: + encoded feature array from the transformer encoder. + """ + cfg = self.config + # Make padding attention mask. + encoder_mask = nn.make_attention_mask( + inputs > 0, inputs > 0, dtype=cfg.dtype) + # Add segmentation block-diagonal attention mask if using segmented data. + if inputs_segmentation is not None: + encoder_mask = nn.combine_masks( + encoder_mask, + nn.make_attention_mask( + inputs_segmentation, + inputs_segmentation, + jnp.equal, + dtype=cfg.dtype)) + return self.encoder( + inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) + + def decode( + self, + encoded, + inputs, # only needed for masks + targets, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None): + """Applies Transformer decoder-branch on encoded-input and target. + + Args: + encoded: encoded input data from encoder. + inputs: input data (only needed for masking). + targets: target data. + targets_positions: target subsequence positions for packed examples. + inputs_segmentation: input segmentation info for packed examples. + targets_segmentation: target segmentation info for packed examples. + + Returns: + logits array from transformer decoder. + """ + cfg = self.config + + # Make padding attention masks. + if cfg.decode: + decoder_mask = None + encoder_decoder_mask = nn.make_attention_mask( + jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype) + else: + decoder_mask = nn.combine_masks( + nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), + nn.make_causal_mask(targets, dtype=cfg.dtype)) + encoder_decoder_mask = nn.make_attention_mask( + targets > 0, inputs > 0, dtype=cfg.dtype) + + # Add segmentation block-diagonal attention masks if using segmented data. + if inputs_segmentation is not None: + decoder_mask = nn.combine_masks( + decoder_mask, + nn.make_attention_mask( + targets_segmentation, + targets_segmentation, + jnp.equal, + dtype=cfg.dtype)) + encoder_decoder_mask = nn.combine_masks( + encoder_decoder_mask, + nn.make_attention_mask( + targets_segmentation, + inputs_segmentation, + jnp.equal, + dtype=cfg.dtype)) + logits = self.decoder( + encoded, + targets, + targets_positions=targets_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask) + return logits.astype(self.config.dtype) + + def __call__(self, + inputs, + targets, + inputs_positions=None, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None): + """Applies Transformer model on the inputs. + + Args: + inputs: input data. + targets: target data. + inputs_positions: input subsequence positions for packed examples. + targets_positions: target subsequence positions for packed examples. + inputs_segmentation: input segmentation info for packed examples. + targets_segmentation: target segmentation info for packed examples. + + Returns: + logits array from full transformer. + """ + encoded = self.encode( + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation) + + return self.decode( + encoded, + inputs, # only used for masks + targets, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation) \ No newline at end of file diff --git a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py new file mode 100644 index 000000000..10aeaa650 --- /dev/null +++ b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py @@ -0,0 +1,156 @@ +""" +Runs fwd pass with random input for our DLRM models and compares outputs. +Run it as: + python3 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +# import equinox as eqx + +from jax.tree_util import tree_structure, tree_leaves, tree_map + + +def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): + """ + A custom function to check if two PyTrees are equal, handling floats with a tolerance. + """ + # 1. Check if the structures are the same + if tree_structure(a) != tree_structure(b): + return False + + # 2. Define a comparison function for leaves + def leaf_comparator(x, y): + # Use allclose for floating-point JAX arrays + if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): + return jnp.allclose(x, y, rtol=rtol, atol=atol) + # Use standard equality for everything else + else: + return x == y + + # 3. Map the comparison function over the trees and check if all results are True + # We also need to flatten the results of the tree_map and check if all are True + comparison_tree = tree_map(leaf_comparator, a, b) + all_equal = all(tree_leaves(comparison_tree)) + + return all_equal + +from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \ + DLRMResNet as OriginalDLRMResNet +from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \ + DlrmSmall as OriginalDlrmSmall +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ + DLRMResNet as CustomDLRMResNet +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ + DlrmSmall as CustomDlrmSmall + +BATCH, DENSE, SPARSE = 16, 13, 26 +FEATURES = DENSE + SPARSE +VOCAB = 1000 +DEVICE = 'cuda' +SEED = 1996 + + +class ModelEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='DLRMResNet, p=0.0', + model='dlrm_resnet', + dropout_rate=0.0), + dict( + testcase_name='DlrmSmall, p=0.0', + model='dlrm_small', + dropout_rate=0.0), + dict( + testcase_name='DLRMResNet, p=0.1', + model='dlrm_resnet', + dropout_rate=0.1), + dict( + testcase_name='DlrmSmall, p=0.1', + model='dlrm_small', + dropout_rate=0.1), + ) + def test_forward(self, model, dropout_rate): + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + fake_batch = jnp.ones((2, 39)) + assert dropout_rate == 0.1 + orig_model = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) + cust_model = CustCls(vocab_size=VOCAB) + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + assert pytrees_are_equal( + initial_params_original, initial_params_custom, rtol=1e-6) + + x = jax.random.normal(data_rng, shape=(BATCH, FEATURES)) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), + dict(testcase_name='DlrmSmall, default', model='dlrm_small'), + ) + def test_default_dropout(self, model): + """Test default dropout_rate.""" + OrigCls, CustCls = ( + (OriginalDLRMResNet, CustomDLRMResNet) + if model == 'dlrm_resnet' + else (OriginalDlrmSmall, CustomDlrmSmall) + ) + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + fake_batch = jnp.ones((2, 39)) + orig_model = OrigCls(vocab_size=VOCAB) + cust_model = CustCls(vocab_size=VOCAB) + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + x = jax.random.normal(data_rng, shape=(BATCH, FEATURES)) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py new file mode 100644 index 000000000..1d318e8c6 --- /dev/null +++ b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py @@ -0,0 +1,130 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.fastmri.fastmri_pytorch.models import \ + UNet as OriginalUNet +from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \ + UNet as CustomUNet + +BATCH, IN_CHANS, H, W = 4, 1, 256, 256 +OUT_CHANS, C, LAYERS = 1, 32, 4 +DEVICE = 'cuda' +TORCH_COMPILE = False +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +class FastMRIModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different values of dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name='default', use_tanh=False, use_layer_norm=False), + dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), + dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), + dict(testcase_name='both', use_tanh=True, use_layer_norm=True), + ) + def test_arch_configs(self, use_tanh, use_layer_norm): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalUNet( + IN_CHANS, + OUT_CHANS, + C, + LAYERS, + dropout_rate=dropout_rate, + use_tanh=use_tanh, + use_layer_norm=use_layer_norm).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomUNet( + IN_CHANS, + OUT_CHANS, + C, + LAYERS, + use_tanh=use_tanh, + use_layer_norm=use_layer_norm).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py new file mode 100644 index 000000000..f51eaec7e --- /dev/null +++ b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py @@ -0,0 +1,138 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +""" + +import itertools +import os + +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ + ViT as OriginalVit +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \ + ViT as CustomVit + +# Model / test hyper-params +BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) +WIDTH, DEPTH, HEADS = 256, 4, 8 +DROPOUT_RATE = None +DEVICE = 'cuda' +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + + def fwd_pass(self, orig, cust, dropout_rate): + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(0) + y2 = cust(x, dropout_rate) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.1', dropout_rate=0.1), + dict(testcase_name='p=0.6', dropout_rate=0.6), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_dropout_values(self, dropout_rate): + """Test different dropout_rates.""" + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters([ + dict( + testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", + use_glu=use_glu, + use_post_ln=use_post_ln, + use_map=use_map, + ) for use_glu, + use_post_ln, + use_map in itertools.product([False, True], repeat=3) + ]) + def test_arch(self, use_glu, use_post_ln, use_map): + """Test different architecture configurations, fixed dropout_rate.""" + dropout_rate = 0.1 + + torch.manual_seed(SEED) + orig = OriginalVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + dropout_rate=dropout_rate, + ).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomVit( + width=WIDTH, + depth=DEPTH, + num_heads=HEADS, + use_glu=use_glu, + use_post_layer_norm=use_post_ln, + use_map=use_map, + ).to(DEVICE) + + cust.load_state_dict(orig.state_dict()) # sync weights + self.fwd_pass(orig, cust, dropout_rate) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) + cust.load_state_dict(orig.state_dict()) # sync weights + + x = torch.randn(BATCH, C, H, W, device=DEVICE) + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(0) + y1 = orig(x) + torch.manual_seed(0) + y2 = cust(x) + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py new file mode 100644 index 000000000..02f3a3d84 --- /dev/null +++ b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py @@ -0,0 +1,84 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py + +NOTE: we don't test for default dropout_rate values, since they changed. +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + ConformerConfig as OriginalConfig +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + ConformerEncoderDecoder as OriginalModel +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ + ConformerConfig as CustomConfig +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ + ConformerEncoderDecoder as CustomModel + +N_LAYERS = 3 +B, T = 32, 36_000 +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class ConformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_encoder_layers=N_LAYERS, + attention_residual_dropout_rate=dropout_rate, + conv_residual_dropout_rate=dropout_rate, + feed_forward_residual_dropout_rate=dropout_rate, + input_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py new file mode 100644 index 000000000..7d6a94592 --- /dev/null +++ b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py @@ -0,0 +1,124 @@ +""" +Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. +Run with: + python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py + +`dropout_rate` controls the following args: +- `input_dropout_rate` (if None, 0.1 +- `feed_forward_dropout_rate` (if None, 0.1) +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import torch +from torch.testing import assert_close + +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechConfig as OriginalConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechEncoderDecoder as OriginalModel +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ + DeepspeechConfig as CustomConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ + DeepspeechEncoderDecoder as CustomModel + +B, T = 32, 30_000 +DEVICE = 'cuda' +TORCH_COMPILE = False + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(mode=True) +SEED = 1996 + + +class DeepSpeechEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='p=0.0', dropout_rate=0.0), + dict(testcase_name='p=0.2', dropout_rate=0.2), + dict(testcase_name='p=0.7', dropout_rate=0.7), + dict(testcase_name='p=1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rate values.""" + + torch.manual_seed(SEED) + orig = OriginalModel( + OriginalConfig( + num_lstm_layers=2, + num_ffn_layers=2, + input_dropout_rate=dropout_rate, + feed_forward_dropout_rate=dropout_rate, + )).to(DEVICE) + + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig( + num_lstm_layers=2, + num_ffn_layers=2, + )).to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + torch.manual_seed(SEED) + orig = OriginalModel(OriginalConfig(num_lstm_layers=2, + num_ffn_layers=2)).to(DEVICE) + torch.manual_seed(SEED) + cust = CustomModel(CustomConfig(num_lstm_layers=2, + num_ffn_layers=2)).to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + if TORCH_COMPILE: + orig = torch.compile(orig) + cust = torch.compile(cust) + + x = torch.randn(B, T, device=DEVICE) + paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + torch.manual_seed(SEED) + y1, p1 = orig(x, paddings) + torch.manual_seed(SEED) + y2, p2 = cust(x, paddings) + assert_close(y1, y2, atol=0, rtol=0) + assert_close(p1, p2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py new file mode 100644 index 000000000..3b3feb680 --- /dev/null +++ b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py @@ -0,0 +1,118 @@ +""" +Runs fwd pass with random graphs for OGBG GNN models and compares outputs. +Run with: + python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +""" + +import os +import random + +from absl.testing import absltest +from absl.testing import parameterized +from jraph import GraphsTuple +import numpy as np +import torch +from torch.testing import assert_close + +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel +from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \ + GNN as CustomModel + +B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph +NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims +DEVICE = 'cuda' + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) +SEED = 1996 + + +def _rand_graph(): + total_nodes, total_edges = B * N, B * E + nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) + edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) + senders, receivers = [], [] + for i in range(B): + offset = i * N + s = torch.randint(N, (E,), device=DEVICE) + offset + r = torch.randint(N, (E,), device=DEVICE) + offset + senders.append(s), receivers.append(r) + senders = torch.cat(senders) + receivers = torch.cat(receivers) + n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) + n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) + return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) + + +class GNNEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='0.0', dropout_rate=0.0), + dict(testcase_name='0.2', dropout_rate=0.2), + dict(testcase_name='0.7', dropout_rate=0.7), + dict(testcase_name='1.0', dropout_rate=1.0), + ) + def test_forward(self, dropout_rate): + """Test different dropout_rates.""" + + orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name=''),) + def test_default_dropout(self): + """Test default dropout_rate.""" + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + orig.load_state_dict(cust.state_dict()) # sync weights + + graph = _rand_graph() + + for mode in ('train', 'eval'): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(graph) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(graph) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py new file mode 100644 index 000000000..03f289a68 --- /dev/null +++ b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py @@ -0,0 +1,121 @@ +""" +Runs fwd pass with random input for WMT Transformer models and compares outputs. +Run with: + python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +""" + +import os +import random + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +import torch +from torch.testing import assert_close + +from algoperf.workloads.wmt.wmt_pytorch.models import \ + Transformer as OriginalModel +from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ + Transformer as CustomModel + +B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 +DEVICE = "cuda" +SEED = 1996 + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.use_deterministic_algorithms(True) + + +def _rand_tokens(bs, seqlen): + return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) + + +class TransformerEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention + dict(testcase_name="0.0", dropout_rate=0.0, compile=False), + dict(testcase_name="0.2", dropout_rate=0.2, compile=False), + dict(testcase_name="0.7", dropout_rate=0.7, compile=False), + dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), + dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), + dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), + ) + def test_dropout_value(self, dropout_rate, compile): + + orig = OriginalModel( + dropout_rate=dropout_rate, + attention_dropout_rate=dropout_rate).to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt, dropout_rate=dropout_rate) + + assert_close(y1, y2, atol=0, rtol=0) + + if mode == 'eval': # one extra test: omit dropout at eval + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt) + assert_close(y1, y2, atol=0, rtol=0) + + @parameterized.named_parameters( + dict(testcase_name="default", compile=False), + dict(testcase_name="default_compile", compile=True), + ) + def test_default(self, compile): + + orig = OriginalModel().to(DEVICE) + cust = CustomModel().to(DEVICE) + + orig.load_state_dict(cust.state_dict()) # sync weights + + if compile: + orig = torch.compile(orig) + cust = torch.compile(cust) + + src = _rand_tokens(B, SRC_LEN) + tgt = _rand_tokens(B, TGT_LEN) + + for mode in ("train", "eval"): + getattr(orig, mode)() + getattr(cust, mode)() + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y1 = orig(src, tgt) + + torch.manual_seed(SEED) + random.seed(SEED) + np.random.seed(SEED) + y2 = cust(src, tgt) + + assert_close(y1, y2, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py new file mode 100644 index 000000000..04713915f --- /dev/null +++ b/tests/test_jax_utils.py @@ -0,0 +1,237 @@ +""" +Test algoperf.jax_utils.Dropout by comparing to flax.linen.Dropout +Run it as: pytest +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import flax.linen as nn + +from jax.tree_util import tree_structure, tree_leaves, tree_map +from algoperf.jax_utils import Dropout +from functools import partial + + +SEED = 1996 +DEFAULT_DROPOUT = 0.5 + + +def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): + """ + A custom function to check if two PyTrees are equal, handling floats with a tolerance. + """ + # 1. Check if the structures are the same + if tree_structure(a) != tree_structure(b): + return False + + # 2. Define a comparison function for leaves + def leaf_comparator(x, y): + # Use allclose for floating-point JAX arrays + if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): + return jnp.allclose(x, y, rtol=rtol, atol=atol) + # Use standard equality for everything else + else: + return x == y + + # 3. Map the comparison function over the trees and check if all results are True + # We also need to flatten the results of the tree_map and check if all are True + comparison_tree = tree_map(leaf_comparator, a, b) + all_equal = all(tree_leaves(comparison_tree)) + + return all_equal + + +class LegacyDropoutModel(nn.Module): + dropout_rate: float = DEFAULT_DROPOUT + + @nn.compact + def __call__(self, x, train): + return nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + + +class DropoutModel(nn.Module): + + @nn.compact + def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT): + return Dropout( + rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate) + + +class ModelEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, + mode="train"), + dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), + dict( + testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, + mode="train"), + dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + ) + def test_forward(self, dropout_rate, mode): + """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and + eval mode. + """ + + # initialize models + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init({"params": rng}, + fake_batch, + train=False) + initial_variables_custom = cust_model.init({"params": rng}, + fake_batch, + train=False) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6) + + # forward pass + x = jnp.ones((10,)) + + train = mode == "train" + y1 = orig_model.apply( + initial_variables_original, + x, + train=train, + rngs={"dropout": dropout_rng}) + y2 = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={"dropout": dropout_rng}, + ) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict( + testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, + mode="train"), + dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), + dict( + testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, + mode="train"), + dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + ) + def test_dropout_update(self, dropout_rate, mode): + """Call forward pass of Dropout layer with two different dropout rates + and check that the output matches to flax.linen.Dropout in train and + eval mode. + """ + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init({"params": rng}, + fake_batch, + train=False) + + initial_variables_custom = cust_model.init({"params": rng}, + fake_batch, + train=False) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6) + + # forward pass + x = jnp.ones((10,)) + + train = mode == "train" + y1 = orig_model.apply( + initial_variables_original, + x, + train=train, + rngs={"dropout": dropout_rng}) + + _ = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=0.9, + rngs={"dropout": dropout_rng}, + ) + + y2 = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={"dropout": dropout_rng}, + ) + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict( + testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, + mode="train"), + dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), + dict( + testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, + mode="train"), + dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + ) + def test_jitted_updates(self, dropout_rate, mode): + """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and + eval mode. + """ + + # initialize models + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init({"params": rng}, + fake_batch, + train=False) + initial_variables_custom = cust_model.init({"params": rng}, + fake_batch, + train=False) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6) + + # forward pass + x = jnp.ones((10,)) + + train = mode == "train" + jitted_original_apply = jax.jit( + partial(orig_model.apply), static_argnames=['train']) + jitted_custom_apply = jax.jit( + partial(cust_model.apply), static_argnames=['train']) + + + def multiple_fwd_passes_custom_layer(): + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y2 = jitted_custom_apply( + initial_variables_custom, + x, + train=train, + dropout_rate=d, + rngs={"dropout": dropout_rng}, + ) + return y2 + + def multiple_fwd_passes_original_layer(): + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y1 = jitted_original_apply( + initial_variables_original, + x, + train=train, + rngs={"dropout": dropout_rng}) + +if __name__ == "__main__": + absltest.main() From d3f25d8ec082f37541a92f340e41cbb0b9a6b269 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 21 Jun 2025 08:05:03 +0000 Subject: [PATCH 082/123] add tests --- .../criteo1tb/criteo1tb_jax/models_ref.py | 2 +- .../librispeech_jax/models_ref.py | 6 +- algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +- .../workloads/ogbg/ogbg_jax/models_ref.py | 16 +- .../fastmri_jax/test_model_equivalence.py | 174 ++++++++-------- .../test_model_equivalence.py | 188 ++++++++---------- .../test_model_equivalence.py | 147 ++++++++------ .../test_model_equivalence.py | 161 +++++++-------- .../test_model_equivalence.py | 118 ----------- .../wmt_pytorch_jax/test_model_equivalence.py | 121 ----------- 10 files changed, 337 insertions(+), 598 deletions(-) delete mode 100644 tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py index 8406b9eb1..d0352e290 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py @@ -23,7 +23,7 @@ class DLRMResNet(nn.Module): mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 - dropout_rate: float = 0.0 + dropout_rate: float = 0.1 use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py index 969d9423c..f168833d3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py @@ -36,12 +36,12 @@ class ConformerConfig: encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 + attention_dropout_rate: float = 0.1 # If None, defaults to 0.1. attention_residual_dropout_rate: Optional[float] = 0.1 # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.0 - feed_forward_dropout_rate: float = 0.0 + conv_residual_dropout_rate: Optional[float] = 0.1 + feed_forward_dropout_rate: float = 0.1 # If None, defaults to 0.1. feed_forward_residual_dropout_rate: Optional[float] = 0.1 convolution_kernel_size: int = 5 diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 8524bb60e..7207e033d 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -93,4 +93,4 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) - return graph.globals + return graph.globals \ No newline at end of file diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py index f0a9e3dc1..ca3d89426 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py @@ -15,7 +15,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout, activation_fn): +def _make_mlp(hidden_dims, activation_fn, train, dropout_rate): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -25,7 +25,7 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = dropout(x) + x = nn.Dropout(rate=dropout_rate, deterministic=not train)(x) return x return make_fn @@ -46,11 +46,7 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train): - if self.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = self.dropout_rate - dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) + dropout_rate = self.dropout_rate graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) @@ -73,11 +69,11 @@ def __call__(self, graph, train): for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_node_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), update_global_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) graph = net(graph) diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py index 1d318e8c6..fbb6d2499 100644 --- a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py +++ b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py @@ -6,125 +6,113 @@ import os + from absl.testing import absltest from absl.testing import parameterized -import torch -from torch.testing import assert_close +import jax +import jax.numpy as jnp +# import equinox as eqx + -from algoperf.workloads.fastmri.fastmri_pytorch.models import \ +from algoperf.workloads.fastmri.fastmri_jax.models_ref import \ UNet as OriginalUNet -from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \ +from algoperf.workloads.fastmri.fastmri_jax.models import \ UNet as CustomUNet BATCH, IN_CHANS, H, W = 4, 1, 256, 256 OUT_CHANS, C, LAYERS = 1, 32, 4 -DEVICE = 'cuda' -TORCH_COMPILE = False SEED = 1996 -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - -class FastMRIModeEquivalenceTest(parameterized.TestCase): - - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) +class ModelEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), + dict( + testcase_name='UNet, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='UNet, p=0.1', + dropout_rate=0.1), ) - def test_dropout_values(self, dropout_rate): - """Test different values of dropout_rate.""" + def test_forward(self, dropout_rate): + OrigCls, CustCls = (OriginalUNet, CustomUNet) - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) + kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS) + orig_model = OrigCls(**kwargs) + cust_model = CustCls(**kwargs) - self.fwd_pass(orig, cust, dropout_rate) + fake_batch = jnp.ones((BATCH, IN_CHANS, H, W)) - @parameterized.named_parameters( - dict(testcase_name='default', use_tanh=False, use_layer_norm=False), - dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), - dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), - dict(testcase_name='both', use_tanh=True, use_layer_norm=True), - ) - def test_arch_configs(self, use_tanh, use_layer_norm): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, - OUT_CHANS, - C, - LAYERS, - dropout_rate=dropout_rate, - use_tanh=use_tanh, - use_layer_norm=use_layer_norm).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomUNet( - IN_CHANS, - OUT_CHANS, - C, - LAYERS, - use_tanh=use_tanh, - use_layer_norm=use_layer_norm).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) - - self.fwd_pass(orig, cust, dropout_rate) + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + # fwd + x = jax.random.normal(data_rng, shape=(BATCH, H, W)) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) @parameterized.named_parameters( - dict(testcase_name=''),) + dict(testcase_name='UNet, default'), + ) def test_default_dropout(self): """Test default dropout_rate.""" + OrigCls, CustCls = (OriginalUNet, CustomUNet) - torch.manual_seed(SEED) - orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + kwargs = dict(num_pool_layers=LAYERS, + num_channels=IN_CHANS, + ) + orig_model = OrigCls(**kwargs) + cust_model = CustCls(**kwargs) + fake_batch = jnp.ones((2, IN_CHANS, H, W)) + + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + # fwd + x = jax.random.normal(data_rng, shape=(BATCH, H, W)) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=0, rtol=0) if __name__ == '__main__': absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py index f51eaec7e..6806ca99f 100644 --- a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py +++ b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py @@ -1,138 +1,110 @@ """ Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. Run it as: - python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py + python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py """ -import itertools import os from absl.testing import absltest from absl.testing import parameterized -import torch -from torch.testing import assert_close +import jax +import jax.numpy as jnp -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_vit.imagenet_jax.models_ref import \ ViT as OriginalVit -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \ +from algoperf.workloads.imagenet_vit.imagenet_jax.models import \ ViT as CustomVit # Model / test hyper-params -BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) -WIDTH, DEPTH, HEADS = 256, 4, 8 -DROPOUT_RATE = None -DEVICE = 'cuda' -SEED = 1996 +INPUT_SHAPE = (2, 224, 124, 3) +SEED = 1994 -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) +class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='ViT, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='ViT, p=0.1', + dropout_rate=0.1), + ) + def test_forward(self, dropout_rate): + OrigCls, CustCls = (OriginalVit, CustomVit) -class ImageNetVitModeEquivalenceTest(parameterized.TestCase): + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls() + cust_model = CustCls() + + fake_batch = jnp.ones(INPUT_SHAPE) + + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + # fwd + x = jax.random.normal(data_rng, shape=INPUT_SHAPE) - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, C, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.6', dropout_rate=0.6), - dict(testcase_name='p=1.0', dropout_rate=1.0), + dict(testcase_name='UNet, default'), ) - def test_dropout_values(self, dropout_rate): - """Test different dropout_rates.""" - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters([ - dict( - testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", - use_glu=use_glu, - use_post_ln=use_post_ln, - use_map=use_map, - ) for use_glu, - use_post_ln, - use_map in itertools.product([False, True], repeat=3) - ]) - def test_arch(self, use_glu, use_post_ln, use_map): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name=''),) def test_default_dropout(self): """Test default dropout_rate.""" + OrigCls, CustCls = (OriginalVit, CustomVit) + + + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls() + cust_model = CustCls() + + fake_batch = jnp.ones(INPUT_SHAPE) + + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) - torch.manual_seed(SEED) - orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights + # fwd + x = jax.random.normal(data_rng, INPUT_SHAPE) - x = torch.randn(BATCH, C, H, W, device=DEVICE) for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) + assert jnp.allclose(y1, y2, atol=0, rtol=0) if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py index 02f3a3d84..c5d849dce 100644 --- a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py @@ -1,84 +1,117 @@ """ -Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. -Run with: - python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py - -NOTE: we don't test for default dropout_rate values, since they changed. +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py """ import os from absl.testing import absltest from absl.testing import parameterized -import torch -from torch.testing import assert_close +import jax +import jax.numpy as jnp -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerConfig as OriginalConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerEncoderDecoder as OriginalModel -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ - ConformerConfig as CustomConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ - ConformerEncoderDecoder as CustomModel +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ConformerConfig as CustClsConfig +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import Conformer as CustCls -N_LAYERS = 3 -B, T = 32, 36_000 -DEVICE = 'cuda' +from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import ConformerConfig as OrigClsConfig +from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import Conformer as OrigCls -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(mode=True) -SEED = 1996 +# Model / test hyper-params +INPUT_SHAPE = [(3200,), (3200,)] +SEED = 1994 -class ConformerEquivalenceTest(parameterized.TestCase): +class ModeEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), + dict( + testcase_name='Conformer, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='Conformer, p=0.1', + dropout_rate=0.1), ) def test_forward(self, dropout_rate): - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( - num_encoder_layers=N_LAYERS, - attention_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - )).to(DEVICE) + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) + orig_model = OrigCls(OrigClsConfig(attention_dropout_rate=dropout_rate, + attention_residual_dropout_rate = dropout_rate, + conv_residual_dropout_rate = dropout_rate, + feed_forward_dropout_rate = dropout_rate, + feed_forward_residual_dropout_rate = dropout_rate, + input_dropout_rate=dropout_rate)) + cust_model = CustCls(CustClsConfig()) - orig.load_state_dict(cust.state_dict()) # sync weights + fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + initial_params_original = orig_model.init({'params': rng}, + *fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + *fake_batch, + train=False) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() + # fwd + x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] + + for train in [True]: + (y1, _), _ = orig_model.apply( + initial_params_original, + *x, + train=train, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats'],) + (y2, _), _ = cust_model.apply( + initial_params_custom, + *x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats']) + + assert jnp.allclose(y1, y2) - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) + @parameterized.named_parameters( + dict(testcase_name='Conformer, default'), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls(OrigClsConfig()) + cust_model = CustCls(CustClsConfig()) + + fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] + + initial_params_original = orig_model.init({'params': rng}, + *fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + *fake_batch, + train=False) + + # fwd + x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] + + for mode in ('train', 'eval'): + train = mode == 'train' + (y1, _), _ = orig_model.apply( + initial_params_original, + *x, + train=train, + rngs={'dropout': dropout_rng}, mutable=['batch_stats']) + (y2, _), _ = cust_model.apply( + initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats']) + + + assert jnp.allclose(y1, y2) if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py index 7d6a94592..5155199f5 100644 --- a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py +++ b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py @@ -1,124 +1,113 @@ """ -Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. -Run with: - python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py - -`dropout_rate` controls the following args: -- `input_dropout_rate` (if None, 0.1 -- `feed_forward_dropout_rate` (if None, 0.1) +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py """ import os from absl.testing import absltest from absl.testing import parameterized -import torch -from torch.testing import assert_close +import jax +import jax.numpy as jnp -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechConfig as OriginalConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechEncoderDecoder as OriginalModel -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ - DeepspeechConfig as CustomConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ - DeepspeechEncoderDecoder as CustomModel +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import DeepspeechConfig as CustClsConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import Deepspeech as CustCls -B, T = 32, 30_000 -DEVICE = 'cuda' -TORCH_COMPILE = False +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import DeepspeechConfig as OrigClsConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import Deepspeech as OrigCls -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(mode=True) -SEED = 1996 +# Model / test hyper-params +INPUT_SHAPE = [(3200,), (3200,)] +SEED = 1994 -class DeepSpeechEquivalenceTest(parameterized.TestCase): +class ModeEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), + dict( + testcase_name='Conformer, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='Conformer, p=0.1', + dropout_rate=0.1), ) def test_forward(self, dropout_rate): - """Test different dropout_rate values.""" - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=dropout_rate, - feed_forward_dropout_rate=dropout_rate, - )).to(DEVICE) + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig( - num_lstm_layers=2, - num_ffn_layers=2, - )).to(DEVICE) + orig_model = OrigCls(OrigClsConfig) + cust_model = CustCls(CustClsConfig) - orig.load_state_dict(cust.state_dict()) # sync weights + fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) + initial_params_original = orig_model.init({'params': rng}, + *fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + *fake_batch, + train=False) - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + # fwd + x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) + train = mode == 'train' + (y1, _), _ = orig_model.apply( + initial_params_original, + *x, + train=train, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats'],) + (y2, _), _ = cust_model.apply( + initial_params_custom, + *x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats']) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) + assert jnp.allclose(y1, y2) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) @parameterized.named_parameters( - dict(testcase_name=''),) + dict(testcase_name='Conformer, default'), + ) def test_default_dropout(self): """Test default dropout_rate.""" + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - torch.manual_seed(SEED) - orig = OriginalModel(OriginalConfig(num_lstm_layers=2, - num_ffn_layers=2)).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_lstm_layers=2, - num_ffn_layers=2)).to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights + orig_model = OrigCls(OrigClsConfig) + cust_model = CustCls(CustClsConfig) - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) + fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) + initial_params_original = orig_model.init({'params': rng}, + *fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + *fake_batch, + train=False) + + # fwd + x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) + train = mode == 'train' + (y1, _), _ = orig_model.apply( + initial_params_original, + *x, + train=train, + rngs={'dropout': dropout_rng}, mutable=['batch_stats']) + (y2, _), _ = cust_model.apply( + initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats']) + + + assert jnp.allclose(y1, y2) if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py deleted file mode 100644 index 3b3feb680..000000000 --- a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Runs fwd pass with random graphs for OGBG GNN models and compares outputs. -Run with: - python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py -""" - -import os -import random - -from absl.testing import absltest -from absl.testing import parameterized -from jraph import GraphsTuple -import numpy as np -import torch -from torch.testing import assert_close - -from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel -from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \ - GNN as CustomModel - -B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph -NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims -DEVICE = 'cuda' - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) -SEED = 1996 - - -def _rand_graph(): - total_nodes, total_edges = B * N, B * E - nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) - edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) - senders, receivers = [], [] - for i in range(B): - offset = i * N - s = torch.randint(N, (E,), device=DEVICE) + offset - r = torch.randint(N, (E,), device=DEVICE) + offset - senders.append(s), receivers.append(r) - senders = torch.cat(senders) - receivers = torch.cat(receivers) - n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) - n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) - return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) - - -class GNNEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='0.0', dropout_rate=0.0), - dict(testcase_name='0.2', dropout_rate=0.2), - dict(testcase_name='0.7', dropout_rate=0.7), - dict(testcase_name='1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - """Test different dropout_rates.""" - - orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights - - graph = _rand_graph() - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(graph) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name=''),) - def test_default_dropout(self): - """Test default dropout_rate.""" - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights - - graph = _rand_graph() - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(graph) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph) - - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py deleted file mode 100644 index 03f289a68..000000000 --- a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -Runs fwd pass with random input for WMT Transformer models and compares outputs. -Run with: - python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py -""" - -import os -import random - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import torch -from torch.testing import assert_close - -from algoperf.workloads.wmt.wmt_pytorch.models import \ - Transformer as OriginalModel -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ - Transformer as CustomModel - -B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 -DEVICE = "cuda" -SEED = 1996 - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - - -def _rand_tokens(bs, seqlen): - return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) - - -class TransformerEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention - dict(testcase_name="0.0", dropout_rate=0.0, compile=False), - dict(testcase_name="0.2", dropout_rate=0.2, compile=False), - dict(testcase_name="0.7", dropout_rate=0.7, compile=False), - dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), - dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), - dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), - ) - def test_dropout_value(self, dropout_rate, compile): - - orig = OriginalModel( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name="default", compile=False), - dict(testcase_name="default_compile", compile=True), - ) - def test_default(self, compile): - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt) - - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == "__main__": - absltest.main() From caacb84cb4c59226e2031a5c8d2a715242830662 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 21 Jun 2025 08:05:59 +0000 Subject: [PATCH 083/123] add tests --- .../ogbg_jax/test_model_equivalence.py | 133 ++++++++++++++++++ .../wmt_jax/test_model_equivalence.py | 129 +++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 tests/dropout_fix/ogbg_jax/test_model_equivalence.py create mode 100644 tests/dropout_fix/wmt_jax/test_model_equivalence.py diff --git a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py new file mode 100644 index 000000000..1b5b8a180 --- /dev/null +++ b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py @@ -0,0 +1,133 @@ +""" +Runs fwd pass with random input for OGBG +""" + +import os + +import jraph + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from algoperf.workloads.ogbg.ogbg_jax.models_ref import \ + GNN as OrigCls +from algoperf.workloads.ogbg.ogbg_jax.models import \ + GNN as CustCls + +# Model / test hyper-params +SEED = 1994 + +class ModeEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='OGBG, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='OGBG, p=0.1', + dropout_rate=0.1), + ) + def test_forward(self, dropout_rate): + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls(num_outputs=128, dropout_rate=dropout_rate) + cust_model = CustCls(num_outputs=128) + + fake_batch = jraph.GraphsTuple( + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0])) + + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + # fwd + x = jraph.GraphsTuple( + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0])) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict(testcase_name='OGBG, default'), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + + + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls(num_outputs=128) + cust_model = CustCls(num_outputs=128) + + fake_batch = jraph.GraphsTuple( + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0])) + + initial_params_original = orig_model.init({'params': rng}, + fake_batch, + train=False) + initial_params_custom = cust_model.init({'params': rng}, + fake_batch, + train=False) + + # fwd + x = jraph.GraphsTuple( + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0])) + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + x, + train=train, + rngs={'dropout': dropout_rng}) + y2 = cust_model.apply( + initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2, atol=0, rtol=0) + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py new file mode 100644 index 000000000..c0443bb50 --- /dev/null +++ b/tests/dropout_fix/wmt_jax/test_model_equivalence.py @@ -0,0 +1,129 @@ +""" +Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. +Run it as: + python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py +""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig as CustClsConfig +from algoperf.workloads.wmt.wmt_jax.models import Transformer as CustCls + +from algoperf.workloads.wmt.wmt_jax.models_ref import TransformerConfig as OrigClsConfig +from algoperf.workloads.wmt.wmt_jax.models_ref import Transformer as OrigCls + + +# Model / test hyper-params +SEED = 1994 + +class ModeEquivalenceTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='WMT, p=0.0', + dropout_rate=0.0), + dict( + testcase_name='WMT p=0.1', + dropout_rate=0.1), + ) + def test_forward(self, dropout_rate): + + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls(OrigClsConfig) + cust_model = CustCls(CustClsConfig) + + init_fake_batch_size = 8 + input_shape = (init_fake_batch_size, 256) + target_shape = (init_fake_batch_size, 256) + + initial_params_original = orig_model.init({'params': rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=False) + initial_params_custom = cust_model.init({'params': rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=False) + + # fwd + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=train, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats'],) + y2 = cust_model.apply( + initial_params_custom, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, + mutable=['batch_stats']) + + for i in range(len(y1)): + assert jnp.allclose(y1[i], y2[i]) + + + + @parameterized.named_parameters( + dict(testcase_name='WMT, default'), + ) + def test_default_dropout(self): + """Test default dropout_rate.""" + # init model + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + + orig_model = OrigCls(OrigClsConfig) + cust_model = CustCls(CustClsConfig) + + init_fake_batch_size = 8 + input_shape = (init_fake_batch_size, 256) + target_shape = (init_fake_batch_size, 256) + + initial_params_original = orig_model.init({'params': rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=False) + initial_params_custom = cust_model.init({'params': rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=False) + + # fwd + x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] + + for mode in ('train', 'eval'): + train = mode == 'train' + y1 = orig_model.apply( + initial_params_original, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=train, + rngs={'dropout': dropout_rng}, mutable=['batch_stats']) + y2 = cust_model.apply( + initial_params_custom, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + train=train, rngs={'dropout': dropout_rng}, + mutable=['batch_stats']) + + print(jax.tree.map(lambda x: x.shape, y1)) + + for i in range(len(y1)): + assert jnp.allclose(y1[i], y2[i]) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file From 8b0a12529870be76cd216835f943462424ec5131 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 21 Jun 2025 08:28:09 +0000 Subject: [PATCH 084/123] fix wmt test --- .../wmt_jax/test_model_equivalence.py | 104 ++++++++---------- 1 file changed, 44 insertions(+), 60 deletions(-) diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py index c0443bb50..d420bf691 100644 --- a/tests/dropout_fix/wmt_jax/test_model_equivalence.py +++ b/tests/dropout_fix/wmt_jax/test_model_equivalence.py @@ -26,18 +26,20 @@ class ModeEquivalenceTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name='WMT, p=0.0', - dropout_rate=0.0), + dropout_rate=0.0, + train=True), dict( testcase_name='WMT p=0.1', - dropout_rate=0.1), + dropout_rate=0.1, + train=False), ) - def test_forward(self, dropout_rate): + def test_forward(self, dropout_rate, train): # init model rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - orig_model = OrigCls(OrigClsConfig) - cust_model = CustCls(CustClsConfig) + orig_model = OrigCls(OrigClsConfig(deterministic=not train, attention_dropout_rate=dropout_rate, dropout_rate=dropout_rate)) + cust_model = CustCls(CustClsConfig(deterministic=not train)) init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) @@ -45,48 +47,39 @@ def test_forward(self, dropout_rate): initial_params_original = orig_model.init({'params': rng}, jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=False) + jnp.ones(target_shape, jnp.float32)) initial_params_custom = cust_model.init({'params': rng}, jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=False) + jnp.ones(target_shape, jnp.float32),) # fwd - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=train, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats'],) - y2 = cust_model.apply( - initial_params_custom, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats']) - - for i in range(len(y1)): - assert jnp.allclose(y1[i], y2[i]) + y1 = orig_model.apply( + initial_params_original, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + rngs={'dropout': dropout_rng}) + + y2 = cust_model.apply( + initial_params_custom, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2) @parameterized.named_parameters( - dict(testcase_name='WMT, default'), + dict(testcase_name='WMT, default train', train=True), + dict(testcase_name='WMT, default eval', train=False), ) - def test_default_dropout(self): + def test_default_dropout(self, train): """Test default dropout_rate.""" - # init model rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig) - cust_model = CustCls(CustClsConfig) + orig_model = OrigCls(OrigClsConfig(deterministic=not train)) + cust_model = CustCls(CustClsConfig(deterministic=not train)) init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) @@ -94,35 +87,26 @@ def test_default_dropout(self): initial_params_original = orig_model.init({'params': rng}, jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=False) + jnp.ones(target_shape, jnp.float32)) initial_params_custom = cust_model.init({'params': rng}, jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=False) + jnp.ones(target_shape, jnp.float32)) # fwd - x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=train, - rngs={'dropout': dropout_rng}, mutable=['batch_stats']) - y2 = cust_model.apply( - initial_params_custom, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - train=train, rngs={'dropout': dropout_rng}, - mutable=['batch_stats']) - - print(jax.tree.map(lambda x: x.shape, y1)) - - for i in range(len(y1)): - assert jnp.allclose(y1[i], y2[i]) + + y1 = orig_model.apply( + initial_params_original, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + rngs={'dropout': dropout_rng}) + + y2 = cust_model.apply( + initial_params_custom, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + rngs={'dropout': dropout_rng}) + + assert jnp.allclose(y1, y2) if __name__ == '__main__': From 6c7d69590f58c55e928165c958e51a86af697928 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 24 Jun 2025 00:00:40 +0000 Subject: [PATCH 085/123] remove dropout fix tests --- .../criteo1tb_jax/test_model_equivalence.py | 156 ------------------ .../test_model_equivalence.py | 120 -------------- .../fastmri_jax/test_model_equivalence.py | 118 ------------- .../fastmri_pytorch/test_model_equivalence.py | 130 --------------- .../test_model_equivalence.py | 110 ------------ .../test_model_equivalence.py | 138 ---------------- .../test_model_equivalence.py | 117 ------------- .../test_model_equivalence.py | 84 ---------- .../test_model_equivalence.py | 113 ------------- .../test_model_equivalence.py | 124 -------------- .../ogbg_jax/test_model_equivalence.py | 133 --------------- .../ogbg_pytorch/test_model_equivalence.py | 118 ------------- .../wmt_jax/test_model_equivalence.py | 113 ------------- .../wmt_pytorch/test_model_equivalence.py | 121 -------------- 14 files changed, 1695 deletions(-) delete mode 100644 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/fastmri_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/ogbg_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py delete mode 100644 tests/dropout_fix/wmt_jax/test_model_equivalence.py delete mode 100644 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py diff --git a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py deleted file mode 100644 index 10aeaa650..000000000 --- a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Runs fwd pass with random input for our DLRM models and compares outputs. -Run it as: - python3 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp -# import equinox as eqx - -from jax.tree_util import tree_structure, tree_leaves, tree_map - - -def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): - """ - A custom function to check if two PyTrees are equal, handling floats with a tolerance. - """ - # 1. Check if the structures are the same - if tree_structure(a) != tree_structure(b): - return False - - # 2. Define a comparison function for leaves - def leaf_comparator(x, y): - # Use allclose for floating-point JAX arrays - if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): - return jnp.allclose(x, y, rtol=rtol, atol=atol) - # Use standard equality for everything else - else: - return x == y - - # 3. Map the comparison function over the trees and check if all results are True - # We also need to flatten the results of the tree_map and check if all are True - comparison_tree = tree_map(leaf_comparator, a, b) - all_equal = all(tree_leaves(comparison_tree)) - - return all_equal - -from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \ - DLRMResNet as OriginalDLRMResNet -from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \ - DlrmSmall as OriginalDlrmSmall -from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ - DLRMResNet as CustomDLRMResNet -from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ - DlrmSmall as CustomDlrmSmall - -BATCH, DENSE, SPARSE = 16, 13, 26 -FEATURES = DENSE + SPARSE -VOCAB = 1000 -DEVICE = 'cuda' -SEED = 1996 - - -class ModelEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='DLRMResNet, p=0.0', - model='dlrm_resnet', - dropout_rate=0.0), - dict( - testcase_name='DlrmSmall, p=0.0', - model='dlrm_small', - dropout_rate=0.0), - dict( - testcase_name='DLRMResNet, p=0.1', - model='dlrm_resnet', - dropout_rate=0.1), - dict( - testcase_name='DlrmSmall, p=0.1', - model='dlrm_small', - dropout_rate=0.1), - ) - def test_forward(self, model, dropout_rate): - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - fake_batch = jnp.ones((2, 39)) - assert dropout_rate == 0.1 - orig_model = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate) - cust_model = CustCls(vocab_size=VOCAB) - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - assert pytrees_are_equal( - initial_params_original, initial_params_custom, rtol=1e-6) - - x = jax.random.normal(data_rng, shape=(BATCH, FEATURES)) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) - - @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), - dict(testcase_name='DlrmSmall, default', model='dlrm_small'), - ) - def test_default_dropout(self, model): - """Test default dropout_rate.""" - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - fake_batch = jnp.ones((2, 39)) - orig_model = OrigCls(vocab_size=VOCAB) - cust_model = CustCls(vocab_size=VOCAB) - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - x = jax.random.normal(data_rng, shape=(BATCH, FEATURES)) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py deleted file mode 100644 index 733052dd0..000000000 --- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Runs fwd pass with random input for our DLRM models and compares outputs. -Run it as: - python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import torch -from torch.testing import assert_close - -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DLRMResNet as OriginalDLRMResNet -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DlrmSmall as OriginalDlrmSmall -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \ - DLRMResNet as CustomDLRMResNet -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \ - DlrmSmall as CustomDlrmSmall - -BATCH, DENSE, SPARSE = 16, 13, 26 -FEATURES = DENSE + SPARSE -VOCAB = 1000 -DEVICE = 'cuda' -SEED = 1996 - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - - -class ModelEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='DLRMResNet, p=0.0', - model='dlrm_resnet', - dropout_rate=0.0), - dict( - testcase_name='DlrmSmall, p=0.0', - model='dlrm_small', - dropout_rate=0.0), - dict( - testcase_name='DLRMResNet, p=0.1', - model='dlrm_resnet', - dropout_rate=0.1), - dict( - testcase_name='DlrmSmall, p=0.1', - model='dlrm_small', - dropout_rate=0.1), - dict( - testcase_name='DLRMResNet, p=1.0', - model='dlrm_resnet', - dropout_rate=1.0), - dict( - testcase_name='DlrmSmall, p=1.0', - model='dlrm_small', - dropout_rate=1.0), - ) - def test_forward(self, model, dropout_rate): - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB).to(DEVICE) - - x = torch.randn(BATCH, FEATURES, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(SEED) - y1 = orig(x) - torch.manual_seed(SEED) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'), - dict(testcase_name='DlrmSmall, default', model='dlrm_small'), - ) - def test_default_dropout(self, model): - """Test default dropout_rate.""" - OrigCls, CustCls = ( - (OriginalDLRMResNet, CustomDLRMResNet) - if model == 'dlrm_resnet' - else (OriginalDlrmSmall, CustomDlrmSmall) - ) - - torch.manual_seed(SEED) - orig = OrigCls(vocab_size=VOCAB).to(DEVICE) - torch.manual_seed(SEED) - cust = CustCls(vocab_size=VOCAB).to(DEVICE) - - x = torch.randn(BATCH, FEATURES, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py deleted file mode 100644 index fbb6d2499..000000000 --- a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py -""" - -import os - - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp -# import equinox as eqx - - -from algoperf.workloads.fastmri.fastmri_jax.models_ref import \ - UNet as OriginalUNet -from algoperf.workloads.fastmri.fastmri_jax.models import \ - UNet as CustomUNet - -BATCH, IN_CHANS, H, W = 4, 1, 256, 256 -OUT_CHANS, C, LAYERS = 1, 32, 4 -SEED = 1996 - - -class ModelEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='UNet, p=0.0', - dropout_rate=0.0), - dict( - testcase_name='UNet, p=0.1', - dropout_rate=0.1), - ) - def test_forward(self, dropout_rate): - OrigCls, CustCls = (OriginalUNet, CustomUNet) - - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS) - orig_model = OrigCls(**kwargs) - cust_model = CustCls(**kwargs) - - fake_batch = jnp.ones((BATCH, IN_CHANS, H, W)) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jax.random.normal(data_rng, shape=(BATCH, H, W)) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) - - @parameterized.named_parameters( - dict(testcase_name='UNet, default'), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - OrigCls, CustCls = (OriginalUNet, CustomUNet) - - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - kwargs = dict(num_pool_layers=LAYERS, - num_channels=IN_CHANS, - ) - orig_model = OrigCls(**kwargs) - cust_model = CustCls(**kwargs) - - fake_batch = jnp.ones((2, IN_CHANS, H, W)) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jax.random.normal(data_rng, shape=(BATCH, H, W)) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=0, rtol=0) - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py deleted file mode 100644 index 1d318e8c6..000000000 --- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import torch -from torch.testing import assert_close - -from algoperf.workloads.fastmri.fastmri_pytorch.models import \ - UNet as OriginalUNet -from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \ - UNet as CustomUNet - -BATCH, IN_CHANS, H, W = 4, 1, 256, 256 -OUT_CHANS, C, LAYERS = 1, 32, 4 -DEVICE = 'cuda' -TORCH_COMPILE = False -SEED = 1996 - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - - -class FastMRIModeEquivalenceTest(parameterized.TestCase): - - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_dropout_values(self, dropout_rate): - """Test different values of dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) - - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name='default', use_tanh=False, use_layer_norm=False), - dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), - dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), - dict(testcase_name='both', use_tanh=True, use_layer_norm=True), - ) - def test_arch_configs(self, use_tanh, use_layer_norm): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalUNet( - IN_CHANS, - OUT_CHANS, - C, - LAYERS, - dropout_rate=dropout_rate, - use_tanh=use_tanh, - use_layer_norm=use_layer_norm).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomUNet( - IN_CHANS, - OUT_CHANS, - C, - LAYERS, - use_tanh=use_tanh, - use_layer_norm=use_layer_norm).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) - - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name=''),) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py deleted file mode 100644 index 6806ca99f..000000000 --- a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp - -from algoperf.workloads.imagenet_vit.imagenet_jax.models_ref import \ - ViT as OriginalVit -from algoperf.workloads.imagenet_vit.imagenet_jax.models import \ - ViT as CustomVit - -# Model / test hyper-params -INPUT_SHAPE = (2, 224, 124, 3) -SEED = 1994 - -class ImageNetVitModeEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='ViT, p=0.0', - dropout_rate=0.0), - dict( - testcase_name='ViT, p=0.1', - dropout_rate=0.1), - ) - def test_forward(self, dropout_rate): - OrigCls, CustCls = (OriginalVit, CustomVit) - - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls() - cust_model = CustCls() - - fake_batch = jnp.ones(INPUT_SHAPE) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jax.random.normal(data_rng, shape=INPUT_SHAPE) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) - - @parameterized.named_parameters( - dict(testcase_name='UNet, default'), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - OrigCls, CustCls = (OriginalVit, CustomVit) - - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls() - cust_model = CustCls() - - fake_batch = jnp.ones(INPUT_SHAPE) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jax.random.normal(data_rng, INPUT_SHAPE) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=0, rtol=0) - -if __name__ == '__main__': - absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py deleted file mode 100644 index f51eaec7e..000000000 --- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py -""" - -import itertools -import os - -from absl.testing import absltest -from absl.testing import parameterized -import torch -from torch.testing import assert_close - -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ - ViT as OriginalVit -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \ - ViT as CustomVit - -# Model / test hyper-params -BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W) -WIDTH, DEPTH, HEADS = 256, 4, 8 -DROPOUT_RATE = None -DEVICE = 'cuda' -SEED = 1996 - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - - -class ImageNetVitModeEquivalenceTest(parameterized.TestCase): - - def fwd_pass(self, orig, cust, dropout_rate): - x = torch.randn(BATCH, C, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(0) - y2 = cust(x, dropout_rate) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.1', dropout_rate=0.1), - dict(testcase_name='p=0.6', dropout_rate=0.6), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_dropout_values(self, dropout_rate): - """Test different dropout_rates.""" - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters([ - dict( - testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}", - use_glu=use_glu, - use_post_ln=use_post_ln, - use_map=use_map, - ) for use_glu, - use_post_ln, - use_map in itertools.product([False, True], repeat=3) - ]) - def test_arch(self, use_glu, use_post_ln, use_map): - """Test different architecture configurations, fixed dropout_rate.""" - dropout_rate = 0.1 - - torch.manual_seed(SEED) - orig = OriginalVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - dropout_rate=dropout_rate, - ).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomVit( - width=WIDTH, - depth=DEPTH, - num_heads=HEADS, - use_glu=use_glu, - use_post_layer_norm=use_post_ln, - use_map=use_map, - ).to(DEVICE) - - cust.load_state_dict(orig.state_dict()) # sync weights - self.fwd_pass(orig, cust, dropout_rate) - - @parameterized.named_parameters( - dict(testcase_name=''),) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE) - cust.load_state_dict(orig.state_dict()) # sync weights - - x = torch.randn(BATCH, C, H, W, device=DEVICE) - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(0) - y1 = orig(x) - torch.manual_seed(0) - y2 = cust(x) - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py deleted file mode 100644 index c5d849dce..000000000 --- a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp - -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ConformerConfig as CustClsConfig -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import Conformer as CustCls - -from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import ConformerConfig as OrigClsConfig -from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import Conformer as OrigCls - - -# Model / test hyper-params -INPUT_SHAPE = [(3200,), (3200,)] -SEED = 1994 - -class ModeEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='Conformer, p=0.0', - dropout_rate=0.0), - dict( - testcase_name='Conformer, p=0.1', - dropout_rate=0.1), - ) - def test_forward(self, dropout_rate): - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig(attention_dropout_rate=dropout_rate, - attention_residual_dropout_rate = dropout_rate, - conv_residual_dropout_rate = dropout_rate, - feed_forward_dropout_rate = dropout_rate, - feed_forward_residual_dropout_rate = dropout_rate, - input_dropout_rate=dropout_rate)) - cust_model = CustCls(CustClsConfig()) - - fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - - initial_params_original = orig_model.init({'params': rng}, - *fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - *fake_batch, - train=False) - - # fwd - x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] - - for train in [True]: - (y1, _), _ = orig_model.apply( - initial_params_original, - *x, - train=train, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats'],) - (y2, _), _ = cust_model.apply( - initial_params_custom, - *x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats']) - - assert jnp.allclose(y1, y2) - - - - @parameterized.named_parameters( - dict(testcase_name='Conformer, default'), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig()) - cust_model = CustCls(CustClsConfig()) - - fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - - initial_params_original = orig_model.init({'params': rng}, - *fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - *fake_batch, - train=False) - - # fwd - x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] - - for mode in ('train', 'eval'): - train = mode == 'train' - (y1, _), _ = orig_model.apply( - initial_params_original, - *x, - train=train, - rngs={'dropout': dropout_rng}, mutable=['batch_stats']) - (y2, _), _ = cust_model.apply( - initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats']) - - - assert jnp.allclose(y1, y2) - - -if __name__ == '__main__': - absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py deleted file mode 100644 index 02f3a3d84..000000000 --- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs. -Run with: - python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py - -NOTE: we don't test for default dropout_rate values, since they changed. -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import torch -from torch.testing import assert_close - -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerConfig as OriginalConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerEncoderDecoder as OriginalModel -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ - ConformerConfig as CustomConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \ - ConformerEncoderDecoder as CustomModel - -N_LAYERS = 3 -B, T = 32, 36_000 -DEVICE = 'cuda' - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(mode=True) -SEED = 1996 - - -class ConformerEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( - num_encoder_layers=N_LAYERS, - attention_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=dropout_rate, - )).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py deleted file mode 100644 index 5155199f5..000000000 --- a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp - -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import DeepspeechConfig as CustClsConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import Deepspeech as CustCls - -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import DeepspeechConfig as OrigClsConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import Deepspeech as OrigCls - - -# Model / test hyper-params -INPUT_SHAPE = [(3200,), (3200,)] -SEED = 1994 - -class ModeEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='Conformer, p=0.0', - dropout_rate=0.0), - dict( - testcase_name='Conformer, p=0.1', - dropout_rate=0.1), - ) - def test_forward(self, dropout_rate): - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig) - cust_model = CustCls(CustClsConfig) - - fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - - initial_params_original = orig_model.init({'params': rng}, - *fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - *fake_batch, - train=False) - - # fwd - x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] - - for mode in ('train', 'eval'): - train = mode == 'train' - (y1, _), _ = orig_model.apply( - initial_params_original, - *x, - train=train, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats'],) - (y2, _), _ = cust_model.apply( - initial_params_custom, - *x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}, - mutable=['batch_stats']) - - assert jnp.allclose(y1, y2) - - - - @parameterized.named_parameters( - dict(testcase_name='Conformer, default'), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig) - cust_model = CustCls(CustClsConfig) - - fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE] - - initial_params_original = orig_model.init({'params': rng}, - *fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - *fake_batch, - train=False) - - # fwd - x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE] - - for mode in ('train', 'eval'): - train = mode == 'train' - (y1, _), _ = orig_model.apply( - initial_params_original, - *x, - train=train, - rngs={'dropout': dropout_rng}, mutable=['batch_stats']) - (y2, _), _ = cust_model.apply( - initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats']) - - - assert jnp.allclose(y1, y2) - - -if __name__ == '__main__': - absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py deleted file mode 100644 index 7d6a94592..000000000 --- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs. -Run with: - python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py - -`dropout_rate` controls the following args: -- `input_dropout_rate` (if None, 0.1 -- `feed_forward_dropout_rate` (if None, 0.1) -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import torch -from torch.testing import assert_close - -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechConfig as OriginalConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechEncoderDecoder as OriginalModel -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ - DeepspeechConfig as CustomConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \ - DeepspeechEncoderDecoder as CustomModel - -B, T = 32, 30_000 -DEVICE = 'cuda' -TORCH_COMPILE = False - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(mode=True) -SEED = 1996 - - -class DeepSpeechEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='p=0.0', dropout_rate=0.0), - dict(testcase_name='p=0.2', dropout_rate=0.2), - dict(testcase_name='p=0.7', dropout_rate=0.7), - dict(testcase_name='p=1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - """Test different dropout_rate values.""" - - torch.manual_seed(SEED) - orig = OriginalModel( - OriginalConfig( - num_lstm_layers=2, - num_ffn_layers=2, - input_dropout_rate=dropout_rate, - feed_forward_dropout_rate=dropout_rate, - )).to(DEVICE) - - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig( - num_lstm_layers=2, - num_ffn_layers=2, - )).to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name=''),) - def test_default_dropout(self): - """Test default dropout_rate.""" - - torch.manual_seed(SEED) - orig = OriginalModel(OriginalConfig(num_lstm_layers=2, - num_ffn_layers=2)).to(DEVICE) - torch.manual_seed(SEED) - cust = CustomModel(CustomConfig(num_lstm_layers=2, - num_ffn_layers=2)).to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights - - if TORCH_COMPILE: - orig = torch.compile(orig) - cust = torch.compile(cust) - - x = torch.randn(B, T, device=DEVICE) - paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE) - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - torch.manual_seed(SEED) - y1, p1 = orig(x, paddings) - torch.manual_seed(SEED) - y2, p2 = cust(x, paddings) - assert_close(y1, y2, atol=0, rtol=0) - assert_close(p1, p2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py deleted file mode 100644 index 1b5b8a180..000000000 --- a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Runs fwd pass with random input for OGBG -""" - -import os - -import jraph - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp - -from algoperf.workloads.ogbg.ogbg_jax.models_ref import \ - GNN as OrigCls -from algoperf.workloads.ogbg.ogbg_jax.models import \ - GNN as CustCls - -# Model / test hyper-params -SEED = 1994 - -class ModeEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='OGBG, p=0.0', - dropout_rate=0.0), - dict( - testcase_name='OGBG, p=0.1', - dropout_rate=0.1), - ) - def test_forward(self, dropout_rate): - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(num_outputs=128, dropout_rate=dropout_rate) - cust_model = CustCls(num_outputs=128) - - fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) - - @parameterized.named_parameters( - dict(testcase_name='OGBG, default'), - ) - def test_default_dropout(self): - """Test default dropout_rate.""" - - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(num_outputs=128) - cust_model = CustCls(num_outputs=128) - - fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) - - initial_params_original = orig_model.init({'params': rng}, - fake_batch, - train=False) - initial_params_custom = cust_model.init({'params': rng}, - fake_batch, - train=False) - - # fwd - x = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) - - for mode in ('train', 'eval'): - train = mode == 'train' - y1 = orig_model.apply( - initial_params_original, - x, - train=train, - rngs={'dropout': dropout_rng}) - y2 = cust_model.apply( - initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2, atol=0, rtol=0) - -if __name__ == '__main__': - absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py deleted file mode 100644 index 3b3feb680..000000000 --- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Runs fwd pass with random graphs for OGBG GNN models and compares outputs. -Run with: - python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py -""" - -import os -import random - -from absl.testing import absltest -from absl.testing import parameterized -from jraph import GraphsTuple -import numpy as np -import torch -from torch.testing import assert_close - -from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel -from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \ - GNN as CustomModel - -B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph -NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims -DEVICE = 'cuda' - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) -SEED = 1996 - - -def _rand_graph(): - total_nodes, total_edges = B * N, B * E - nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE) - edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE) - senders, receivers = [], [] - for i in range(B): - offset = i * N - s = torch.randint(N, (E,), device=DEVICE) + offset - r = torch.randint(N, (E,), device=DEVICE) + offset - senders.append(s), receivers.append(r) - senders = torch.cat(senders) - receivers = torch.cat(receivers) - n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32) - n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32) - return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge) - - -class GNNEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict(testcase_name='0.0', dropout_rate=0.0), - dict(testcase_name='0.2', dropout_rate=0.2), - dict(testcase_name='0.7', dropout_rate=0.7), - dict(testcase_name='1.0', dropout_rate=1.0), - ) - def test_forward(self, dropout_rate): - """Test different dropout_rates.""" - - orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights - - graph = _rand_graph() - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(graph) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name=''),) - def test_default_dropout(self): - """Test default dropout_rate.""" - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - orig.load_state_dict(cust.state_dict()) # sync weights - - graph = _rand_graph() - - for mode in ('train', 'eval'): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(graph) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(graph) - - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py deleted file mode 100644 index d420bf691..000000000 --- a/tests/dropout_fix/wmt_jax/test_model_equivalence.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Runs fwd pass with random input for FASTMRI U-Net models and compares outputs. -Run it as: - python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py -""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -import jax.numpy as jnp - -from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig as CustClsConfig -from algoperf.workloads.wmt.wmt_jax.models import Transformer as CustCls - -from algoperf.workloads.wmt.wmt_jax.models_ref import TransformerConfig as OrigClsConfig -from algoperf.workloads.wmt.wmt_jax.models_ref import Transformer as OrigCls - - -# Model / test hyper-params -SEED = 1994 - -class ModeEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name='WMT, p=0.0', - dropout_rate=0.0, - train=True), - dict( - testcase_name='WMT p=0.1', - dropout_rate=0.1, - train=False), - ) - def test_forward(self, dropout_rate, train): - - # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - - orig_model = OrigCls(OrigClsConfig(deterministic=not train, attention_dropout_rate=dropout_rate, dropout_rate=dropout_rate)) - cust_model = CustCls(CustClsConfig(deterministic=not train)) - - init_fake_batch_size = 8 - input_shape = (init_fake_batch_size, 256) - target_shape = (init_fake_batch_size, 256) - - initial_params_original = orig_model.init({'params': rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) - initial_params_custom = cust_model.init({'params': rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32),) - - # fwd - - y1 = orig_model.apply( - initial_params_original, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - rngs={'dropout': dropout_rng}) - - y2 = cust_model.apply( - initial_params_custom, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - dropout_rate=dropout_rate, - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2) - - - - @parameterized.named_parameters( - dict(testcase_name='WMT, default train', train=True), - dict(testcase_name='WMT, default eval', train=False), - ) - def test_default_dropout(self, train): - """Test default dropout_rate.""" - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) - orig_model = OrigCls(OrigClsConfig(deterministic=not train)) - cust_model = CustCls(CustClsConfig(deterministic=not train)) - - init_fake_batch_size = 8 - input_shape = (init_fake_batch_size, 256) - target_shape = (init_fake_batch_size, 256) - - initial_params_original = orig_model.init({'params': rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) - initial_params_custom = cust_model.init({'params': rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) - - # fwd - - y1 = orig_model.apply( - initial_params_original, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - rngs={'dropout': dropout_rng}) - - y2 = cust_model.apply( - initial_params_custom, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32), - rngs={'dropout': dropout_rng}) - - assert jnp.allclose(y1, y2) - - -if __name__ == '__main__': - absltest.main() \ No newline at end of file diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py deleted file mode 100644 index 03f289a68..000000000 --- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -Runs fwd pass with random input for WMT Transformer models and compares outputs. -Run with: - python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py -""" - -import os -import random - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import torch -from torch.testing import assert_close - -from algoperf.workloads.wmt.wmt_pytorch.models import \ - Transformer as OriginalModel -from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \ - Transformer as CustomModel - -B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000 -DEVICE = "cuda" -SEED = 1996 - -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" -torch.backends.cudnn.benchmark = False -torch.backends.cudnn.deterministic = True -torch.use_deterministic_algorithms(True) - - -def _rand_tokens(bs, seqlen): - return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE) - - -class TransformerEquivalenceTest(parameterized.TestCase): - - @parameterized.named_parameters( - # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention - dict(testcase_name="0.0", dropout_rate=0.0, compile=False), - dict(testcase_name="0.2", dropout_rate=0.2, compile=False), - dict(testcase_name="0.7", dropout_rate=0.7, compile=False), - dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True), - dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True), - dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True), - ) - def test_dropout_value(self, dropout_rate, compile): - - orig = OriginalModel( - dropout_rate=dropout_rate, - attention_dropout_rate=dropout_rate).to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt, dropout_rate=dropout_rate) - - assert_close(y1, y2, atol=0, rtol=0) - - if mode == 'eval': # one extra test: omit dropout at eval - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt) - assert_close(y1, y2, atol=0, rtol=0) - - @parameterized.named_parameters( - dict(testcase_name="default", compile=False), - dict(testcase_name="default_compile", compile=True), - ) - def test_default(self, compile): - - orig = OriginalModel().to(DEVICE) - cust = CustomModel().to(DEVICE) - - orig.load_state_dict(cust.state_dict()) # sync weights - - if compile: - orig = torch.compile(orig) - cust = torch.compile(cust) - - src = _rand_tokens(B, SRC_LEN) - tgt = _rand_tokens(B, TGT_LEN) - - for mode in ("train", "eval"): - getattr(orig, mode)() - getattr(cust, mode)() - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y1 = orig(src, tgt) - - torch.manual_seed(SEED) - random.seed(SEED) - np.random.seed(SEED) - y2 = cust(src, tgt) - - assert_close(y1, y2, atol=0, rtol=0) - - -if __name__ == "__main__": - absltest.main() From 66f5ed320226e81c39394053386cf09a864bf881 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 05:08:11 +0000 Subject: [PATCH 086/123] fix formatting --- tests/test_jax_utils.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index 04713915f..8a156149b 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -3,18 +3,19 @@ Run it as: pytest """ +from functools import partial import os from absl.testing import absltest from absl.testing import parameterized +import flax.linen as nn import jax import jax.numpy as jnp -import flax.linen as nn +from jax.tree_util import tree_leaves +from jax.tree_util import tree_map +from jax.tree_util import tree_structure -from jax.tree_util import tree_structure, tree_leaves, tree_map from algoperf.jax_utils import Dropout -from functools import partial - SEED = 1996 DEFAULT_DROPOUT = 0.5 @@ -213,25 +214,25 @@ def test_jitted_updates(self, dropout_rate, mode): jitted_custom_apply = jax.jit( partial(cust_model.apply), static_argnames=['train']) - def multiple_fwd_passes_custom_layer(): - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: - y2 = jitted_custom_apply( - initial_variables_custom, - x, - train=train, - dropout_rate=d, - rngs={"dropout": dropout_rng}, - ) - return y2 + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y2 = jitted_custom_apply( + initial_variables_custom, + x, + train=train, + dropout_rate=d, + rngs={"dropout": dropout_rng}, + ) + return y2 def multiple_fwd_passes_original_layer(): - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: - y1 = jitted_original_apply( + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y1 = jitted_original_apply( initial_variables_original, x, train=train, rngs={"dropout": dropout_rng}) + if __name__ == "__main__": absltest.main() From 62b1cc97b346110574c086f923f987685f3642c1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 05:24:36 +0000 Subject: [PATCH 087/123] remove reference model implementations used for testing --- .../criteo1tb/criteo1tb_jax/models_ref.py | 219 ------ .../fastmri/fastmri_jax/models_ref.py | 220 ------ .../imagenet_jax/models_ref.py | 132 ---- .../imagenet_vit/imagenet_jax/models_ref.py | 235 ------ .../librispeech_jax/models_ref.py | 712 ------------------ .../librispeech_jax/models_ref.py | 525 ------------- .../workloads/ogbg/ogbg_jax/models_ref.py | 84 --- algoperf/workloads/wmt/wmt_jax/models_ref.py | 604 --------------- 8 files changed, 2731 deletions(-) delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py delete mode 100644 algoperf/workloads/fastmri/fastmri_jax/models_ref.py delete mode 100644 algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py delete mode 100644 algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py delete mode 100644 algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py delete mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py delete mode 100644 algoperf/workloads/ogbg/ogbg_jax/models_ref.py delete mode 100644 algoperf/workloads/wmt/wmt_jax/models_ref.py diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py deleted file mode 100644 index d0352e290..000000000 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py +++ /dev/null @@ -1,219 +0,0 @@ -"""A JAX implementation of DLRM-Small.""" - -from typing import Sequence - -import flax.linen as nn -from jax import nn as jnn -import jax.numpy as jnp - - -class DLRMResNet(nn.Module): - """Define a DLRMResNet model. - - Parameters: - vocab_size: the size of a single unified embedding table. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - num_dense_features: number of dense features as the bottom mlp input. - embed_dim: embedding dimension. - """ - - vocab_size: int = 32 * 128 * 1024 # 4_194_304 - num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (256, 256, 256) - mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 - dropout_rate: float = 0.1 - use_layer_norm: bool = False # Unused. - embedding_init_multiplier: float = None # Unused - - @nn.compact - def __call__(self, x, train): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) - cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - - # bottom mlp - mlp_bottom_dims = self.mlp_bottom_dims - - bot_mlp_input = nn.Dense( - mlp_bottom_dims[0], - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), - )( - bot_mlp_input) - bot_mlp_input = nn.relu(bot_mlp_input) - - for dense_dim in mlp_bottom_dims[1:]: - x = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), - )( - bot_mlp_input) - bot_mlp_input += nn.relu(x) - - base_init_fn = jnn.initializers.uniform(scale=1.0) - # Embedding table init and lookup for a single unified table. - idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - - def scaled_init(key, shape, dtype=jnp.float_): - return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) - - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) - - embed_features = embedding_table[idx_lookup] - batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape(embed_features, - (batch_size, 26 * self.embed_dim)) - top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) - mlp_input_dim = top_mlp_input.shape[1] - mlp_top_dims = self.mlp_top_dims - num_layers_top = len(mlp_top_dims) - top_mlp_input = nn.Dense( - mlp_top_dims[0], - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( - top_mlp_input) - top_mlp_input = nn.relu(top_mlp_input) - for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: - fan_in = mlp_top_dims[layer_idx - 1] - x = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( - top_mlp_input) - x = nn.relu(x) - if self.dropout_rate and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) - top_mlp_input += x - # In the DLRM model the last layer width is always 1. We can hardcode that - # below. - logits = nn.Dense( - 1, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( - top_mlp_input) - return logits - - -def dot_interact(concat_features): - """Performs feature interaction operation between dense or sparse features. - Input tensors represent dense or sparse features. - Pre-condition: The tensors have been stacked along dimension 1. - Args: - concat_features: Array of features with shape [B, n_features, feature_dim]. - Returns: - activations: Array representing interacted features. - """ - batch_size = concat_features.shape[0] - - # Interact features, select upper or lower-triangular portion, and reshape. - xactions = jnp.matmul(concat_features, - jnp.transpose(concat_features, [0, 2, 1])) - feature_dim = xactions.shape[-1] - - indices = jnp.array(jnp.triu_indices(feature_dim)) - num_elems = indices.shape[1] - indices = jnp.tile(indices, [1, batch_size]) - indices0 = jnp.reshape( - jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), - [1, -1]) - indices = tuple(jnp.concatenate((indices0, indices), 0)) - activations = xactions[indices] - activations = jnp.reshape(activations, [batch_size, -1]) - return activations - - -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - vocab_size: int = 32 * 128 * 1024 # 4_194_304. - num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (512, 256, 128) - mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) - embed_dim: int = 128 - dropout_rate: float = 0.0 - use_layer_norm: bool = False - embedding_init_multiplier: float = None - - @nn.compact - def __call__(self, x, train): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) - cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - - # Bottom MLP. - for dense_dim in self.mlp_bottom_dims: - bot_mlp_input = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) - bot_mlp_input = nn.relu(bot_mlp_input) - if self.use_layer_norm: - bot_mlp_input = nn.LayerNorm()(bot_mlp_input) - bot_mlp_output = bot_mlp_input - batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) - - # Embedding table look-up. - idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - - if self.embedding_init_multiplier is None: - scale = 1 / jnp.sqrt(self.vocab_size) - else: - scale = self.embedding_init_multiplier - - def scaled_init(key, shape, dtype=jnp.float_): - return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale - - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) - - idx_lookup = jnp.reshape(idx_lookup, [-1]) - embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) - if self.use_layer_norm: - embed_features = nn.LayerNorm()(embed_features) - feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) - dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) - mlp_input_dim = top_mlp_input.shape[1] - mlp_top_dims = self.mlp_top_dims - num_layers_top = len(mlp_top_dims) - for layer_idx, fan_out in enumerate(mlp_top_dims): - fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] - top_mlp_input = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( - top_mlp_input) - if layer_idx < (num_layers_top - 1): - top_mlp_input = nn.relu(top_mlp_input) - if self.use_layer_norm: - top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) - logits = top_mlp_input - return logits \ No newline at end of file diff --git a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py deleted file mode 100644 index a2d56a4b4..000000000 --- a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Jax / Flax implementation of FastMRI U-Net. - -Forked from -https://github.com/google/init2winit/blob/master/init2winit/model_lib/unet.py - -Original implementation: -github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py - -Training: -github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/unet_module.py - -Data: -github.com/facebookresearch/fastMRI/tree/main/fastmri/data -""" -import functools -from typing import Optional - -import flax.linen as nn -import jax -import jax.numpy as jnp - - -def _instance_norm2d(x, axes, epsilon=1e-5): - # promote x to at least float32, this avoids half precision computation - # but preserves double or complex floating points - x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) - mean = jnp.mean(x, axes) - mean2 = jnp.mean(jnp.square(x), axes) - # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due - # to floating point round-off errors. - var = jnp.maximum(0., mean2 - jnp.square(mean)) - stats_shape = list(x.shape) - for axis in axes: - stats_shape[axis] = 1 - mean = mean.reshape(stats_shape) - var = var.reshape(stats_shape) - y = x - mean - mul = jnp.sqrt(var + epsilon) - y /= mul - return y - - -class UNet(nn.Module): - """Jax / Flax implementation of a U-Net model. - - O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks - for biomedical image segmentation. In International Conference on Medical - image computing and computer-assisted intervention, pages 234–241. - Springer, 2015. - - out_channels: Number of channels in the output to the U-Net model. - channels: Number of output channels of the first convolution layer. - num_pool_layers: Number of down-sampling and up-sampling layers. - dropout_rate: Dropout probability. - """ - num_channels: int = 32 - num_pool_layers: int = 4 - out_channels = 1 - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. - use_tanh: bool = False - use_layer_norm: bool = False - - @nn.compact - def __call__(self, x, train=True): - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - - # pylint: disable=invalid-name - _ConvBlock = functools.partial( - ConvBlock, - dropout_rate=dropout_rate, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) - _TransposeConvBlock = functools.partial( - TransposeConvBlock, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) - - down_sample_layers = [_ConvBlock(self.num_channels)] - - ch = self.num_channels - for _ in range(self.num_pool_layers - 1): - down_sample_layers.append(_ConvBlock(ch * 2)) - ch *= 2 - conv = _ConvBlock(ch * 2) - - up_conv = [] - up_transpose_conv = [] - for _ in range(self.num_pool_layers - 1): - up_transpose_conv.append(_TransposeConvBlock(ch)) - up_conv.append(_ConvBlock(ch)) - ch //= 2 - - up_transpose_conv.append(_TransposeConvBlock(ch)) - up_conv.append(_ConvBlock(ch)) - - stack = [] - output = jnp.expand_dims(x, axis=-1) - - # apply down-sampling layers - for layer in down_sample_layers: - output = layer(output, train) - stack.append(output) - output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2)) - - output = conv(output, train) - - # apply up-sampling layers - for transpose_conv, conv in zip(up_transpose_conv, up_conv): - downsample_layer = stack.pop() - output = transpose_conv(output) - - # reflect pad on the right/botton if needed to handle odd input dimensions - padding_right = 0 - padding_bottom = 0 - if output.shape[-2] != downsample_layer.shape[-2]: - padding_right = 1 # padding right - if output.shape[-3] != downsample_layer.shape[-3]: - padding_bottom = 1 # padding bottom - - if padding_right or padding_bottom: - padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0)) - output = jnp.pad(output, padding, mode='reflect') - - output = jnp.concatenate((output, downsample_layer), axis=-1) - output = conv(output, train) - - output = nn.Conv( - self.out_channels, kernel_size=(1, 1), strides=(1, 1))( - output) - return output.squeeze(-1) - - -class ConvBlock(nn.Module): - """A Convolutional Block. - out_channels: Number of channels in the output. - dropout_rate: Dropout probability. - """ - out_channels: int - dropout_rate: float - use_tanh: bool - use_layer_norm: bool - - @nn.compact - def __call__(self, x, train=True): - """Forward function. - Note: Pytorch is NCHW and jax/flax is NHWC. - Args: - x: Input 4D tensor of shape `(N, H, W, in_channels)`. - train: deterministic or not (use init2winit naming). - Returns: - jnp.array: Output tensor of shape `(N, H, W, out_channels)`. - """ - x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) - if self.use_layer_norm: - x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) - else: - # DO NOT SUBMIT check that this comment edit is correct - # InstanceNorm2d was run with no learnable params in reference code - # so this is a simple normalization along spatial dims. - x = _instance_norm2d(x, (1, 2)) - if self.use_tanh: - activation_fn = nn.tanh - else: - activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) - x = activation_fn(x) - # Ref code uses dropout2d which applies the same mask for the entire channel - # Replicated by using broadcast dims to have the same filter on HW - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) - x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) - if self.use_layer_norm: - x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) - else: - x = _instance_norm2d(x, (1, 2)) - x = activation_fn(x) - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) - return x - - -class TransposeConvBlock(nn.Module): - """A Transpose Convolutional Block. - out_channels: Number of channels in the output. - """ - out_channels: int - use_tanh: bool - use_layer_norm: bool - - @nn.compact - def __call__(self, x): - """Forward function. - Args: - x: Input 4D tensor of shape `(N, H, W, in_channels)`. - Returns: - jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`. - """ - x = nn.ConvTranspose( - self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( - x) - x = _instance_norm2d(x, (1, 2)) - if self.use_tanh: - activation_fn = nn.tanh - else: - activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) - x = activation_fn(x) - return x \ No newline at end of file diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py deleted file mode 100644 index 357dadc13..000000000 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Jax implementation of ResNet V1. - -Adapted from Flax example: -https://github.com/google/flax/blob/main/examples/imagenet/models.py. -""" - -import functools -from typing import Any, Callable, Optional, Tuple - -from flax import linen as nn -import jax.numpy as jnp - -from algoperf import spec - -ModuleDef = nn.Module - - -class ResNetBlock(nn.Module): - """ResNet block.""" - filters: int - conv: ModuleDef - norm: ModuleDef - act: Callable - strides: Tuple[int, int] = (1, 1) - bn_init_scale: float = 0. - - @nn.compact - def __call__(self, x: spec.Tensor) -> spec.Tensor: - residual = x - y = self.conv(self.filters, (3, 3), self.strides)(x) - y = self.norm()(y) - y = self.act(y) - y = self.conv(self.filters, (3, 3))(y) - y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y) - - if residual.shape != y.shape or self.strides != (1, 1): - residual = self.conv( - self.filters, (1, 1), self.strides, name='Conv_proj')( - residual) - residual = self.norm(name='BatchNorm_proj')(residual) - - return self.act(residual + y) - - -class BottleneckResNetBlock(nn.Module): - """Bottleneck ResNet block.""" - filters: int - conv: ModuleDef - norm: ModuleDef - act: Callable - strides: Tuple[int, int] = (1, 1) - bn_init_scale: Optional[float] = None - - @nn.compact - def __call__(self, x: spec.Tensor) -> spec.Tensor: - residual = x - y = self.conv(self.filters, (1, 1))(x) - y = self.norm()(y) - y = self.act(y) - y = self.conv(self.filters, (3, 3), self.strides)(y) - y = self.norm()(y) - y = self.act(y) - y = self.conv(self.filters * 4, (1, 1))(y) - y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y) - - if residual.shape != y.shape or self.strides != (1, 1): - residual = self.conv( - self.filters * 4, (1, 1), self.strides, name='Conv_proj')( - residual) - residual = self.norm(name='BatchNorm_proj')(residual) - - return self.act(residual + y) - - -class ResNet(nn.Module): - stage_sizes: Tuple[int] - block_cls: ModuleDef - num_classes: int - num_filters: int = 64 - dtype: Any = jnp.float32 - act: Callable = nn.relu - bn_init_scale: float = 0. - - @nn.compact - def __call__(self, - x: spec.Tensor, - update_batch_norm: bool = True, - use_running_average_bn: Optional[bool] = None) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - - # Preserve default behavior for backwards compatibility - if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm - norm = functools.partial( - nn.BatchNorm, - use_running_average=use_running_average_bn, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype) - - x = conv( - self.num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='Conv_init')( - x) - x = norm(name='BatchNorm_init')(x) - x = self.act(x) - x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) - for i, block_size in enumerate(self.stage_sizes): - for j in range(block_size): - strides = (2, 2) if i > 0 and j == 0 else (1, 1) - x = self.block_cls( - self.num_filters * 2**i, - strides=strides, - conv=conv, - norm=norm, - act=self.act, - bn_init_scale=self.bn_init_scale)( - x) - x = jnp.mean(x, axis=(1, 2)) - x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - dtype=self.dtype)( - x) - return x - - -ResNet18 = functools.partial( - ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) -ResNet50 = functools.partial( - ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) \ No newline at end of file diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py deleted file mode 100644 index beb8a2eb8..000000000 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Jax implementation of refactored and simplified ViT. - -Forked from: -https://github.com/google/init2winit/blob/master/init2winit/model_lib/vit.py, -originally from https://github.com/google/big_vision with modifications noted. -""" - -from typing import Optional, Sequence, Union - -from flax import linen as nn -import jax.numpy as jnp - -from algoperf import spec - - -def posemb_sincos_2d(h: int, - w: int, - width: int, - temperature: int = 10_000., - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: - """Follows the MoCo v3 logic.""" - y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence - - if width % 4 != 0: - raise ValueError('Width must be mult of 4 for sincos posemb.') - omega = jnp.arange(width // 4) / (width // 4 - 1) - omega = 1. / (temperature**omega) - y = jnp.einsum('m,d->md', y.flatten(), omega) - x = jnp.einsum('m,d->md', x.flatten(), omega) - pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) - return jnp.asarray(pe, dtype)[None, :, :] - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block.""" - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - use_glu: bool = False - dropout_rate: float = 0.0 - - @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - """Applies Transformer MlpBlock module.""" - inits = { - 'kernel_init': nn.initializers.xavier_uniform(), - 'bias_init': nn.initializers.normal(stddev=1e-6), - } - - d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) - x = nn.gelu(x) - - if self.use_glu: - y = nn.Dense(self.mlp_dim, **inits)(x) - x = x * y - - x = nn.Dropout(rate=self.dropout_rate)(x, train) - x = nn.Dense(d, **inits)(x) - return x - - -class Encoder1DBlock(nn.Module): - """Single transformer encoder block (MHSA + MLP).""" - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - use_glu: bool = False - use_post_layer_norm: bool = False - dropout_rate: float = 0.0 - - @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - else: - y = x - y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) - - y = x - y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) - - return x - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation.""" - depth: int - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - dropout_rate: float = 0.0 - use_glu: bool = False - use_post_layer_norm: bool = False - - @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - # Input Encoder - for lyr in range(self.depth): - block = Encoder1DBlock( - name=f'encoderblock_{lyr}', - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=self.dropout_rate) - x = block(x, train) - if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) - else: - return x - - -class MAPHead(nn.Module): - """Multihead Attention Pooling.""" - mlp_dim: Optional[int] = None # Defaults to 4x input dim - num_heads: int = 12 - - @nn.compact - def __call__(self, x): - n, _, d = x.shape - probe = self.param('probe', - nn.initializers.xavier_uniform(), (1, 1, d), - x.dtype) - probe = jnp.tile(probe, [n, 1, 1]) - - x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(probe, x) - - y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) - return x[:, 0] - - -class ViT(nn.Module): - """ViT model.""" - - num_classes: int = 1000 - patch_size: Sequence[int] = (16, 16) - width: int = 768 - depth: int = 12 - mlp_dim: Optional[int] = None # Defaults to 4x input dim. - num_heads: int = 12 - rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. - reinit: Optional[Sequence[str]] = None - head_zeroinit: bool = True - use_glu: bool = False - use_post_layer_norm: bool = False - use_map: bool = False - - def get_posemb(self, - seqshape: tuple, - width: int, - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: - return posemb_sincos_2d(*seqshape, width, dtype=dtype) - - @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: - # Patch extraction - x = nn.Conv( - self.width, - self.patch_size, - strides=self.patch_size, - padding='VALID', - name='conv_patch_extract')( - x) - - n, h, w, c = x.shape - x = jnp.reshape(x, [n, h * w, c]) - - # Add posemb before adding extra token. - x = x + self.get_posemb((h, w), c, x.dtype) - - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - x = nn.Dropout(rate=dropout_rate)(x, not train) - - x = Encoder( - depth=self.depth, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate, - name='Transformer')( - x, train=not train) - - if self.use_map: - x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) - else: - x = jnp.mean(x, axis=1) - - if self.rep_size: - rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits') - x = nn.tanh(hid(x)) - - if self.num_classes: - kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense(self.num_classes, name='head', **kw) - x = head(x) - - return x \ No newline at end of file diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py deleted file mode 100644 index f168833d3..000000000 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py +++ /dev/null @@ -1,712 +0,0 @@ -r"""Conformer. - -This model uses a conformer network to convert speech to text. -paper : https://arxiv.org/abs/2005.08100 - -high-level overview of Conformer encoder layer. - - x = x + 0.5 * FeedForward(x) - x = x + MHSA(x) - x = x + ConvolutionBlock(x) - x = x + 0.5 * FeedForward(x) - y = layer_norm(x) -""" - -import functools -import math -from typing import Any, List, Optional - -from flax import linen as nn -from flax import struct -import jax -import jax.numpy as jnp -import numpy as np - -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter - - -@struct.dataclass -class ConformerConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - dtype: Any = jnp.float32 - encoder_dim: int = 512 - num_attention_heads: int = 8 - num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.1 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.1 - feed_forward_dropout_rate: float = 0.1 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 - convolution_kernel_size: int = 5 - feed_forward_expansion_factor: int = 4 - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - batch_norm_momentum: float = 0.999 - batch_norm_epsilon: float = 0.001 - use_specaug: bool = True - attention_temperature: float = 1.0 - activation_function_name: str = 'swish' - use_post_layer_norm: bool = True - - -class LayerNorm(nn.Module): - """Module implementing layer normalization. - - This implementation is same as in this paper: - https://arxiv.org/pdf/1607.06450.pdf. - - note: we multiply normalized inputs by (1 + scale) and initialize scale to - zeros, this differs from default flax implementation of multiplying by scale - and initializing to ones. - """ - dim: int = 0 - epsilon: float = 1e-6 - - def setup(self): - self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) - self.bias = self.param('bias', nn.initializers.zeros, [self.dim]) - - @nn.compact - def __call__(self, inputs): - mean = jnp.mean(inputs, axis=[-1], keepdims=True) - var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) - - normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) - normed_inputs += self.bias - - return normed_inputs - - -class Subsample(nn.Module): - """Module to perform strided convolution in order to subsample inputs. - - Attributes: - encoder_dim: model dimension of conformer. - input_dropout_rate: dropout rate for inputs. - """ - encoder_dim: int = 0 - input_dropout_rate: float = 0.0 - - @nn.compact - def __call__(self, inputs, input_paddings, train): - output_paddings = input_paddings - outputs = jnp.expand_dims(inputs, axis=-1) - - outputs, output_paddings = Conv2dSubsampling( - input_channels=1, output_channels=self.encoder_dim)( - outputs, output_paddings) - - outputs, output_paddings = Conv2dSubsampling( - input_channels=self.encoder_dim, - output_channels=self.encoder_dim)(outputs, output_paddings) - - batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape - - outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) - - outputs = nn.Dense( - self.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) - - outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( - seq_length=outputs.shape[1]) - - outputs = nn.Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - """Helper module used in Subsample layer. - - 1) Performs strided convolution over inputs and then applies non-linearity. - 2) Also performs strided convolution over input_paddings to return the correct - paddings for downstream layers. - """ - input_channels: int = 0 - output_channels: int = 0 - filter_stride: List[int] = (2, 2) - padding: str = 'SAME' - - def setup(self): - self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) - - @nn.compact - def __call__(self, inputs, paddings): - # Computing strided convolution to subsample inputs. - feature_group_count = inputs.shape[3] // self.filter_shape[2] - outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) - - outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) - outputs = nn.relu(outputs) - - # Computing correct paddings post input convolution. - input_length = paddings.shape[1] - stride = self.filter_stride[0] - - pad_len = (input_length + stride - 1) // stride * stride - input_length - out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) - out_padding = jnp.squeeze(out_padding, axis=-1) - - # Mask outputs by correct paddings to ensure padded elements in inputs map - # to padded value in outputs. - outputs = outputs * \ - (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - """Feedforward block of conformer layer. - """ - config: ConformerConfig - - @nn.compact - def __call__(self, inputs, padding_mask=None, train=False): - config = self.config - - inputs = LayerNorm(dim=config.encoder_dim)(inputs) - - inputs = nn.Dense( - config.encoder_dim * config.feed_forward_expansion_factor, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - if config.activation_function_name == 'swish': - activation_fn = nn.swish - elif config.activation_function_name == 'gelu': - activation_fn = nn.gelu - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) - - inputs = inputs * padding_mask - - inputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - inputs = inputs * padding_mask - - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( - inputs, deterministic=not train) - - return inputs - - -class AddPositionalEmbedding(nn.Module): - """Adds (optionally learned) positional embeddings to the inputs. - - Attributes: - max_len: maximum possible length for the input - posemb_init: positional embedding initializer - """ - min_timescale: int = 1 - max_timescale: int = 10_000 - embedding_dim: int = 512 - - @nn.compact - def __call__(self, seq_length): - position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] - num_timescales = self.embedding_dim // 2 - log_timescale_increment = ( - math.log(float(self.max_timescale) / float(self.min_timescale)) / - jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)) - inv_timescales = self.min_timescale * jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * - -log_timescale_increment) - scaled_time = ( - position[:, :, jnp.newaxis] * - inv_timescales[jnp.newaxis, jnp.newaxis, :]) - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], - axis=2).astype(jnp.float32) - # Force usage of `np` rather than `jnp` to compute static values at trace - # time. - signal = jnp.pad(signal, - [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) - return signal - - -# Adapted from lingvo attention layer for query scaling -# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201 -class QueryScaler(nn.Module): - """A layer to scale individual dims of the query attention matrix.""" - dim: int = 0 - - def setup(self): - self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) - - @nn.compact - def __call__(self, inputs): - inputs_shape = inputs.shape - if inputs_shape[-1] != self.dim: - raise ValueError('QueryScaler expects inputs to have' - ' same last dimension as scaling param.') - - # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we - # can avoid unnecessary XLA op fusion mess on TPU. - r_softplus_0 = 1.442695041 - - scale = jnp.array(r_softplus_0, dtype=inputs.dtype) - scale *= jax.nn.softplus(self.scale) - - return inputs * scale - - -# Modifying flax linen default dot product attention function to add -# query scaling, reference to original function here : -# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121 -def dot_product_attention(query, - key, - value, - bias=None, - mask=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - dtype=jnp.float32, - precision=None, - temperature=1.0): - """Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It's slightly modified to add query scaling. - It calculates the attention weights given query and key and combines the - values using the attention weights. - - Note: query, key, value needn't have any batch dimensions. - - Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. - mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - - Returns: - Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - - # compute attention weights - query = QueryScaler(dim=query.shape[-1])(query) - attn_weights = nn.attention.dot_product_attention_weights( - query, - key, - bias, - mask, - broadcast_dropout, - dropout_rng, - dropout_rate, - deterministic, - dtype, - precision) - - # return weighted sum over values for each query position - return jnp.einsum( - '...hqk,...khd->...qhd', attn_weights, value, - precision=precision) * temperature - - -class MultiHeadedSelfAttention(nn.Module): - """Self attention sub-layer used in the Conformer layer. - - Input is first normalized using layer norm. Output is processed using - multi-headed attention. - - Note: this attention implementation uses a learned scale parameter to scale - query matrix before passing it to flax attention module. - """ - config: ConformerConfig = None - - @nn.compact - def __call__(self, inputs, paddings, train): - config = self.config - mask_paddings = 1 - paddings - attention_mask = nn.make_attention_mask( - mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) - - inputs = LayerNorm(dim=config.encoder_dim)(inputs) - attention_fn = functools.partial( - dot_product_attention, temperature=config.attention_temperature) - result = nn.MultiHeadDotProductAttention( - num_heads=config.num_attention_heads, - qkv_features=config.encoder_dim, - decode=False, - dtype=config.dtype, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - use_bias=True, - broadcast_dropout=False, - attention_fn=attention_fn, - dropout_rate=config.attention_dropout_rate, - deterministic=not train)( - inputs_q=inputs, mask=attention_mask) - - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - result = nn.Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)( - result) - - return result - - -class BatchNorm(nn.Module): - """Implements batch norm respecting input paddings. - - This implementation takes into account input padding by masking inputs before - computing mean and variance. - - This is inspired by lingvo jax implementation of BatchNorm: - https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92 - - and the corresponding defaults for momentum and epsilon have been copied over - from lingvo. - """ - config: ConformerConfig - - def setup(self): - dim = self.config.encoder_dim - dtype = self.config.dtype - - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) - - self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) - self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) - - @nn.compact - def __call__(self, - inputs, - input_paddings, - update_batch_norm, - use_running_average_bn): - rank = inputs.ndim - reduce_over_dims = list(range(0, rank - 1)) - - padding = jnp.expand_dims(input_paddings, -1) - momentum = self.config.batch_norm_momentum - epsilon = self.config.batch_norm_epsilon - - if use_running_average_bn: - mean = self.ra_mean.value - var = self.ra_var.value - - else: - # compute batch statistics - mask = 1.0 - padding - sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) - count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - - count_v = jnp.maximum(count_v, 1.0) - mean = sum_v / count_v - - sum_vv = jnp.sum( - (inputs - mean) * (inputs - mean) * mask, - axis=reduce_over_dims, - keepdims=True) - - var = sum_vv / count_v - - if update_batch_norm: - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - - inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta - bn_output *= 1.0 - padding - - return bn_output - - -class ConvolutionBlock(nn.Module): - r"""Convolution block in conformer layer. - - architecture: - - input # (batch, time, hidden_dim) - | - layer_norm(.) # (batch, time, hidden_dim) - dense(.), dense(.) # (batch, time, 2 * hidden_dim) - | / - glu(.) # (batch, time, hidden_dim) - depthwise_conv1d(.) - batch_norm(.) - act(.) - | - dense(.) - dropout(.) - | - output - """ - config: ConformerConfig - - @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average_bn): - config = self.config - inputs = LayerNorm(dim=config.encoder_dim)(inputs) - - input_gated1 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) - - input_gated2 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) - - inputs = input_gated1 * jax.nn.sigmoid(input_gated2) - inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1)) - - inputs = nn.Conv( - features=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - strides=(1,), - padding='SAME', - feature_group_count=config.encoder_dim, - use_bias=False, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - - inputs = BatchNorm(config)(inputs, - input_paddings, - update_batch_norm, - use_running_average_bn) - if config.activation_function_name == 'swish': - activation_fn = nn.swish - elif config.activation_function_name == 'gelu': - activation_fn = nn.gelu - else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') - inputs = activation_fn(inputs) - inputs = nn.Dense( - config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( - inputs) - - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - inputs = nn.Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)( - inputs) - return inputs - - -class ConformerBlock(nn.Module): - """Implements a single conformer encoder layer. - - High level overview: - - x = x + 0.5 * FeedForward(x) - x = x + MHSA(x) - x = x + ConvolutionBlock(x) - x = x + 0.5 * FeedForward(x) - - y = layer_norm(x) - - """ - config: ConformerConfig - - @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average): - config = self.config - padding_mask = jnp.expand_dims(1 - input_paddings, -1) - - inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) - - inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) - - inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, - use_running_average - ) - - inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) - - if config.use_post_layer_norm: - inputs = LayerNorm(dim=config.encoder_dim)(inputs) - - return inputs - - -class Conformer(nn.Module): - """Conformer (encoder + decoder) block. - - Takes audio input signals and outputs probability distribution over vocab size - for each time step. The output is then fed into a CTC loss which eliminates - the need for alignment with targets. - """ - config: ConformerConfig - - def setup(self): - self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=self.config.freq_mask_count, - freq_mask_max_bins=self.config.freq_mask_max_bins, - time_mask_count=self.config.time_mask_count, - time_mask_max_frames=self.config.time_mask_max_frames, - time_mask_max_ratio=self.config.time_mask_max_ratio, - time_masks_per_frame=self.config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.config - .use_dynamic_time_mask_max_frames) - - @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): - config = self.config - - outputs = inputs - output_paddings = input_paddings - - # Set BN args if not supplied for backwards compatibility - if update_batch_norm is None: - update_batch_norm = train - if use_running_average_bn is None: - use_running_average_bn = not train - - # Compute normalized log mel spectrograms from input audio signal. - preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() - outputs, output_paddings = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)( - outputs, output_paddings) - - # Ablate random parts of input along temporal and frequency dimension - # following the specaug procedure in https://arxiv.org/abs/1904.08779. - if train and config.use_specaug: - outputs, output_paddings = self.specaug(outputs, output_paddings) - - # Subsample input by a factor of 4 by performing strided convolutions. - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate)( - outputs, output_paddings, train) - - # Run the conformer encoder layers. - for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, - output_paddings, - train, - update_batch_norm, - use_running_average_bn) - - outputs = LayerNorm(config.encoder_dim)(outputs) - # Run the decoder which in this case is a trivial projection layer. - outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) - - return outputs, output_paddings \ No newline at end of file diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py deleted file mode 100644 index 7b7c9720a..000000000 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py +++ /dev/null @@ -1,525 +0,0 @@ -r"""Deepspeech. - -This model uses a deepspeech2 network to convert speech to text. -paper : https://arxiv.org/abs/1512.02595 - -# BiLSTM code contributed by bastings@ -# github : https://github.com/bastings -# webpage : https://bastings.github.io/ -""" - -from typing import Any, Dict, List, Optional, Tuple, Union - -from flax import linen as nn -from flax import struct -import jax -from jax.experimental import rnn -import jax.numpy as jnp - -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter - -Array = jnp.ndarray -StateType = Union[Array, Tuple[Array, ...]] -PRNGKey = Any -Shape = Tuple[int] -Dtype = Any -Carry = Any -CarryHistory = Any -Output = Any - - -@struct.dataclass -class DeepspeechConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - vocab_size: int = 1024 - dtype: Any = jnp.float32 - encoder_dim: int = 512 - num_lstm_layers: int = 6 - num_ffn_layers: int = 3 - conv_subsampling_factor: int = 2 - conv_subsampling_layers: int = 2 - use_specaug: bool = True - freq_mask_count: int = 2 - freq_mask_max_bins: int = 27 - time_mask_count: int = 10 - time_mask_max_frames: int = 40 - time_mask_max_ratio: float = 0.05 - time_masks_per_frame: float = 0.0 - use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 - batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 - enable_residual_connections: bool = True - enable_decoder_layer_norm: bool = True - bidirectional: bool = True - use_tanh: bool = False - layernorm_everywhere: bool = False - - -class Subsample(nn.Module): - """Module to perform strided convolution in order to subsample inputs. - - Attributes: - encoder_dim: model dimension of conformer. - input_dropout_rate: dropout rate for inputs. - """ - config: DeepspeechConfig - - @nn.compact - def __call__(self, inputs, output_paddings, train): - config = self.config - outputs = jnp.expand_dims(inputs, axis=-1) - - outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=1, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh - )(outputs, output_paddings, train) - - outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=config.encoder_dim, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh)(outputs, output_paddings, train) - - batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape - - outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) - - outputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) - - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - outputs = nn.Dropout( - rate=input_dropout_rate, deterministic=not train)( - outputs) - - return outputs, output_paddings - - -class Conv2dSubsampling(nn.Module): - """Helper module used in Subsample layer. - - 1) Performs strided convolution over inputs and then applies non-linearity. - 2) Also performs strided convolution over input_paddings to return the correct - paddings for downstream layers. - """ - input_channels: int = 0 - output_channels: int = 0 - filter_stride: List[int] = (2, 2) - padding: str = 'SAME' - encoder_dim: int = 0 - dtype: Any = jnp.float32 - batch_norm_momentum: float = 0.999 - batch_norm_epsilon: float = 0.001 - use_tanh: bool = False - - def setup(self): - self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) - - @nn.compact - def __call__(self, inputs, paddings, train): - # Computing strided convolution to subsample inputs. - feature_group_count = inputs.shape[3] // self.filter_shape[2] - outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) - - outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) - - if self.use_tanh: - outputs = nn.tanh(outputs) - else: - outputs = nn.relu(outputs) - - # Computing correct paddings post input convolution. - input_length = paddings.shape[1] - stride = self.filter_stride[0] - - pad_len = (input_length + stride - 1) // stride * stride - input_length - out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) - out_padding = jnp.squeeze(out_padding, axis=-1) - - # Mask outputs by correct paddings to ensure padded elements in inputs map - # to padded value in outputs. - outputs = outputs * (1.0 - - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) - - return outputs, out_padding - - -class FeedForwardModule(nn.Module): - """Feedforward block of conformer layer.""" - config: DeepspeechConfig - - @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): - padding_mask = jnp.expand_dims(1 - input_paddings, -1) - config = self.config - - if config.layernorm_everywhere: - inputs = LayerNorm(config.encoder_dim)(inputs) - else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) - inputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - if config.use_tanh: - inputs = nn.tanh(inputs) - else: - inputs = nn.relu(inputs) - inputs *= padding_mask - - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = nn.Dropout(rate=feed_forward_dropout_rate)( - inputs, deterministic=not train) - - return inputs - - -class LayerNorm(nn.Module): - """Module implementing layer normalization. - - This implementation is same as in this paper: - https://arxiv.org/pdf/1607.06450.pdf. - - note: we multiply normalized inputs by (1 + scale) and initialize scale to - zeros, this differs from default flax implementation of multiplying by scale - and initializing to ones. - """ - dim: int = 0 - epsilon: float = 1e-6 - - def setup(self): - self.scale = self.param('scale', nn.initializers.zeros, [self.dim]) - self.bias = self.param('bias', nn.initializers.zeros, [self.dim]) - - @nn.compact - def __call__(self, inputs): - mean = jnp.mean(inputs, axis=-1, keepdims=True) - var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) - - normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) - normed_inputs += self.bias - - return normed_inputs - - -class BatchNorm(nn.Module): - """Implements batch norm respecting input paddings. - - This implementation takes into account input padding by masking inputs before - computing mean and variance. - - This is inspired by lingvo jax implementation of BatchNorm: - https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92 - - and the corresponding defaults for momentum and epsilon have been copied over - from lingvo. - """ - encoder_dim: int = 0 - dtype: Any = jnp.float32 - batch_norm_momentum: float = 0.999 - batch_norm_epsilon: float = 0.001 - - def setup(self): - dim = self.encoder_dim - dtype = self.dtype - - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) - - self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) - self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) - - def _get_default_paddings(self, inputs): - """Gets the default paddings for an input.""" - in_shape = list(inputs.shape) - in_shape[-1] = 1 - - return jnp.zeros(in_shape, dtype=inputs.dtype) - - @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): - rank = inputs.ndim - reduce_over_dims = list(range(0, rank - 1)) - - if input_paddings is None: - padding = self._get_default_paddings(inputs) - else: - padding = jnp.expand_dims(input_paddings, -1) - - momentum = self.batch_norm_momentum - epsilon = self.batch_norm_epsilon - - if train: - mask = 1.0 - padding - sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) - count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - - count_v = jnp.maximum(count_v, 1.0) - mean = sum_v / count_v - variance = (inputs - mean) * (inputs - mean) * mask - - sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') - var = sum_vv / count_v - - self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value - - inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - - bn_output = (inputs - mean) * inv + self.beta - bn_output *= 1.0 - padding - - return bn_output - # return inputs - - -class CudnnLSTM(nn.Module): - features: int - num_layers: int = 1 - dropout_rate: float = 0.0 - bidirectional: bool = False - - @nn.compact - def __call__( - self, - inputs: Array, - segmentation_mask: Optional[Array] = None, - return_carry: Optional[bool] = None, - deterministic: bool = False, - initial_states: Optional[Tuple[Array, Array]] = None, - use_cuda: bool = True, - ) -> Union[Array, Tuple[Array, Carry]]: - - if jax.devices()[0].platform != 'gpu': - use_cuda = False - - batch_size = inputs.shape[0] - input_size = inputs.shape[2] - num_directions = 2 if self.bidirectional else 1 - dropout = 0.0 if deterministic else self.dropout_rate - - weights = self.param( - 'weights', - rnn.init_lstm_weight, - input_size, - self.features, - self.num_layers, - self.bidirectional, - ) - - if initial_states is None: - h_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, - ) - c_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, - ) - else: - h_0, c_0 = initial_states - - if segmentation_mask is not None: - seq_lengths = jnp.sum(1 - segmentation_mask, axis=1, dtype=jnp.int32) - else: - seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32) - - if use_cuda: - y, h, c = rnn.lstm( - x=inputs, h_0=h_0, c_0=c_0, weights=weights, - seq_lengths=seq_lengths, input_size=input_size, - hidden_size=self.features, num_layers=self.num_layers, - dropout=dropout, bidirectional=self.bidirectional, - ) - else: - weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights( - weights, input_size) - y, h, c = rnn.lstm_ref( - x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh, - b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths, - input_size=input_size, hidden_size=self.features, - num_layers=self.num_layers, dropout=dropout, - bidirectional=self.bidirectional, - ) - - if return_carry: - return y, (h, c) - - return y - - @nn.nowrap - def unpack_weights( - self, weights: Array, input_size: int - ) -> Tuple[ - Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]: - return jax.experimental.rnn.unpack_lstm_weights( - weights, - input_size, - self.features, - self.num_layers, - self.bidirectional, - ) - - -class BatchRNN(nn.Module): - """Implements a single deepspeech encoder layer. - """ - config: DeepspeechConfig - - @nn.compact - def __call__(self, inputs, input_paddings, train): - config = self.config - - if config.layernorm_everywhere: - inputs = LayerNorm(config.encoder_dim)(inputs) - else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) - output = CudnnLSTM( - features=config.encoder_dim // 2, - bidirectional=config.bidirectional, - num_layers=1)(inputs, input_paddings) - - return output - - -class Deepspeech(nn.Module): - """Conformer (encoder + decoder) block. - - Takes audio input signals and outputs probability distribution over vocab size - for each time step. The output is then fed into a CTC loss which eliminates - the need for alignment with targets. - """ - config: DeepspeechConfig - - def setup(self): - config = self.config - self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames - ) - - @nn.compact - def __call__(self, inputs, input_paddings, train): - config = self.config - - outputs = inputs - output_paddings = input_paddings - - # Compute normalized log mel spectrograms from input audio signal. - preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() - outputs, output_paddings = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs, - output_paddings) - - # Ablate random parts of input along temporal and frequency dimension - # following the specaug procedure in https://arxiv.org/abs/1904.08779. - if config.use_specaug and train: - outputs, output_paddings = self.specaug(outputs, output_paddings) - - # Subsample input by a factor of 4 by performing strided convolutions. - outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train) - - # Run the lstm layers. - for _ in range(config.num_lstm_layers): - if config.enable_residual_connections: - outputs = outputs + BatchRNN(config)(outputs, output_paddings, train) - else: - outputs = BatchRNN(config)(outputs, output_paddings, train) - - for _ in range(config.num_ffn_layers): - if config.enable_residual_connections: - outputs = outputs + FeedForwardModule(config=self.config)( - outputs, output_paddings, train) - else: - outputs = FeedForwardModule(config=self.config)(outputs, - output_paddings, - train) - - # Run the decoder which in this case is a trivial projection layer. - if config.enable_decoder_layer_norm: - outputs = LayerNorm(config.encoder_dim)(outputs) - - outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) - - return outputs, output_paddings \ No newline at end of file diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py deleted file mode 100644 index ca3d89426..000000000 --- a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py +++ /dev/null @@ -1,84 +0,0 @@ -# Forked from the init2winit implementation here -# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Optional, Tuple - -from flax import linen as nn -import jax.numpy as jnp -import jraph - - -def _make_embed(latent_dim, name): - - def make_fn(inputs): - return nn.Dense(features=latent_dim, name=name)(inputs) - - return make_fn - - -def _make_mlp(hidden_dims, activation_fn, train, dropout_rate): - """Creates a MLP with specified dimensions.""" - - @jraph.concatenated_args - def make_fn(inputs): - x = inputs - for dim in hidden_dims: - x = nn.Dense(features=dim)(x) - x = nn.LayerNorm()(x) - x = activation_fn(x) - x = nn.Dropout(rate=dropout_rate, deterministic=not train)(x) - return x - - return make_fn - - -class GNN(nn.Module): - """Defines a graph network. - The model assumes the input data is a jraph.GraphsTuple without global - variables. The final prediction will be encoded in the globals. - """ - num_outputs: int - latent_dim: int = 256 - hidden_dims: Tuple[int] = (256,) - # If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 - num_message_passing_steps: int = 5 - activation_fn_name: str = 'relu' - - @nn.compact - def __call__(self, graph, train): - dropout_rate = self.dropout_rate - - graph = graph._replace( - globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) - - embedder = jraph.GraphMapFeatures( - embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), - embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) - graph = embedder(graph) - - if self.activation_fn_name == 'relu': - activation_fn = nn.relu - elif self.activation_fn_name == 'gelu': - activation_fn = nn.gelu - elif self.activation_fn_name == 'silu': - activation_fn = nn.silu - else: - raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') - - for _ in range(self.num_message_passing_steps): - net = jraph.GraphNetwork( - update_edge_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), - update_node_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate), - update_global_fn=_make_mlp( - self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate)) - - graph = net(graph) - - # Map globals to represent the final result - decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) - graph = decoder(graph) - - return graph.globals \ No newline at end of file diff --git a/algoperf/workloads/wmt/wmt_jax/models_ref.py b/algoperf/workloads/wmt/wmt_jax/models_ref.py deleted file mode 100644 index e1f44aaa6..000000000 --- a/algoperf/workloads/wmt/wmt_jax/models_ref.py +++ /dev/null @@ -1,604 +0,0 @@ -"""Transformer-based machine translation model. - -Reference https://github.com/google/flax/tree/main/examples/wmt. -""" - -from typing import Any, Callable, Optional - -from flax import linen as nn -from flax import struct -from jax import lax -import jax.numpy as jnp -import numpy as np - - -@struct.dataclass -class TransformerConfig: - """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" - share_embeddings: bool = True - dtype: Any = jnp.float32 - vocab_size: int = 32000 - emb_dim: int = 1024 - num_heads: int = 16 - num_layers: int = 6 - qkv_dim: int = 1024 - mlp_dim: int = 1024 - max_len: int = 256 - activation: Callable = nn.relu - glu: bool = False - #If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 - #If None, defaults to 0.1. - attention_dropout_rate: Optional[float] = 0.1 - attention_temp: float = 1.0 - deterministic: bool = False - decode: bool = False - kernel_init: Callable = nn.initializers.xavier_uniform() - bias_init: Callable = nn.initializers.normal(stddev=1e-6) - posemb_init: Optional[Callable] = None - pre_ln: bool = True - - -def shift_right(x, axis=1): - """Shift the input to the right by padding on axis 1.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = (1, 0) - padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) - return padded[:, :-1] - - -def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): - """1D Sinusoidal Position Embedding Initializer. - - Args: - max_len: maximum possible length for the input. - min_scale: float: minimum frequency-scale in sine grating. - max_scale: float: maximum frequency-scale in sine grating. - - Returns: - output: init function returning `(1, max_len, d_feature)` - """ - - def init(key, shape, dtype=np.float32): - """Sinusoidal init.""" - del key, dtype - d_feature = shape[-1] - pe = np.zeros((max_len, d_feature), dtype=np.float32) - position = np.arange(0, max_len)[:, np.newaxis] - scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) - div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) - pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] - return jnp.array(pe) - - return init - - -class AddPositionEmbs(nn.Module): - """Adds (optionally learned) positional embeddings to the inputs. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - decode: whether to run in single-position autoregressive mode. - """ - config: TransformerConfig - decode: bool = False - - @nn.compact - def __call__(self, inputs, inputs_positions=None): - """Applies AddPositionEmbs module. - - By default this layer uses a fixed sinusoidal embedding table. If a - learned position embedding is desired, pass an initializer to - posemb_init in the configuration. - - Args: - inputs: input data. - inputs_positions: input position indices for packed sequences. - - Returns: - output: `(bs, timesteps, in_dim)` - """ - cfg = self.config - # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - f' but it is: {inputs.ndim}') - length = inputs.shape[1] - pos_emb_shape = (1, cfg.max_len, inputs.shape[-1]) - if cfg.posemb_init is None: - # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None, - pos_emb_shape, - None) - else: - pos_embedding = self.param('pos_embedding', - cfg.posemb_init, - pos_emb_shape) - pe = pos_embedding[:, :length, :] - - # We use a cache position index for tracking decoding position. - if self.decode: - is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', - 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) - if is_initialized: - i = cache_index.value - cache_index.value = i + 1 - _, _, df = pos_embedding.shape - pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) - if inputs_positions is None: - # normal unpacked case: - return inputs + pe - else: - # for packed data we need to use known position indices: - return inputs + jnp.take(pe[0], inputs_positions, axis=0) - - -class MlpBlock(nn.Module): - """Transformer MLP / feed-forward block. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - out_dim: optionally specify out dimension. - """ - config: TransformerConfig - out_dim: Optional[int] = None - - @nn.compact - def __call__(self, inputs): - """Applies Transformer MlpBlock module.""" - cfg = self.config - actual_out_dim = ( - inputs.shape[-1] if self.out_dim is None else self.out_dim) - x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = x * y - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - x) - output = nn.Dropout(rate=dropout_rate)( - output, deterministic=cfg.deterministic) - return output - - -class Encoder1DBlock(nn.Module): - """Transformer encoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, inputs, encoder_mask=None): - """Applies Encoder1DBlock module. - - Args: - inputs: input data. - encoder_mask: encoder self-attention mask. - - Returns: - output after transformer encoder block. - """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Attention block. - assert inputs.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * x, x, mask=encoder_mask) - - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + inputs - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # MLP block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) - - return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) - - -class EncoderDecoder1DBlock(nn.Module): - """Transformer encoder-decoder layer. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies EncoderDecoder1DBlock module. - - Args: - targets: input data for decoder - encoded: input data from encoder - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output after transformer encoder-decoder block. - """ - cfg = self.config - pre_ln = cfg.pre_ln - - # Decoder block. - assert targets.ndim == 3 - x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate - x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode)( - cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - x = x + targets - if not pre_ln: - x = nn.LayerNorm(dtype=cfg.dtype)(x) - - # Encoder-Decoder block. - y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - y = y + x - if not pre_ln: - y = nn.LayerNorm(dtype=cfg.dtype)(y) - - # MLP block. - z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) - - return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) - - -class Encoder(nn.Module): - """Transformer Model Encoder for sequence to sequence translation. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, inputs, inputs_positions=None, encoder_mask=None): - """Applies Transformer model on the inputs. - - Args: - inputs: input data - inputs_positions: input subsequence positions for packed examples. - encoder_mask: decoder self-attention mask. - - Returns: - output of a transformer encoder. - """ - cfg = self.config - assert inputs.ndim == 2 # (batch, len) - - # Input Embedding - if self.shared_embedding is None: - input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - input_embed = self.shared_embedding - x = inputs.astype('int32') - x = input_embed(x) - x = AddPositionEmbs( - config=cfg, decode=False, name='posembed_input')( - x, inputs_positions=inputs_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - - x = x.astype(cfg.dtype) - - # Input Encoder - for lyr in range(cfg.num_layers): - x = Encoder1DBlock( - config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) - - encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) - if cfg.pre_ln else x) - - return encoded - - -class Decoder(nn.Module): - """Transformer Model Decoder for sequence to sequence translation. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ - config: TransformerConfig - shared_embedding: Any = None - - @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): - """Applies Transformer model on the inputs. - - Args: - encoded: encoded input data from encoder. - targets: target inputs. - targets_positions: input subsequence positions for packed examples. - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. - - Returns: - output of a transformer decoder. - """ - cfg = self.config - - assert encoded.ndim == 3 # (batch, len, depth) - assert targets.ndim == 2 # (batch, len) - - # Target Embedding - if self.shared_embedding is None: - output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - output_embed = self.shared_embedding - - y = targets.astype('int32') - if not cfg.decode: - y = shift_right(y) - y = output_embed(y) - y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name='posembed_output')( - y, inputs_positions=targets_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) - - y = y.astype(cfg.dtype) - - # Target-Input Decoder - for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock( - config=cfg, name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) - y = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) - if cfg.pre_ln else y) - - # Use the transpose of embedding matrix for logit transform. - logits = output_embed.attend(y.astype(jnp.float32)) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - return logits - - -class Transformer(nn.Module): - """Transformer Model for sequence to sequence translation. - - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ - config: TransformerConfig - - def setup(self): - cfg = self.config - - if cfg.share_embeddings: - if cfg.vocab_size is not None: - assert cfg.vocab_size == cfg.vocab_size, ( - "can't share embedding with different vocab sizes.") - self.shared_embedding = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) - else: - self.shared_embedding = None - - self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) - self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): - """Applies Transformer encoder-branch on the inputs. - - Args: - inputs: input data. - inputs_positions: input subsequence positions for packed examples. - inputs_segmentation: input segmentation info for packed examples. - - Returns: - encoded feature array from the transformer encoder. - """ - cfg = self.config - # Make padding attention mask. - encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=cfg.dtype) - # Add segmentation block-diagonal attention mask if using segmented data. - if inputs_segmentation is not None: - encoder_mask = nn.combine_masks( - encoder_mask, - nn.make_attention_mask( - inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) - return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) - - def decode( - self, - encoded, - inputs, # only needed for masks - targets, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): - """Applies Transformer decoder-branch on encoded-input and target. - - Args: - encoded: encoded input data from encoder. - inputs: input data (only needed for masking). - targets: target data. - targets_positions: target subsequence positions for packed examples. - inputs_segmentation: input segmentation info for packed examples. - targets_segmentation: target segmentation info for packed examples. - - Returns: - logits array from transformer decoder. - """ - cfg = self.config - - # Make padding attention masks. - if cfg.decode: - decoder_mask = None - encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype) - else: - decoder_mask = nn.combine_masks( - nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), - nn.make_causal_mask(targets, dtype=cfg.dtype)) - encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=cfg.dtype) - - # Add segmentation block-diagonal attention masks if using segmented data. - if inputs_segmentation is not None: - decoder_mask = nn.combine_masks( - decoder_mask, - nn.make_attention_mask( - targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=cfg.dtype)) - encoder_decoder_mask = nn.combine_masks( - encoder_decoder_mask, - nn.make_attention_mask( - targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) - logits = self.decoder( - encoded, - targets, - targets_positions=targets_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) - return logits.astype(self.config.dtype) - - def __call__(self, - inputs, - targets, - inputs_positions=None, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): - """Applies Transformer model on the inputs. - - Args: - inputs: input data. - targets: target data. - inputs_positions: input subsequence positions for packed examples. - targets_positions: target subsequence positions for packed examples. - inputs_segmentation: input segmentation info for packed examples. - targets_segmentation: target segmentation info for packed examples. - - Returns: - logits array from full transformer. - """ - encoded = self.encode( - inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) - - return self.decode( - encoded, - inputs, # only used for masks - targets, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) \ No newline at end of file From 161c264c7c007f9d3f065f2affb55091cc61e492 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 05:32:28 +0000 Subject: [PATCH 088/123] lint fix --- algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 7207e033d..8524bb60e 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -93,4 +93,4 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) - return graph.globals \ No newline at end of file + return graph.globals From ac76d4ff7258fba7882d3e414c4b7f97fc0627de Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 05:47:05 +0000 Subject: [PATCH 089/123] formatting fixes --- tests/test_jax_utils.py | 43 ++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index 8a156149b..7ee61d1e1 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -4,7 +4,6 @@ """ from functools import partial -import os from absl.testing import absltest from absl.testing import parameterized @@ -63,7 +62,7 @@ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT): x, rate=dropout_rate) -class ModelEquivalenceTest(parameterized.TestCase): +class DropoutTest(parameterized.TestCase): @parameterized.named_parameters( dict( @@ -185,8 +184,7 @@ def test_dropout_update(self, dropout_rate, mode): dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), ) def test_jitted_updates(self, dropout_rate, mode): - """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and - eval mode. + """ Compare jitted updates with dropout. """ # initialize models @@ -214,24 +212,25 @@ def test_jitted_updates(self, dropout_rate, mode): jitted_custom_apply = jax.jit( partial(cust_model.apply), static_argnames=['train']) - def multiple_fwd_passes_custom_layer(): - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: - y2 = jitted_custom_apply( - initial_variables_custom, - x, - train=train, - dropout_rate=d, - rngs={"dropout": dropout_rng}, - ) - return y2 - - def multiple_fwd_passes_original_layer(): - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: - y1 = jitted_original_apply( - initial_variables_original, - x, - train=train, - rngs={"dropout": dropout_rng}) + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y1 = jitted_original_apply( + initial_variables_original, + x, + train=train, + rngs={"dropout": dropout_rng}) + return y1 + + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y2 = jitted_custom_apply( + initial_variables_custom, + x, + train=train, + dropout_rate=d, + rngs={"dropout": dropout_rng}, + ) + return y2 + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) if __name__ == "__main__": From e4eaceaf57e32b74cafd2a31e16cc64bb5a354ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 06:10:58 +0000 Subject: [PATCH 090/123] fix linting --- tests/test_jax_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index 7ee61d1e1..a90b44f62 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -22,13 +22,12 @@ def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): """ - A custom function to check if two PyTrees are equal, handling floats with a tolerance. + A custom function to check if two PyTrees are equal, handling floats with + a tolerance. """ - # 1. Check if the structures are the same if tree_structure(a) != tree_structure(b): return False - # 2. Define a comparison function for leaves def leaf_comparator(x, y): # Use allclose for floating-point JAX arrays if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): @@ -37,8 +36,6 @@ def leaf_comparator(x, y): else: return x == y - # 3. Map the comparison function over the trees and check if all results are True - # We also need to flatten the results of the tree_map and check if all are True comparison_tree = tree_map(leaf_comparator, a, b) all_equal = all(tree_leaves(comparison_tree)) @@ -80,7 +77,7 @@ def test_forward(self, dropout_rate, mode): """ # initialize models - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) fake_batch = jnp.ones((10,)) orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() @@ -130,7 +127,7 @@ def test_dropout_update(self, dropout_rate, mode): eval mode. """ # init model - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) fake_batch = jnp.ones((10,)) orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() From 0f430495f6ecd6ab6efcebcbd99bbdb7ca04f018 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 06:11:35 +0000 Subject: [PATCH 091/123] fix linting --- tests/test_jax_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index a90b44f62..22a3d0929 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -215,7 +215,6 @@ def test_jitted_updates(self, dropout_rate, mode): x, train=train, rngs={"dropout": dropout_rng}) - return y1 for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: y2 = jitted_custom_apply( @@ -225,8 +224,6 @@ def test_jitted_updates(self, dropout_rate, mode): dropout_rate=d, rngs={"dropout": dropout_rng}, ) - return y2 - assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) From a151382d5c61ce29f80003869cf0be83f374aadb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 25 Jun 2025 06:20:47 +0000 Subject: [PATCH 092/123] pylint fix --- tests/test_jax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index 22a3d0929..d54bf47aa 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -185,7 +185,7 @@ def test_jitted_updates(self, dropout_rate, mode): """ # initialize models - rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) fake_batch = jnp.ones((10,)) orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() From 78a14095273f38f03697005b37e7365560bafe55 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 14:17:23 +0200 Subject: [PATCH 093/123] Formatting, ignore `.eggs/` in yapf --- pyproject.toml | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1daa72848..f72491380 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ "clu==0.0.12", "matplotlib>=3.9.2", "tabulate==0.9.0", - ] [build-system] @@ -69,9 +68,7 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", -] +full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package @@ -104,11 +101,7 @@ jax_core_deps = [ "ml_dtypes==0.4.1", "protobuf==4.25.5", ] -jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "algoperf[jax_core_deps]", -] +jax_cpu = ["jax==0.4.28", "jaxlib==0.4.28", "algoperf[jax_core_deps]"] jax_gpu = [ "jax==0.4.28", "jaxlib==0.4.28", @@ -117,10 +110,8 @@ jax_gpu = [ "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] -pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", -] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. +pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"] +# Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. ############################################################################### # Linting Configurations # @@ -133,7 +124,7 @@ each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true column_limit = 80 [tool.yapfignore] -ignore_patterns = ["algoperf/_version.py"] +ignore_patterns = ["algoperf/_version.py", ".eggs/**"] # isort configuration [tool.isort] From 462e8b7988ab722682e0f2ceb18785ec2028cd10 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 15:44:01 +0200 Subject: [PATCH 094/123] Replace yapf, pylint, isort with ruff --- pyproject.toml | 258 ++++++++----------------------------------------- 1 file changed, 39 insertions(+), 219 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f72491380..1532a66c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,9 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ @@ -72,13 +72,7 @@ full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package -dev = [ - "isort==5.12.0", - "pylint==2.17.4", - "pytest==8.3.3", - "yapf==0.32.0", - "pre-commit==4.0.1", -] +dev = ["ruff==0.12.0", "pytest==8.3.3", "pre-commit==4.0.1"] wandb = ["wandb==0.19.6"] @@ -111,215 +105,41 @@ jax_gpu = [ ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"] -# Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. +# Note: omit the cuda suffix and installing from the appropriate wheel +# will result in using locally installed CUDA. ############################################################################### -# Linting Configurations # +# Linting & Formatting Configurations # ############################################################################### - -# yapf configuration -[tool.yapf] -based_on_style = "yapf" -each_dict_entry_on_separate_line = false -split_all_top_level_comma_separated_values = true -column_limit = 80 -[tool.yapfignore] -ignore_patterns = ["algoperf/_version.py", ".eggs/**"] - -# isort configuration -[tool.isort] -profile = "google" - -# pylint configuration -[tool.pylint.MASTER] -persistent = false -ignore = "get_references_web.py,get_references_web_single_group.py,_version.py" - -[tool.pylint.REPORTS] -reports = false -msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]" - -[tool.pylint.MESSAGES_CONTROL] -enable = "indexing-exception,old-raise-syntax" - -[tool.pylint.BASIC] -# Required attributes for module, separated by a comma -#required-attributes= -# Regular expression which should only match the name -# of functions or classes which do not require a docstring. -no-docstring-rgx = "(__.*__|main)" -# Min length in lines of a function that requires a docstring. -docstring-min-length = 10 -# Regular expression which should only match correct module names. The -# leading underscore is sanctioned for private modules by Google's style -# guide. -# -# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover -# requirements of Python's module system. -module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$" -# Regular expression which should only match correct module level names -const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" -# Regular expression which should only match correct class attribute -class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" -# Regular expression which should only match correct class names -class-rgx = "^_?[A-Z][a-zA-Z0-9]*$" -# Regular expression which should only match correct function names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. -function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$" -# Regular expression which should only match correct method names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. 'exempt' indicates a name which is -# consistent with all naming styles. -method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$" -# Regular expression which should only match correct instance attribute names -attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$" -# Regular expression which should only match correct argument names -argument-rgx = "^[a-z][a-z0-9_]*$" -# Regular expression which should only match correct variable names -variable-rgx = "^[a-z][a-z0-9_]*$" -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx = "^[a-z][a-z0-9_]*$" -# Good variable names which should always be accepted, separated by a comma -good-names = "main,_" -# Bad variable names which should always be refused, separated by a comma -bad-names = "" -# List of builtins function names that should not be used, separated by a comma -#bad-functions=input,apply,reduce -# List of decorators that define properties, such as abc.abstractproperty. -property-classes = "abc.abstractproperty" - -[tool.pylint.typecheck] -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members = true - -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators = [ - "contextlib.contextmanager", - "contextlib2.contextmanager", -] - -[tool.pylint.VARIABLES] -# Tells whether we should check for unused import in __init__ files. -init-import = false - -# A regular expression matching names used for dummy variables (i.e. not used). -dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)" - -# List of additional names supposed to be defined in builtins. -additional-builtins = [] - -[tool.pylint.CLASSES] -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods = ["__init__", "__new__", "setUp"] - -# Valid names for the first argument to a class method. -valid-classmethod-first-arg = ["cls", "class_"] - -[tool.pylint.EXCEPTIONS] -overgeneral-exceptions = [ - "builtins.StandardError", - "builtins.Exception", - "builtins.BaseException", -] - -[tool.pylint.IMPORTS] -# Deprecated modules which should not be used, separated by a comma -deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"] - -[tool.pylint.FORMAT] -# List of checkers and warnings to disable. -disable = [ - "abstract-method", - "access-member-before-definition", - "arguments-differ", - "assignment-from-no-return", - "attribute-defined-outside-init", - "bad-mcs-classmethod-argument", - "bad-option-value", - "c-extension-no-member", - "consider-merging-isinstance", - "consider-using-dict-comprehension", - "consider-using-enumerate", - "consider-using-in", - "consider-using-set-comprehension", - "consider-using-ternary", - "deprecated-method", - "design", - "file-ignored", - "fixme", - "global-statement", - "import-error", - "inconsistent-return-statements", - "invalid-unary-operand-type", - "len-as-condition", - "locally-disabled", - "locally-enabled", - "misplaced-comparison-constant", - "missing-docstring", - "multiple-imports", - "no-else-return", - "no-member", - "no-name-in-module", - "no-self-use", - "no-value-for-parameter", - "not-an-iterable", - "not-context-manager", - "pointless-except", - "protected-access", - "redefined-argument-from-local", - "signature-differs", - "similarities", - "simplifiable-if-expression", - "star-args", - "super-init-not-called", - "suppressed-message", - "too-many-function-args", - "trailing-comma-tuple", - "trailing-newlines", - "ungrouped-imports", - "unnecessary-pass", - "unsubscriptable-object", - "unused-argument", - "useless-object-inheritance", - "useless-return", - "useless-suppression", - "wrong-import-order", - "wrong-import-position", - "unneeded-not", - "unexpected-keyword-arg", - "redundant-keyword-arg", - "unspecified-encoding", - "logging-fstring-interpolation", - "consider-using-f-string", - "use-dict-literal", -] -# Maximum number of characters on a single line. -max-line-length = 80 -ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )??$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))" -# Maximum number of lines in a module -max-module-lines = 99999 -# String used as indentation unit. We differ from PEP8's normal 4 spaces. -indent-string = ' ' -single-line-if-stmt = true -# Do not warn about multiple statements on a single line for constructs like -# if test: stmt -[tool.pylint.LOGGING] -logging-modules = "logging,absl.logging" -# Add logging modules. -[tool.pylint.MISCELLANEOUS] -# Maximum line length for lambdas -#short-func-length=1 -# List of module members that should be marked as deprecated. -# All of the string functions are listed in 4.1.4 Deprecated string functions -# in the Python 2.4 docs. -#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint -# List of exceptions that do not need to be mentioned in the Raises section of -# a docstring. -#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError -# Number of spaces of indent required when the last token on the preceding line -# is an open (, [, or {. -indent-after-paren = 4 +[tool.ruff] +line-length = 80 +indent-width = 2 +exclude = ["_version.py"] +target-version = "py311" + +[tool.ruff.format] +quote-style = "single" + +[tool.ruff.lint] +select = ["I", "PL"] +ignore = [ + # Conflicting lint rules with Ruff's formatter + # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules). + "W191", + "E111", + "E114", + "E117", + "D206", + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + "COM812", + "COM819", + "ISC001", + "ISC002", + "FBT001", + "FBT003", + "TD003", +] \ No newline at end of file From 383db7a0f4a463537b1493a56808628db6f238b9 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 15:52:11 +0200 Subject: [PATCH 095/123] Replace pre-commit with ruff --- .pre-commit-config.yaml | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc8f13d25..f2d684c53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,10 @@ repos: - - repo: https://github.com/google/yapf - rev: v0.32.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.12.0 hooks: - - id: yapf - args: ["--in-place", "--parallel", "--verbose", "--recursive"] - - repo: https://github.com/pycqa/isort - rev: 5.10.1 - hooks: - - id: isort - - repo: https://github.com/pycqa/pylint - rev: v2.16.1 - hooks: - - id: pylint + # Run the linter (don't change files). + - id: ruff-check + # Run the formatter (don't change files). + - id: ruff-format + args: ["--check"] From 2c2813628e8c8e4c9942012b045ea8188db7e941 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 16:08:54 +0200 Subject: [PATCH 096/123] Use extend-select instead, and reduce lint rules --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1532a66c7..2db40f61f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,8 @@ target-version = "py311" quote-style = "single" [tool.ruff.lint] -select = ["I", "PL"] +extend-select = ["I"] +# Could add (in the future): "E", "F", "UP", "B", "SIM", "PL" ignore = [ # Conflicting lint rules with Ruff's formatter # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules). From 999f7a2f280f3c91c8f29cd4652200f8d63dae24 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 16:09:09 +0200 Subject: [PATCH 097/123] Replace linting GH actions with ruff --- .github/workflows/linting.yml | 39 +++++++++-------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 992496b69..4308e4201 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -3,7 +3,7 @@ name: Linting on: [push, pull_request] jobs: - pylint: + ruff-linting: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -11,19 +11,15 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.11.10 - - name: Install pylint + - name: Install ruff run: | python -m pip install --upgrade pip - pip install pylint==2.16.1 - - name: Run pylint + pip install ruff==0.12.0 + - name: Run ruff linter run: | - pylint algoperf - pylint reference_algorithms - pylint prize_qualification_baselines - pylint submission_runner.py - pylint tests + ruff check - isort: + ruff-formatter: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -31,26 +27,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.11.10 - - name: Install isort + - name: Install ruff run: | python -m pip install --upgrade pip - pip install isort==5.12.0 - - name: Run isort + pip install ruff==0.12.0 + - name: Run ruff formatter run: | - isort . --check --diff + ruff format --check - yapf: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.11.10 - uses: actions/setup-python@v2 - with: - python-version: 3.11.10 - - name: Install yapf - run: | - python -m pip install --upgrade pip - pip install yapf==0.32 toml - - name: Run yapf - run: | - yapf . --diff --recursive From 7b245d6a6692509b4e994106a1c3587a640e4165 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 16:41:28 +0200 Subject: [PATCH 098/123] Add ruff badge --- README.md | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 8e470266d..0666d21d5 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ Benchmark/Results Paper

-[![CI](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml/badge.svg)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml) -[![Lint](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml/badge.svg)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml) -[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/mlcommons/algorithmic-efficiency/blob/main/LICENSE.md) -[![Code style: yapf](https://img.shields.io/badge/code%20style-yapf-orange)](https://github.com/google/yapf) -[![Discord](https://dcbadge.vercel.app/api/server/5FPXK7SMt6?style=flat)](https://discord.gg/5FPXK7SMt6) +[![CI Status](https://img.shields.io/github/actions/workflow/status/mlcommons/algorithmic-efficiency/CI.yml?style=flat-square&logo=github&label=CI)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml) +[![Linting Status](https://img.shields.io/github/actions/workflow/status/mlcommons/algorithmic-efficiency/linting.yml?style=flat-square&logo=github&label=Linting)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml) +[![Code Style Ruff](https://img.shields.io/badge/Code%20Style-Ruff-brightgreen?style=flat-square&logo=ruff)](https://github.com/astral-sh/ruff) +[![GitHub License](https://img.shields.io/github/license/mlcommons/algorithmic-efficiency?style=flat-square&label=License)](LICENSE.md) +[![Discord](https://dcbadge.limes.pink/api/server/5FPXK7SMt6?style=flat-square)](https://discord.gg/5FPXK7SMt6) --- @@ -28,11 +28,12 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock --- -> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements. +> This is the repository for the _AlgoPerf: Training Algorithms benchmark_ measuring neural network training speedups due to algorithmic improvements. > It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/). > This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm). > > **See our [AlgoPerf Leaderboard](https://github.com/mlcommons/submissions_algorithms) for the latest results of the benchmark and to submit your algorithm.** + --- > [!IMPORTANT] @@ -50,14 +51,13 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock ## Installation -> [!TIP] -> **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/). +> [!TIP] > **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/). You can install this package and dependencies in a [Python virtual environment](/docs/GETTING_STARTED.md#python-virtual-environment) or use a [Docker/Singularity/Apptainer container](/docs/GETTING_STARTED.md#docker) (recommended). We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document. -*TL;DR to install the Jax version for GPU run:* +_TL;DR to install the Jax version for GPU run:_ ```bash pip3 install -e '.[pytorch_cpu]' @@ -65,7 +65,7 @@ pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax pip3 install -e '.[full]' ``` -*TL;DR to install the PyTorch version for GPU run:* +_TL;DR to install the PyTorch version for GPU run:_ ```bash pip3 install -e '.[jax_cpu]' @@ -77,7 +77,7 @@ pip3 install -e '.[full]' For detailed instructions on developing your own algorithm in the benchmark see the [Getting Started](/docs/GETTING_STARTED.md) document. -*TL;DR running a JAX workload:* +_TL;DR running a JAX workload:_ ```bash python3 submission_runner.py \ @@ -89,7 +89,7 @@ python3 submission_runner.py \ --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` -*TL;DR running a PyTorch workload:* +_TL;DR running a PyTorch workload:_ ```bash python3 submission_runner.py \ @@ -117,17 +117,15 @@ Our [**Contributing**](/docs/CONTRIBUTING.md) document provides further MLCommon ## License -The *AlgoPerf* codebase is licensed under the [Apache License 2.0](/LICENSE.md). +The _AlgoPerf_ codebase is licensed under the [Apache License 2.0](/LICENSE.md). ## Paper and Citing the AlgoPerf Benchmark -In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the *AlgoPerf: Training Algorithms* benchmark. +In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the _AlgoPerf: Training Algorithms_ benchmark. -If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads, please consider citing our paper: +If you are using the _AlgoPerf benchmark_, its codebase, baselines, or workloads, please consider citing our paper: -> [Dahl, Schneider, Nado, et al.
-> **Benchmarking Neural Network Training Algorithms**
-> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179) +> [Dahl, Schneider, Nado, et al.
> **Benchmarking Neural Network Training Algorithms**
> _arXiv 2306.07179_](http://arxiv.org/abs/2306.07179) ```bibtex @Misc{Dahl2023AlgoPerf, @@ -139,10 +137,9 @@ If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads } ``` -If you use the results from the first *AlgoPerf competition*, please consider citing the results paper, as well as the relevant submissions: +If you use the results from the first _AlgoPerf competition_, please consider citing the results paper, as well as the relevant submissions: -> [Kasimbeg, Schneider, Eschenhagen, et al.
-> **Accelerating neural network training: An analysis of the AlgoPerf competition**
+> [Kasimbeg, Schneider, Eschenhagen, et al.
> **Accelerating neural network training: An analysis of the AlgoPerf competition**
> ICLR 2025](https://openreview.net/forum?id=CtM5xjRSfm) ```bibtex From 277674c64cfdf671c1721c5e32feb8219bc19a2c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 18 Jun 2025 16:54:06 +0200 Subject: [PATCH 099/123] Update style testing with ruff --- docs/CONTRIBUTING.md | 42 +++++++++++++++--------------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 7778030dc..e9918e14c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -179,13 +179,13 @@ docker run -t -d \ To find the container IDs of running containers ```bash -docker ps +docker ps ``` To see the logging output ```bash -docker logs +docker logs ``` To enter a bash session in the container @@ -209,7 +209,7 @@ docker run -t -d \ --gpus all \ --ipc=host \ \ ---keep_container_alive true +--keep_container_alive true ``` ## Submitting PRs @@ -222,38 +222,26 @@ We run tests with GitHub Actions, configured in the [.github/workflows](.github/ ### Style Testing -We run yapf and linting tests on PRs. You can view and fix offending errors with these instructions. - +We run formatting and linting tests via ruff on PRs. You can view and fix offending errors with these instructions. To run the below commands, use the versions installed via `pip install -e '.[dev]'`. -To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!): +To check whether your code is **formatted** correctly, run the following: ```bash -yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py +ruff format --check ``` -To sort all import orderings, run the following: +To automatically fix formatting errors you can run `ruff format`, without the `--check` flag. +(**WARNING**: this will edit your code, so it is suggested to make a git commit first!) -```bash -isort . -``` - -To just print out all offending import orderings, run the following: +To check whether your code is **linted** correctly, run the following: ```bash -isort . --check --diff +ruff check ``` -To print out all offending pylint issues, run the following: - -```bash -pylint algoperf -pylint datasets -pylint prize_qualification_baselines -pylint reference_algorithms -pylint submission_runner.py -pylint tests -``` +To automatically fix linting errors you can run `ruff check --fix`, with the additional `--fix` flag. +(**WARNING**: this will edit your code, so it is suggested to make a git commit first!) ### Unit and Integration Tests @@ -270,9 +258,9 @@ To run a regression test: 1. Build and upload latest Docker images from dev branch. - ```bash - bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev - ``` + ```bash + bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev + ``` 2. Turn on the self-hosted runner. 3. Run the self-hosted runner application for the runner to accept jobs. From 830d2c2506e2f85d737e0cb9ce4f433b7d898923 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:16:03 +0200 Subject: [PATCH 100/123] Format submission_runner --- submission_runner.py | 772 ++++++++++++++++++++++++------------------- 1 file changed, 426 insertions(+), 346 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 221a7c21d..8dc8589cb 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,41 +17,37 @@ import datetime import gc import importlib -from inspect import signature import itertools import json import os import struct import time +from inspect import signature from types import MappingProxyType from typing import Any, Dict, Optional, Tuple -from absl import app -from absl import flags -from absl import logging import jax import tensorflow as tf import torch import torch.distributed as dist +from absl import app, flags, logging # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') -from algoperf import checkpoint_utils -from algoperf import halton -from algoperf import logger_utils -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.profiler import Profiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.pytorch_utils import sync_ddp_time -from algoperf.workloads import workloads +from algoperf import checkpoint_utils, halton, logger_utils, spec # noqa: E402 +from algoperf import random_utils as prng # noqa: E402 +from algoperf.profiler import PassThroughProfiler, Profiler # noqa: E402 +from algoperf.pytorch_utils import ( # noqa: E402 + pytorch_init, + pytorch_setup, + sync_ddp_time, +) +from algoperf.workloads import workloads # noqa: E402 # Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' @@ -62,106 +58,121 @@ WORKLOADS = workloads.WORKLOADS flags.DEFINE_string( - 'submission_path', - None, - 'The relative path of the Python file containing submission functions. ' - 'NOTE: the submission dir must have an __init__.py file!') + 'submission_path', + None, + 'The relative path of the Python file containing submission functions. ' + 'NOTE: the submission dir must have an __init__.py file!', +) flags.DEFINE_string( - 'workload', - None, - help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}' + 'workload', + None, + help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}', ) flags.DEFINE_enum( - 'tuning_ruleset', - 'external', - enum_values=['external', 'self'], - help='Which tuning ruleset to use.') + 'tuning_ruleset', + 'external', + enum_values=['external', 'self'], + help='Which tuning ruleset to use.', +) flags.DEFINE_string( - 'tuning_search_space', - None, - 'The path to the JSON file describing the external tuning search space.') -flags.DEFINE_integer('num_tuning_trials', - 1, - 'The number of external hyperparameter trials to run.') + 'tuning_search_space', + None, + 'The path to the JSON file describing the external tuning search space.', +) +flags.DEFINE_integer( + 'num_tuning_trials', 1, 'The number of external hyperparameter trials to run.' +) flags.DEFINE_string('data_dir', '~/data', 'Dataset location.') -flags.DEFINE_string('imagenet_v2_data_dir', - None, - 'Dataset location for ImageNet-v2.') -flags.DEFINE_string('librispeech_tokenizer_vocab_path', - '', - 'Location to librispeech tokenizer.') +flags.DEFINE_string( + 'imagenet_v2_data_dir', None, 'Dataset location for ImageNet-v2.' +) +flags.DEFINE_string( + 'librispeech_tokenizer_vocab_path', '', 'Location to librispeech tokenizer.' +) flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.', +) flags.DEFINE_boolean( - 'torch_compile', - True, - 'Whether to use `torch.compile` to JIT-compile PyTorch code. ' - 'This will only take effect when `framework`==pytorch.') + 'torch_compile', + True, + 'Whether to use `torch.compile` to JIT-compile PyTorch code. ' + 'This will only take effect when `framework`==pytorch.', +) flags.DEFINE_string( - 'experiment_dir', - None, - 'The root directory to store all experiments. ' - 'It is required and the directory should have ' - 'an absolute path rather than a relative path.') + 'experiment_dir', + None, + 'The root directory to store all experiments. ' + 'It is required and the directory should have ' + 'an absolute path rather than a relative path.', +) flags.DEFINE_string('experiment_name', None, 'Name of the experiment.') flags.DEFINE_boolean( - 'save_checkpoints', - True, - 'Whether or not to save checkpoints of the model and optimizer ' - 'at every eval and after training.') + 'save_checkpoints', + True, + 'Whether or not to save checkpoints of the model and optimizer ' + 'at every eval and after training.', +) +flags.DEFINE_boolean( + 'save_intermediate_checkpoints', + True, + 'Whether to save any intermediate checkpoints. ' + 'If False, it will only keep the latest checkpoint.', +) flags.DEFINE_boolean( - 'save_intermediate_checkpoints', - True, - 'Whether to save any intermediate checkpoints. ' - 'If False, it will only keep the latest checkpoint.') -flags.DEFINE_boolean('resume_last_run', - None, - 'Whether to resume the experiment from its last run.') + 'resume_last_run', None, 'Whether to resume the experiment from its last run.' +) flags.DEFINE_boolean( - 'append_timestamp', - False, - 'If True, the current datetime will be appended to the experiment name. ' - 'Useful for guaranteeing a unique experiment dir for new runs.') -flags.DEFINE_boolean('use_wandb', - False, - 'Whether to use Weights & Biases logging.') + 'append_timestamp', + False, + 'If True, the current datetime will be appended to the experiment name. ' + 'Useful for guaranteeing a unique experiment dir for new runs.', +) +flags.DEFINE_boolean( + 'use_wandb', False, 'Whether to use Weights & Biases logging.' +) flags.DEFINE_boolean('profile', False, 'Whether to produce profiling output.') -flags.DEFINE_integer('max_global_steps', - None, - 'Maximum number of update steps.') +flags.DEFINE_integer( + 'max_global_steps', None, 'Maximum number of update steps.' +) flags.DEFINE_boolean( - 'overwrite', - False, - 'Whether to overwrite the experiment with identical experiment_dir and' - 'experiment_name.') + 'overwrite', + False, + 'Whether to overwrite the experiment with identical experiment_dir and' + 'experiment_name.', +) flags.DEFINE_integer( - 'hparam_start_index', - None, - 'Start index to slice set of hyperparameters in tuning search space.') + 'hparam_start_index', + None, + 'Start index to slice set of hyperparameters in tuning search space.', +) flags.DEFINE_integer( - 'hparam_end_index', - None, - 'End index to slice set of hyperparameters in tuning search space.') + 'hparam_end_index', + None, + 'End index to slice set of hyperparameters in tuning search space.', +) flags.DEFINE_integer( - 'rng_seed', - None, - 'Value of rng seed. If None, a random seed will' - 'be generated from hardware.') -flags.DEFINE_boolean('set_pytorch_max_split_size', - False, - 'If true, set pytorch max_split_size_mb to 256') + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed willbe generated from hardware.', +) +flags.DEFINE_boolean( + 'set_pytorch_max_split_size', + False, + 'If true, set pytorch max_split_size_mb to 256', +) flags.DEFINE_integer( - 'pytorch_eval_num_workers', - 0, - 'Number of workers for ImageNet PyTorch evaluation data loaders.' - 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' - 'in incorrect evals currently, see issues/732.') + 'pytorch_eval_num_workers', + 0, + 'Number of workers for ImageNet PyTorch evaluation data loaders.' + 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' + 'in incorrect evals currently, see issues/732.', +) FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -193,23 +204,23 @@ def _reset_cuda_mem(): def train_once( - workload: spec.Workload, - workload_name: str, - global_batch_size: int, - global_eval_batch_size: int, - data_dir: str, - imagenet_v2_data_dir: str, - init_optimizer_state: spec.InitOptimizerFn, - update_params: spec.UpdateParamsFn, - data_selection: spec.DataSelectionFn, - prepare_for_eval: Optional[spec.PrepareForEvalFn], - hyperparameters: Optional[spec.Hyperparameters], - rng_seed: int, - rng: spec.RandomState, - profiler: Profiler, - max_global_steps: int = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True + workload: spec.Workload, + workload_name: str, + global_batch_size: int, + global_eval_batch_size: int, + data_dir: str, + imagenet_v2_data_dir: str, + init_optimizer_state: spec.InitOptimizerFn, + update_params: spec.UpdateParamsFn, + data_selection: spec.DataSelectionFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], + hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, + rng: spec.RandomState, + profiler: Profiler, + max_global_steps: int = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) @@ -222,40 +233,44 @@ def train_once( workload.eval_num_workers = FLAGS.pytorch_eval_num_workers with profiler.profile('Initializing dataset'): input_queue = workload._build_input_queue( - data_rng, - 'train', - data_dir=data_dir, - global_batch_size=global_batch_size) + data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size + ) logging.info('Initializing model.') with profiler.profile('Initializing model'): model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'librispeech_conformer', - 'ogbg', - 'criteo1tb', - 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_conformer', + 'ogbg', + 'criteo1tb', + 'imagenet_vit', + 'librispeech_deepspeech', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] loss_compilation_workloads = [ - 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt' + 'fastmri', + 'librispeech_deepspeech', + 'ogbg', + 'wmt', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding without `torch.compile`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding without `torch.compile`.' + ) elif base_workload in eager_backend_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding with `backend=eager`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding with `backend=eager`.' + ) model_params = torch.compile(model_params, backend='eager') elif base_workload in aot_eager_backend_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding with `backend=aot_eager`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding with `backend=aot_eager`.' + ) model_params = torch.compile(model_params, backend='aot_eager') else: logging.info('Performing `torch.compile`.') @@ -264,11 +279,9 @@ def train_once( workload.loss_fn = torch.compile(workload.loss_fn) logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): - optimizer_state = init_optimizer_state(workload, - model_params, - model_state, - hyperparameters, - opt_init_rng) + optimizer_state = init_optimizer_state( + workload, model_params, model_state, hyperparameters, opt_init_rng + ) logging.info('Initializing metrics bundle.') # Check if 'train_state' is in the function signature @@ -276,15 +289,15 @@ def train_once( # Bookkeeping. train_state = { - 'validation_goal_reached': False, - 'test_goal_reached': False, - 'is_time_remaining': True, - 'last_eval_time': 0, - 'training_complete': False, - 'accumulated_submission_time': 0, - 'accumulated_eval_time': 0, - 'accumulated_logging_time': 0, - 'last_step_end_time': None, + 'validation_goal_reached': False, + 'test_goal_reached': False, + 'is_time_remaining': True, + 'last_eval_time': 0, + 'training_complete': False, + 'accumulated_submission_time': 0, + 'accumulated_eval_time': 0, + 'accumulated_logging_time': 0, + 'last_step_end_time': None, } global_step = 0 eval_results = [] @@ -294,22 +307,25 @@ def train_once( logging.info('Initializing checkpoint and logger.') if log_dir is not None: # If the checkpoint exists, load from the checkpoint. - (optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count) = checkpoint_utils.maybe_restore_checkpoint( - FLAGS.framework, - optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count, - checkpoint_dir=log_dir) + ( + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + ) = checkpoint_utils.maybe_restore_checkpoint( + FLAGS.framework, + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + checkpoint_dir=log_dir, + ) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') meta_data = logger_utils.get_meta_data(workload, rng_seed) @@ -319,9 +335,9 @@ def train_once( logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) metrics_logger = None if RANK == 0: - metrics_logger = logger_utils.set_up_loggers(log_dir, - flags.FLAGS, - hyperparameters) + metrics_logger = logger_utils.set_up_loggers( + log_dir, flags.FLAGS, hyperparameters + ) workload.attach_metrics_logger(metrics_logger) global_start_time = get_time() @@ -329,42 +345,50 @@ def train_once( logging.info('Starting training loop.') goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - while train_state['is_time_remaining'] and \ - not goals_reached and \ - not train_state['training_complete']: - + train_state['validation_goal_reached'] and train_state['test_goal_reached'] + ) + while ( + train_state['is_time_remaining'] + and not goals_reached + and not train_state['training_complete'] + ): step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, prep_eval_rng, eval_rng = \ - prng.split(step_rng, 4) + data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split( + step_rng, 4 + ) with profiler.profile('Data selection'): - batch = data_selection(workload, - input_queue, - optimizer_state, - model_params, - model_state, - hyperparameters, - global_step, - data_select_rng) + batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + model_state, + hyperparameters, + global_step, + data_select_rng, + ) try: with profiler.profile('Update parameters'): optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - batch=batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=update_rng, - **({'train_state': MappingProxyType(train_state)} - if needs_train_state else {})) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + batch=batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=update_rng, + **( + {'train_state': MappingProxyType(train_state)} + if needs_train_state + else {} + ), + ) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 @@ -374,121 +398,139 @@ def train_once( train_step_end_time = get_time() train_state['accumulated_submission_time'] += ( - train_step_end_time - train_state['last_step_end_time']) + train_step_end_time - train_state['last_step_end_time'] + ) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): - + if ( + train_step_end_time - train_state['last_eval_time'] + ) >= workload.eval_period_time_sec or train_state['training_complete']: # Prepare for evaluation (timed). if prepare_for_eval is not None: - with profiler.profile('Prepare for eval'): del batch prepare_for_eval_start_time = get_time() optimizer_state, model_params, model_state = prepare_for_eval( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=prep_eval_rng) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng, + ) prepare_for_eval_end_time = get_time() # Update sumbission time. train_state['accumulated_submission_time'] += ( - prepare_for_eval_end_time - prepare_for_eval_start_time) + prepare_for_eval_end_time - prepare_for_eval_start_time + ) # Check if time is remaining, # use 1.5x the runtime budget for the self-tuning ruleset. max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 1.5 * workload.max_allowed_runtime_sec) + workload.max_allowed_runtime_sec + if FLAGS.tuning_ruleset == 'external' + else 1.5 * workload.max_allowed_runtime_sec + ) train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + train_state['accumulated_submission_time'] < max_allowed_runtime_sec + ) # Eval if time is remaining (untimed). if train_state['is_time_remaining']: - with profiler.profile('Evaluation'): _reset_cuda_mem() try: eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) + latest_eval_result = workload.eval_model( + global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step, + ) # Check if targets reached. # Note that this is one of the stopping conditions for the length of # a training run. To score the run we only consider the time # to validation target retrospectively. train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) + workload.has_reached_validation_target(latest_eval_result) + or train_state['validation_goal_reached'] + ) train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) + workload.has_reached_test_target(latest_eval_result) + or train_state['test_goal_reached'] + ) goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) + train_state['validation_goal_reached'] + and train_state['test_goal_reached'] + ) # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time + train_state['accumulated_eval_time'] += ( + eval_end_time - eval_start_time + ) # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['score'] = train_state[ + 'accumulated_submission_time' + ] + latest_eval_result['total_duration'] = ( + eval_end_time - global_start_time + ) latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] + 'accumulated_submission_time' + ] latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] + 'accumulated_eval_time' + ] latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] + 'accumulated_logging_time' + ] time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') + logging.info( + f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}' + ) eval_results.append((global_step, latest_eval_result)) logging_start_time = get_time() if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, ) if save_checkpoints: checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints, + ) logging_end_time = get_time() train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + logging_end_time - logging_start_time + ) _reset_cuda_mem() @@ -496,8 +538,9 @@ def train_once( logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): logging.warning( - 'Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.' + ) _reset_cuda_mem() train_state['last_step_end_time'] = get_time() @@ -506,41 +549,45 @@ def train_once( if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( - {'score': train_state['accumulated_submission_time']}, - global_step=global_step, - preemption_count=preemption_count) + {'score': train_state['accumulated_submission_time']}, + global_step=global_step, + preemption_count=preemption_count, + ) metrics_logger.finish() if save_checkpoints: checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints) + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints, + ) return train_state['accumulated_submission_time'], metrics -def score_submission_on_workload(workload: spec.Workload, - workload_name: str, - submission_path: str, - data_dir: str, - tuning_ruleset: str, - profiler: Optional[Profiler] = None, - max_global_steps: Optional[int] = None, - imagenet_v2_data_dir: Optional[str] = None, - tuning_search_space: Optional[str] = None, - num_tuning_trials: Optional[int] = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True, - hparam_start_index: Optional[bool] = None, - hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None): +def score_submission_on_workload( + workload: spec.Workload, + workload_name: str, + submission_path: str, + data_dir: str, + tuning_ruleset: str, + profiler: Optional[Profiler] = None, + max_global_steps: Optional[int] = None, + imagenet_v2_data_dir: Optional[str] = None, + tuning_search_space: Optional[str] = None, + num_tuning_trials: Optional[int] = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, + rng_seed: Optional[int] = None, +): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -564,18 +611,21 @@ def score_submission_on_workload(workload: spec.Workload, n_gpus = max(N_GPUS, jax.local_device_count()) if global_batch_size % n_gpus != 0: raise ValueError( - f'The global batch size ({global_batch_size}) has to be divisible by ' - f'the number of GPUs ({n_gpus}).') + f'The global batch size ({global_batch_size}) has to be divisible by ' + f'the number of GPUs ({n_gpus}).' + ) if hasattr(submission_module, 'get_eval_batch_size'): # If the user specifies the eval batch size, use the provided one. global_eval_batch_size = submission_module.get_eval_batch_size( - workload_name) + workload_name + ) else: global_eval_batch_size = workload.eval_batch_size if global_eval_batch_size % n_gpus != 0: raise ValueError( - f'The global eval batch size ({global_eval_batch_size}) has to be ' - f'divisible by the number of GPUs ({n_gpus}).') + f'The global eval batch size ({global_eval_batch_size}) has to be ' + f'divisible by the number of GPUs ({n_gpus}).' + ) if tuning_ruleset == 'external': # If the submission runner is responsible for hyperparameter tuning, load in @@ -583,15 +633,18 @@ def score_submission_on_workload(workload: spec.Workload, # settings from it. if tuning_search_space is None: raise ValueError( - 'Must provide a tuning search space JSON file when using external ' - 'tuning.') + 'Must provide a tuning search space JSON file when using external ' + 'tuning.' + ) with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file: tuning_search_space = halton.generate_search( - json.load(search_space_file), num_tuning_trials) + json.load(search_space_file), num_tuning_trials + ) all_timings = {} all_metrics = {} tuning_search_space_iter = itertools.islice( - enumerate(tuning_search_space), hparam_start_index, hparam_end_index) + enumerate(tuning_search_space), hparam_start_index, hparam_end_index + ) for hi, hyperparameters in tuning_search_space_iter: # Generate a new seed from hardware sources of randomness for each trial. if not rng_seed: @@ -615,25 +668,31 @@ def score_submission_on_workload(workload: spec.Workload, # If existing hyperparameter exists, use saved # hyperparameters for consistency. - hyperparameters = logger_utils.write_hparams(hyperparameters, - tuning_dir_name) + hyperparameters = logger_utils.write_hparams( + hyperparameters, tuning_dir_name + ) tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): - timing, metrics = train_once(workload, workload_name, - global_batch_size, - global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, - update_params, data_selection, - prepare_for_eval, - hyperparameters, - rng_seed, - rng, - profiler, - max_global_steps, - tuning_dir_name, - save_checkpoints=save_checkpoints,) + timing, metrics = train_once( + workload, + workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, + imagenet_v2_data_dir, + init_optimizer_state, + update_params, + data_selection, + prepare_for_eval, + hyperparameters, + rng_seed, + rng, + profiler, + max_global_steps, + tuning_dir_name, + save_checkpoints=save_checkpoints, + ) all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') @@ -647,7 +706,8 @@ def score_submission_on_workload(workload: spec.Workload, else: if tuning_search_space is not None: raise ValueError( - 'Cannot provide a tuning search space when using self tuning.') + 'Cannot provide a tuning search space when using self tuning.' + ) if not rng_seed: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) @@ -659,11 +719,24 @@ def score_submission_on_workload(workload: spec.Workload, logger_utils.makedir(log_dir) with profiler.profile('Train'): score, _ = train_once( - workload, workload_name, global_batch_size, global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, update_params, data_selection, prepare_for_eval, - None, rng_seed, rng, profiler, max_global_steps, log_dir, - save_checkpoints=save_checkpoints) + workload, + workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, + imagenet_v2_data_dir, + init_optimizer_state, + update_params, + data_selection, + prepare_for_eval, + None, + rng_seed, + rng, + profiler, + max_global_steps, + log_dir, + save_checkpoints=save_checkpoints, + ) return score @@ -687,59 +760,66 @@ def main(_): # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: logging.warning( - 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' - 'in incorrect evals currently, see issues/732.') + 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' + 'in incorrect evals currently, see issues/732.' + ) workload_metadata = WORKLOADS[FLAGS.workload] if base_workload in [ - 'librispeech_conformer', - 'librispeech_deepspeech', - 'imagenet_vit', - 'criteo1tb' + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( - BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + f'_{FLAGS.framework}', - 'workload.py') + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'_{FLAGS.framework}', + 'workload.py', + ) workload_init_kwargs = {} if FLAGS.librispeech_tokenizer_vocab_path: workload_init_kwargs['tokenizer_vocab_path'] = ( - FLAGS.librispeech_tokenizer_vocab_path) + FLAGS.librispeech_tokenizer_vocab_path + ) workload = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs=workload_init_kwargs) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S') - logging_dir_path = logger_utils.get_log_dir(FLAGS.experiment_dir, - FLAGS.workload, - FLAGS.framework, - experiment_name, - FLAGS.resume_last_run, - FLAGS.overwrite) + logging_dir_path = logger_utils.get_log_dir( + FLAGS.experiment_dir, + FLAGS.workload, + FLAGS.framework, + experiment_name, + FLAGS.resume_last_run, + FLAGS.overwrite, + ) score = score_submission_on_workload( - workload=workload, - workload_name=FLAGS.workload, - submission_path=FLAGS.submission_path, - data_dir=FLAGS.data_dir, - tuning_ruleset=FLAGS.tuning_ruleset, - profiler=profiler, - max_global_steps=FLAGS.max_global_steps, - imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir, - tuning_search_space=FLAGS.tuning_search_space, - num_tuning_trials=FLAGS.num_tuning_trials, - log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints, - hparam_start_index=FLAGS.hparam_start_index, - hparam_end_index=FLAGS.hparam_end_index, - rng_seed=FLAGS.rng_seed) + workload=workload, + workload_name=FLAGS.workload, + submission_path=FLAGS.submission_path, + data_dir=FLAGS.data_dir, + tuning_ruleset=FLAGS.tuning_ruleset, + profiler=profiler, + max_global_steps=FLAGS.max_global_steps, + imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir, + tuning_search_space=FLAGS.tuning_search_space, + num_tuning_trials=FLAGS.num_tuning_trials, + log_dir=logging_dir_path, + save_checkpoints=FLAGS.save_checkpoints, + hparam_start_index=FLAGS.hparam_start_index, + hparam_end_index=FLAGS.hparam_end_index, + rng_seed=FLAGS.rng_seed, + ) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From d84eddf2bcf6e4db1f7b47a13ea11affc5192f16 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:16:12 +0200 Subject: [PATCH 101/123] Format submissions/ --- submissions/submission_checker.py | 38 ++++---- submissions/template/submission.py | 138 +++++++++++++++-------------- 2 files changed, 95 insertions(+), 81 deletions(-) diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py index ab657c0f0..fcc4e1faf 100644 --- a/submissions/submission_checker.py +++ b/submissions/submission_checker.py @@ -41,7 +41,8 @@ def _check_ruleset_subdirs(submission_dir): contents = os.listdir(submission_dir) if not ((EXTERNAL_TUNING in contents) or (SELF_TUNING in contents)): logging.info( - f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.') + f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.' + ) return False return True @@ -54,7 +55,7 @@ def _check_submission_module(submission_dir): contents = os.listdir(os.path.join(root, submission_dir)) if SUBMISSION_MODULE not in contents: logging.info( - f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}' + f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}' ) return False return True @@ -68,7 +69,7 @@ def _check_tuning_search_space_file(submission_dir): contents = os.listdir(os.path.join(root, submission_dir)) if TUNING_SEARCH_SPACE_FILENAME not in contents: logging.info( - f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}' + f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}' ) return False return True @@ -76,18 +77,22 @@ def _check_tuning_search_space_file(submission_dir): def run_checks(submission_dir): """Top-level checker function. - Call individual checkers from this function. - """ + Call individual checkers from this function. + """ logging.info('Running repository checks.') # Execute checks contains_ruleset_subdirs = _check_ruleset_subdirs(submission_dir) contains_submission_module = _check_submission_module(submission_dir) contains_tuning_search_space_file = _check_tuning_search_space_file( - submission_dir) + submission_dir + ) - if not (contains_ruleset_subdirs and contains_submission_module and - contains_tuning_search_space_file): + if not ( + contains_ruleset_subdirs + and contains_submission_module + and contains_tuning_search_space_file + ): logging.info('TESTS FAILED.') return False @@ -98,16 +103,17 @@ def run_checks(submission_dir): def get_parser(): """Parse commandline.""" parser = argparse.ArgumentParser( - description='Checks for submission folder for AlgoPerf',) + description='Checks for submission folder for AlgoPerf', + ) parser.add_argument( - 'folder', - type=str, - help='the folder for a submission package.', + 'folder', + type=str, + help='the folder for a submission package.', ) parser.add_argument( - '--log_output', - type=str, - default='submission_checker.log', + '--log_output', + type=str, + default='submission_checker.log', ) return parser @@ -118,7 +124,7 @@ def main(): logging.basicConfig(filename=args.log_output, level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler()) - formatter = logging.Formatter("%(levelname)s - %(message)s") + formatter = logging.Formatter('%(levelname)s - %(message)s') logging.getLogger().handlers[0].setFormatter(formatter) logging.getLogger().handlers[1].setFormatter(formatter) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 2269b7dbb..db6900afd 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,94 +4,102 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ + from typing import Any, Dict, Iterator, List, Optional, Tuple from algoperf import spec -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule. - Returns: spec.OptimizerState initialized optimizer state - """ + Returns: spec.OptimizerState initialized optimizer state + """ pass def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """ + Returns: + spec.OptimizerState: new optimizer state + spec.ParameterTypeTree: new params + new_model_state: new model state """ - Returns: - spec.OptimizerState: new optimizer state - spec.ParameterTypeTree: new params - new_model_state: new model state - """ pass -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state """ - Returns: - new_optimizer_state - new_params - new_model_state - """ pass def get_batch_size(workload_name): """ - Gets batch size for workload. - Note that these batch sizes only apply during training and not during evals. - Args: - workload_name (str): Valid workload_name values are: "wmt", "ogbg", - "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", - "librispeech_deepspeech", "librispeech_conformer" or any of the - variants. - Returns: - int: batch_size - Raises: - ValueError: If workload_name is not handled. - """ + Gets batch size for workload. + Note that these batch sizes only apply during training and not during evals. + Args: + workload_name (str): Valid workload_name values are: "wmt", "ogbg", + "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", + "librispeech_deepspeech", "librispeech_conformer" or any of the + variants. + Returns: + int: batch_size + Raises: + ValueError: If workload_name is not handled. + """ pass -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - Tip: - If you would just like the next batch from the input queue return next(input_queue). + Each element of the queue is a batch of training examples and labels. + Tip: + If you would just like the next batch from the input queue return next(input_queue). - Returns: - batch: next batch of input data - """ + Returns: + batch: next batch of input data + """ pass From fbbeafa65541419f8cc1879948d112d34539fb5c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:17:59 +0200 Subject: [PATCH 102/123] Format scoring/ --- .../generate_held_out_workloads.py | 70 +++--- scoring/algoperf_v05/score_submissions.py | 197 ++++++++------- scoring/compute_speedups.py | 71 +++--- scoring/performance_profile.py | 218 +++++++++------- scoring/score_submissions.py | 190 ++++++++------ scoring/scoring_utils.py | 169 ++++++------- scoring/test_performance_profile.py | 4 +- scoring/test_scoring_utils.py | 4 +- scoring/utils/package_logs.py | 12 +- scoring/utils/run_workloads.py | 232 ++++++++++-------- scoring/utils/slurm/README.md | 41 ++-- .../utils/slurm/algoperf_slurm_cluster.yaml | 130 +++++----- .../slurm/algoperf_slurm_pakcer_builder.yaml | 169 +++++++------ scoring/utils/slurm/config.json | 210 ++++++++-------- scoring/utils/slurm/make_job_config.py | 72 +++--- .../workload_metadata_external_tuning.json | 66 ++--- .../utils/workload_metadata_self_tuning.json | 66 ++--- 17 files changed, 1030 insertions(+), 891 deletions(-) diff --git a/scoring/algoperf_v05/generate_held_out_workloads.py b/scoring/algoperf_v05/generate_held_out_workloads.py index 647dc3c3d..e9ebf6a53 100644 --- a/scoring/algoperf_v05/generate_held_out_workloads.py +++ b/scoring/algoperf_v05/generate_held_out_workloads.py @@ -2,49 +2,51 @@ import os import struct -from absl import app -from absl import flags -from absl import logging import numpy as np +from absl import app, flags, logging flags.DEFINE_integer( - 'held_out_workloads_seed', - None, - 'Random seed for scoring.' - 'AlgoPerf v0.5 seed: 3438810845') -flags.DEFINE_string('output_filename', - 'held_out_workloads.json', - 'Path to file to record sampled held_out workloads.') + 'held_out_workloads_seed', + None, + 'Random seed for scoring.AlgoPerf v0.5 seed: 3438810845', +) +flags.DEFINE_string( + 'output_filename', + 'held_out_workloads.json', + 'Path to file to record sampled held_out workloads.', +) FLAGS = flags.FLAGS HELD_OUT_WORKLOADS = { - 'librispeech': [ - 'librispeech_conformer_attention_temperature', - 'librispeech_conformer_layernorm', - # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure - 'librispeech_deepspeech_no_resnet', - 'librispeech_deepspeech_norm_and_spec_aug', - 'librispeech_deepspeech_tanh' - ], - 'imagenet': [ - 'imagenet_resnet_silu', - 'imagenet_resnet_gelu', - 'imagenet_resnet_large_bn_init', - 'imagenet_vit_glu', - 'imagenet_vit_post_ln', - 'imagenet_vit_map' - ], - 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], - 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb': [ - 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet' - ] + 'librispeech': [ + 'librispeech_conformer_attention_temperature', + 'librispeech_conformer_layernorm', + # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure + 'librispeech_deepspeech_no_resnet', + 'librispeech_deepspeech_norm_and_spec_aug', + 'librispeech_deepspeech_tanh', + ], + 'imagenet': [ + 'imagenet_resnet_silu', + 'imagenet_resnet_gelu', + 'imagenet_resnet_large_bn_init', + 'imagenet_vit_glu', + 'imagenet_vit_post_ln', + 'imagenet_vit_map', + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], + 'criteo1tb': [ + 'criteo1tb_layernorm', + 'criteo1tb_embed_init', + 'criteo1tb_resnet', + ], } def save_held_out_workloads(held_out_workloads, filename): - with open(filename, "w") as f: + with open(filename, 'w') as f: json.dump(held_out_workloads, f) @@ -63,7 +65,7 @@ def main(_): sampled_index = rng.integers(len(v)) sampled_held_out_workloads.append(v[sampled_index]) - logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + logging.info(f'Sampled held-out workloads: {sampled_held_out_workloads}') save_held_out_workloads(sampled_held_out_workloads, output_filename) diff --git a/scoring/algoperf_v05/score_submissions.py b/scoring/algoperf_v05/score_submissions.py index 8cc06b15f..6ef931a54 100644 --- a/scoring/algoperf_v05/score_submissions.py +++ b/scoring/algoperf_v05/score_submissions.py @@ -16,57 +16,62 @@ import os import pickle -from absl import app -from absl import flags -from absl import logging import numpy as np import pandas as pd import performance_profile import scoring_utils +from absl import app, flags, logging from tabulate import tabulate flags.DEFINE_string( - 'submission_directory', - None, - 'Path to submission directory containing experiment directories.') + 'submission_directory', + None, + 'Path to submission directory containing experiment directories.', +) flags.DEFINE_string( - 'output_dir', - 'scoring_results', - 'Path to save performance profile artifacts, submission_summaries and results files.' + 'output_dir', + 'scoring_results', + 'Path to save performance profile artifacts, submission_summaries and results files.', +) +flags.DEFINE_boolean( + 'compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.', ) -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant performance and on' - '5-trial median performance. Note that during official scoring this ' - 'flag will be set to True.') + 'strict', + False, + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.', +) flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset') + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset', +) flags.DEFINE_string( - 'save_results_to_filename', - None, - 'Filename to save the processed results that are fed into the performance profile functions.' + 'save_results_to_filename', + None, + 'Filename to save the processed results that are fed into the performance profile functions.', ) flags.DEFINE_string( - 'load_results_from_filename', - None, - 'Filename to load processed results from that are fed into performance profile functions' + 'load_results_from_filename', + None, + 'Filename to load processed results from that are fed into performance profile functions', ) flags.DEFINE_string( - 'exclude_submissions', - '', - 'Optional comma seperated list of names of submissions to exclude from scoring.' + 'exclude_submissions', + '', + 'Optional comma seperated list of names of submissions to exclude from scoring.', ) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + ) is_minimized = performance_profile.check_if_minimized(validation_metric) target_op = operator.le if is_minimized else operator.ge @@ -79,47 +84,69 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['val target metric name'] = validation_metric summary_df['val target metric value'] = validation_target - summary_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) summary_df['best metric value on val'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on val'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on val']], - axis=1) - workload_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + lambda x: x['accumulated_submission_time'][x['index best eval on val']], + axis=1, + ) + workload_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) workload_df['index to target on val'] = workload_df.apply( - lambda x: np.argmax(target_op(x[validation_metric], validation_target)) - if x['val target reached'] else np.nan, - axis=1) + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] + else np.nan, + axis=1, + ) summary_df['time to target on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][int(x[ - 'index to target on val'])] if x['val target reached'] else np.inf, - axis=1) + lambda x: x['accumulated_submission_time'][int(x['index to target on val'])] + if x['val target reached'] + else np.inf, + axis=1, + ) # test metrics if include_test_split: - test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') + test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( + workload, split='test' + ) summary_df['test target metric name'] = test_metric summary_df['test target metric value'] = test_target - summary_df['test target reached'] = workload_df[test_metric].apply( - lambda x: target_op(x, test_target)).apply(np.any) + summary_df['test target reached'] = ( + workload_df[test_metric] + .apply(lambda x: target_op(x, test_target)) + .apply(np.any) + ) summary_df['best metric value on test'] = workload_df[test_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on test'] = workload_df[test_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on test (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on test'] - ], - axis=1) + lambda x: x['accumulated_submission_time'][x['index best eval on test']], + axis=1, + ) summary_df['time to target on test (s)'] = summary_df.apply( - lambda x: x['time to best eval on test (s)'] - if x['test target reached'] else np.inf, - axis=1) + lambda x: x['time to best eval on test (s)'] + if x['test target reached'] + else np.inf, + axis=1, + ) return summary_df @@ -133,7 +160,8 @@ def get_submission_summary(df, include_test_split=True): print(df) for workload, group in df.groupby('workload'): summary_df = get_summary_df( - workload, group, include_test_split=include_test_split) + workload, group, include_test_split=include_test_split + ) dfs.append(summary_df) df = pd.concat(dfs) @@ -164,61 +192,64 @@ def main(_): # Optionally read results to filename if FLAGS.load_results_from_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), - 'rb') as f: + os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb' + ) as f: results = pickle.load(f) else: for team in os.listdir(FLAGS.submission_directory): for submission in os.listdir( - os.path.join(FLAGS.submission_directory, team)): + os.path.join(FLAGS.submission_directory, team) + ): print(submission) if submission in FLAGS.exclude_submissions.split(','): continue - experiment_path = os.path.join(FLAGS.submission_directory, - team, - submission) + experiment_path = os.path.join( + FLAGS.submission_directory, team, submission + ) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df summary_df = get_submission_summary(df) with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), - 'w') as fout: + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), - 'wb') as f: + os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb' + ) as f: pickle.dump(results, f) if not FLAGS.strict: logging.warning( - 'You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.' + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=4.0, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict, - output_dir=FLAGS.output_dir, + results, + time_col='score', + min_tau=1.0, + max_tau=4.0, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict, + output_dir=FLAGS.output_dir, ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) + performance_profile_df, 'score', save_dir=FLAGS.output_dir + ) performance_profile_str = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') + performance_profile_df.T, headers='keys', tablefmt='psql' + ) logging.info(f'Performance profile:\n {performance_profile_str}') scores = compute_leaderboard_score(performance_profile_df) scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv')) diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py index d0e5bf70b..1740a6dce 100644 --- a/scoring/compute_speedups.py +++ b/scoring/compute_speedups.py @@ -2,39 +2,39 @@ import pickle -from absl import app -from absl import flags import numpy as np import pandas as pd -from performance_profile import BASE_WORKLOADS -from performance_profile import get_workloads_time_to_target +from absl import app, flags +from performance_profile import BASE_WORKLOADS, get_workloads_time_to_target from scipy import stats flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.') flags.DEFINE_string( - 'base', - 'prize_qualification_baseline', - 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.' + 'base', + 'prize_qualification_baseline', + 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.', ) flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.') -flags.DEFINE_boolean('self_tuning_ruleset', - False, - 'Whether the self-tuning ruleset is being scored.') -flags.DEFINE_boolean('save_results', - False, - 'Whether to save the results to disk.') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether the self-tuning ruleset is being scored.', +) +flags.DEFINE_boolean( + 'save_results', False, 'Whether to save the results to disk.' +) FLAGS = flags.FLAGS # These are the old budgets, used in the first iteration of the competition. MAX_BUDGETS = { - 'criteo1tb': 7703, - 'fastmri': 8859, - 'imagenet_resnet': 63_008, - 'imagenet_vit': 77_520, - 'librispeech_conformer': 61_068, - 'librispeech_deepspeech': 55_506, - 'ogbg': 18_477, - 'wmt': 48_151, + 'criteo1tb': 7703, + 'fastmri': 8859, + 'imagenet_resnet': 63_008, + 'imagenet_vit': 77_520, + 'librispeech_conformer': 61_068, + 'librispeech_deepspeech': 55_506, + 'ogbg': 18_477, + 'wmt': 48_151, } @@ -63,16 +63,16 @@ def compute_speedup(): # Compute median over runtimes for both training algorithms base_results = get_workloads_time_to_target( - results[FLAGS.base], - FLAGS.base, - time_col="score", - self_tuning_ruleset=FLAGS.self_tuning_ruleset, + results[FLAGS.base], + FLAGS.base, + time_col='score', + self_tuning_ruleset=FLAGS.self_tuning_ruleset, ) comparison_results = get_workloads_time_to_target( - results[FLAGS.comparison], - FLAGS.comparison, - time_col="score", - self_tuning_ruleset=FLAGS.self_tuning_ruleset, + results[FLAGS.comparison], + FLAGS.comparison, + time_col='score', + self_tuning_ruleset=FLAGS.self_tuning_ruleset, ) # Merge results @@ -85,20 +85,23 @@ def compute_speedup(): merged_results = merged_results.apply(replace_inf, axis=1) # Compute speedup - merged_results['speedup'] = merged_results[ - f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}'] + merged_results['speedup'] = ( + merged_results[f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}'] + ) speedups = merged_results['speedup'].to_numpy() mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups print(merged_results, end='\n\n') print( - f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}" + f'Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1 - mean_speedup):.1%}' ) if FLAGS.save_results: # Optionally save results to disk - print("Saving results to disk...") - filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv' + print('Saving results to disk...') + filename = ( + f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1 - mean_speedup):.1%}.csv' + ) merged_results.to_csv(filename) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 05026a0c7..d79f705d1 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -25,21 +25,22 @@ The keys in this dictionary should match the workload identifiers used in the dictionary of submissions. """ + import itertools import json import operator import os import re -from absl import logging import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd +from absl import logging from tabulate import tabulate -from algoperf.workloads.workloads import get_base_workload_name import algoperf.workloads.workloads as workloads_registry +from algoperf.workloads.workloads import get_base_workload_name from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS @@ -49,7 +50,7 @@ # Open json file to read heldout workloads # TODO: This probably shouldn't be hardcoded but passed as an argument.\ try: - with open("held_out_workloads_algoperf_v05.json", "r") as f: + with open('held_out_workloads_algoperf_v05.json', 'r') as f: HELDOUT_WORKLOADS = json.load(f) except: HELDOUT_WORKLOADS = None @@ -64,22 +65,22 @@ NUM_STUDIES = 3 MIN_EVAL_METRICS = [ - 'ce_loss', - 'error_rate', - 'ctc_loss', - 'wer', - 'l1_loss', - 'loss', + 'ce_loss', + 'error_rate', + 'ctc_loss', + 'wer', + 'l1_loss', + 'loss', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] -#MPL params +# MPL params mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches mpl.rcParams['font.family'] = 'serif' -mpl.rcParams['font.serif'] = [ - 'Times New Roman' -] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice +mpl.rcParams['font.serif'] = ['Times New Roman'] + mpl.rcParams[ + 'font.serif' +] # Add Times New Roman as first choice mpl.rcParams['font.size'] = 22 mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures @@ -87,16 +88,17 @@ mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed mpl.rcParams['axes.prop_cycle'] = mpl.cycler( - color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", - "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis) + color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] +) # Example color cycle (consider ColorBrewer or viridis) mpl.rcParams['axes.labelsize'] = 22 # Axis label font size mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size mpl.rcParams['ytick.labelsize'] = 20 # Legends and Gridlines mpl.rcParams['legend.fontsize'] = 20 # Legend font size -mpl.rcParams[ - 'legend.loc'] = 'best' # Let matplotlib decide the best legend location +mpl.rcParams['legend.loc'] = ( + 'best' # Let matplotlib decide the best legend location +) mpl.rcParams['axes.grid'] = True # Enable grid mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency @@ -113,7 +115,8 @@ def generate_eval_cols(metrics): MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)} MINIMIZE_REGISTRY.update( - {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}) + {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)} +) MINIMIZE_REGISTRY['train_cost'] = True @@ -125,13 +128,15 @@ def check_if_minimized(col_name): if col in col_name: return MINIMIZE_REGISTRY[col] - raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' - 'either a column name or a substring of a column name.') + raise ValueError( + f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' + 'either a column name or a substring of a column name.' + ) -def get_best_trial_index(workload_df, - validation_metric, - validation_target=None): +def get_best_trial_index( + workload_df, validation_metric, validation_target=None +): """Get the eval index in which a workload reaches the target metric_col. Args: @@ -150,7 +155,8 @@ def get_best_trial_index(workload_df, op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( - lambda x: op(x, validation_target)) + lambda x: op(x, validation_target) + ) target_reached = pd.Series(validation_target_reached) # Remove trials that never reach the target @@ -166,12 +172,14 @@ def get_best_trial_index(workload_df, return trial, index_reached[trial] -def get_workloads_time_to_target(submission, - submission_name, - time_col='global_step', - verbosity=1, - self_tuning_ruleset=False, - strict=False): +def get_workloads_time_to_target( + submission, + submission_name, + time_col='global_step', + verbosity=1, + self_tuning_ruleset=False, + strict=False, +): """Get times to target for each workload in a submission. Args: @@ -191,60 +199,72 @@ def get_workloads_time_to_target(submission, if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS: if strict: raise ValueError( - f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads for {submission_name}.') - logging.warning( f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads for {submission_name}.') + f'but found {num_workloads} workloads for {submission_name}.' + ) + logging.warning( + f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' + f'but found {num_workloads} workloads for {submission_name}.' + ) # For each workload get submission time get the submission times to target. for workload, group in submission.groupby('workload'): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload) + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload) + ) # Check number of studies time_vals_per_study = [] num_studies = len(group.groupby('study')) if num_studies != NUM_STUDIES: if strict: - raise ValueError(f'Expecting {NUM_STUDIES} studies for workload ' - f'{workload} but found {num_studies} studies ' - f'for {submission_name}.') + raise ValueError( + f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.' + ) else: - logging.warning(f'Expecting {NUM_STUDIES} studies for workload ' - f'{workload} but found {num_studies} studies ' - f'for {submission_name}.') + logging.warning( + f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.' + ) # For each study check trials for study, group in group.groupby('study'): - # Check number of trials per study num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: if strict: raise ValueError( - f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials ' - f'for {submission_name}.') + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.' + ) else: logging.warning( - f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials ' - f'for {submission_name}.') + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.' + ) # Get trial and time index that reaches target trial_idx, time_idx = get_best_trial_index( - group, validation_metric, validation_target) + group, validation_metric, validation_target + ) if time_idx > -1: time_val = group[time_col].loc[trial_idx][time_idx] else: time_val = float('inf') time_vals_per_study.append(time_val) - workloads.append({ + workloads.append( + { 'submission': submission_name, 'workload': re.sub(r'_(jax|pytorch)$', '', workload), time_col: np.median(time_vals_per_study), - }) + } + ) df = pd.DataFrame.from_records(workloads) df = df.pivot(index='submission', columns='workload', values=time_col) @@ -252,7 +272,6 @@ def get_workloads_time_to_target(submission, def variant_criteria_filter(base_workload, variant_workload): - def filter(x): try: if x[variant_workload] == np.inf: @@ -269,17 +288,19 @@ def filter(x): return filter -def compute_performance_profiles(submissions, - time_col='global_step', - min_tau=1.0, - max_tau=None, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - strict=False, - self_tuning_ruleset=False, - output_dir=None): +def compute_performance_profiles( + submissions, + time_col='global_step', + min_tau=1.0, + max_tau=None, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + strict=False, + self_tuning_ruleset=False, + output_dir=None, +): """Compute performance profiles for a set of submission by some time column. Args: @@ -308,16 +329,20 @@ def compute_performance_profiles(submissions, for submission_tag, submission in submissions.items(): logging.info( - f'\nComputing performance profile with respect to `{time_col}` for ' - f'{submission_tag}') + f'\nComputing performance profile with respect to `{time_col}` for ' + f'{submission_tag}' + ) # Get time to targets for each submission across studies and trials dfs.append( - get_workloads_time_to_target(submission, - submission_tag, - time_col, - verbosity, - self_tuning_ruleset, - strict)) + get_workloads_time_to_target( + submission, + submission_tag, + time_col, + verbosity, + self_tuning_ruleset, + strict, + ) + ) df = pd.concat(dfs) # Restrict to base and sampled held-out workloads # (ignore the additional workload variants of the baseline @@ -335,7 +360,8 @@ def compute_performance_profiles(submissions, # If base do not have finite score set variant score to inf base_workload = get_base_workload_name(workload) df[workload] = df.apply( - variant_criteria_filter(workload, base_workload), axis=1) + variant_criteria_filter(workload, base_workload), axis=1 + ) # Set score to inf if not within 4x of fastest submission best_scores = df.min(axis=0) @@ -347,17 +373,20 @@ def compute_performance_profiles(submissions, # If variants do not have finite score set base_workload score to inf base_workload = get_base_workload_name(workload) df[base_workload] = df.apply( - variant_criteria_filter(base_workload, workload), axis=1) + variant_criteria_filter(base_workload, workload), axis=1 + ) df = df[BASE_WORKLOADS] if verbosity > 0: logging.info('\n`{time_col}` to reach target:') - with pd.option_context('display.max_rows', - None, - 'display.max_columns', - None, - 'display.width', - 1000): + with pd.option_context( + 'display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000, + ): logging.info(df) # Divide by the fastest. @@ -368,12 +397,14 @@ def compute_performance_profiles(submissions, if verbosity > 0: logging.info('\n`{time_col}` to reach target normalized to best:') - with pd.option_context('display.max_rows', - None, - 'display.max_columns', - None, - 'display.width', - 1000): + with pd.option_context( + 'display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000, + ): logging.info(df) # If no max_tau is supplied, choose the value of tau that would plot all non @@ -385,7 +416,8 @@ def compute_performance_profiles(submissions, points = np.linspace(min_tau, max_tau, num=num_points) elif scale == 'log': points = np.logspace( - np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) + np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0 + ) def rho(r, tau): return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS @@ -431,11 +463,9 @@ def maybe_save_df_to_csv(save_dir, df, path, **to_csv_kwargs): df.to_csv(fout, **to_csv_kwargs) -def plot_performance_profiles(perf_df, - df_col, - scale='linear', - save_dir=None, - figsize=(30, 10)): +def plot_performance_profiles( + perf_df, df_col, scale='linear', save_dir=None, figsize=(30, 10) +): """Plot performance profiles. Args: @@ -462,6 +492,6 @@ def plot_performance_profiles(perf_df, fig.legend(bbox_to_anchor=(1.0, 1.0)) plt.tight_layout() maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}') - maybe_save_df_to_csv(save_dir, - perf_df, - f'performance_profile_{df_col_display}.csv') + maybe_save_df_to_csv( + save_dir, perf_df, f'performance_profile_{df_col_display}.csv' + ) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index f07dc8cdd..b48509f02 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -17,57 +17,62 @@ import os import pickle -from absl import app -from absl import flags -from absl import logging import numpy as np import pandas as pd import performance_profile import scoring_utils +from absl import app, flags, logging from tabulate import tabulate flags.DEFINE_string( - 'submission_directory', - None, - 'Path to submission directory containing experiment directories.') + 'submission_directory', + None, + 'Path to submission directory containing experiment directories.', +) flags.DEFINE_string( - 'output_dir', - 'scoring_results', - 'Path to save performance profile artifacts, submission_summaries and results files.' + 'output_dir', + 'scoring_results', + 'Path to save performance profile artifacts, submission_summaries and results files.', +) +flags.DEFINE_boolean( + 'compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.', ) -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant performance and on' - '5-trial median performance. Note that during official scoring this ' - 'flag will be set to True.') + 'strict', + False, + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.', +) flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset') + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset', +) flags.DEFINE_string( - 'save_results_to_filename', - None, - 'Filename to save the processed results that are fed into the performance profile functions.' + 'save_results_to_filename', + None, + 'Filename to save the processed results that are fed into the performance profile functions.', ) flags.DEFINE_string( - 'load_results_from_filename', - None, - 'Filename to load processed results from that are fed into performance profile functions' + 'load_results_from_filename', + None, + 'Filename to load processed results from that are fed into performance profile functions', ) flags.DEFINE_string( - 'exclude_submissions', - '', - 'Optional comma seperated list of names of submissions to exclude from scoring.' + 'exclude_submissions', + '', + 'Optional comma seperated list of names of submissions to exclude from scoring.', ) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + ) is_minimized = performance_profile.check_if_minimized(validation_metric) target_op = operator.le if is_minimized else operator.ge @@ -80,47 +85,69 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['val target metric name'] = validation_metric summary_df['val target metric value'] = validation_target - summary_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) summary_df['best metric value on val'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on val'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on val']], - axis=1) - workload_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + lambda x: x['accumulated_submission_time'][x['index best eval on val']], + axis=1, + ) + workload_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) workload_df['index to target on val'] = workload_df.apply( - lambda x: np.argmax(target_op(x[validation_metric], validation_target)) - if x['val target reached'] else np.nan, - axis=1) + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] + else np.nan, + axis=1, + ) summary_df['time to target on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][int(x[ - 'index to target on val'])] if x['val target reached'] else np.inf, - axis=1) + lambda x: x['accumulated_submission_time'][int(x['index to target on val'])] + if x['val target reached'] + else np.inf, + axis=1, + ) # test metrics if include_test_split: - test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') + test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( + workload, split='test' + ) summary_df['test target metric name'] = test_metric summary_df['test target metric value'] = test_target - summary_df['test target reached'] = workload_df[test_metric].apply( - lambda x: target_op(x, test_target)).apply(np.any) + summary_df['test target reached'] = ( + workload_df[test_metric] + .apply(lambda x: target_op(x, test_target)) + .apply(np.any) + ) summary_df['best metric value on test'] = workload_df[test_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on test'] = workload_df[test_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on test (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on test'] - ], - axis=1) + lambda x: x['accumulated_submission_time'][x['index best eval on test']], + axis=1, + ) summary_df['time to target on test (s)'] = summary_df.apply( - lambda x: x['time to best eval on test (s)'] - if x['test target reached'] else np.inf, - axis=1) + lambda x: x['time to best eval on test (s)'] + if x['test target reached'] + else np.inf, + axis=1, + ) return summary_df @@ -134,7 +161,8 @@ def get_submission_summary(df, include_test_split=True): print(df) for workload, group in df.groupby('workload'): summary_df = get_summary_df( - workload, group, include_test_split=include_test_split) + workload, group, include_test_split=include_test_split + ) dfs.append(summary_df) df = pd.concat(dfs) @@ -161,13 +189,13 @@ def compute_leaderboard_score(df, normalize=True): def main(_): results = {} os.makedirs(FLAGS.output_dir, exist_ok=True) - logging.info(f"Scoring submissions in {FLAGS.submission_directory}") + logging.info(f'Scoring submissions in {FLAGS.submission_directory}') # Optionally read results to filename if FLAGS.load_results_from_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), - 'rb') as f: + os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb' + ) as f: results = pickle.load(f) else: for submission in os.listdir(FLAGS.submission_directory): @@ -179,44 +207,46 @@ def main(_): results[submission] = df summary_df = get_submission_summary(df) with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), - 'w') as fout: + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), - 'wb') as f: + os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb' + ) as f: pickle.dump(results, f) if not FLAGS.strict: logging.warning( - 'You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.' + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=4.0, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict, - output_dir=FLAGS.output_dir, + results, + time_col='score', + min_tau=1.0, + max_tau=4.0, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict, + output_dir=FLAGS.output_dir, ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) + performance_profile_df, 'score', save_dir=FLAGS.output_dir + ) performance_profile_str = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') + performance_profile_df.T, headers='keys', tablefmt='psql' + ) logging.info(f'Performance profile:\n {performance_profile_str}') scores = compute_leaderboard_score(performance_profile_df) scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv')) diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index ac513816e..ab639f870 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -4,8 +4,8 @@ import os import re -from absl import logging import pandas as pd +from absl import logging import algoperf.workloads.workloads as workloads_registry @@ -13,7 +13,7 @@ METRICS_LINE_REGEX = '(.*) Metrics: ({.*})' TRIAL_DIR_REGEX = 'trial_(\d+)' MEASUREMENTS_FILENAME = 'eval_measurements.csv' -TIMESTAMP = r"-\d{4}(-\d{2}){5}" +TIMESTAMP = r'-\d{4}(-\d{2}){5}' WORKLOADS = workloads_registry.WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' @@ -22,12 +22,11 @@ #### File IO helper functions ### def get_logfile_paths(logdir): - """Gets all files ending in .log in logdir - """ + """Gets all files ending in .log in logdir""" filenames = os.listdir(logdir) logfile_paths = [] for f in filenames: - if f.endswith(".log"): + if f.endswith('.log'): f = os.path.join(logdir, f) logfile_paths.append(f) return logfile_paths @@ -36,23 +35,23 @@ def get_logfile_paths(logdir): ### Logfile reading helper functions ### def decode_metrics_line(line): """Convert metrics line to dict. - Args: - line: str - - Returns: - dict_of_lists: dict where keys are metric names and vals - are lists of values. - e.g. {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]} - """ + Args: + line: str + + Returns: + dict_of_lists: dict where keys are metric names and vals + are lists of values. + e.g. {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]} + """ eval_results = [] dict_str = re.match(METRICS_LINE_REGEX, line).group(2) - dict_str = dict_str.replace("'", "\"") - dict_str = dict_str.replace("(", "") - dict_str = dict_str.replace(")", "") - dict_str = dict_str.replace("DeviceArray", "") - dict_str = dict_str.replace(", dtype=float32", "") - dict_str = dict_str.replace("nan", "0") + dict_str = dict_str.replace("'", '"') + dict_str = dict_str.replace('(', '') + dict_str = dict_str.replace(')', '') + dict_str = dict_str.replace('DeviceArray', '') + dict_str = dict_str.replace(', dtype=float32', '') + dict_str = dict_str.replace('nan', '0') metrics_dict = json.loads(dict_str) for item in metrics_dict['eval_results']: if isinstance(item, dict): @@ -73,18 +72,18 @@ def decode_metrics_line(line): def get_trials_dict(logfile): - """Get a dict of dicts with metrics for each - tuning run. - - Returns: - trials_dict: Dict of dicts where outer dict keys - are trial indices and inner dict key-value pairs - are metrics and list of values. - e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]}, - 'trial_1': {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]}} - """ + """Get a dict of dicts with metrics for each + tuning run. + + Returns: + trials_dict: Dict of dicts where outer dict keys + are trial indices and inner dict key-value pairs + are metrics and list of values. + e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}, + 'trial_1': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}} + """ trial = 0 metrics_lines = {} with open(logfile, 'r') as f: @@ -100,16 +99,16 @@ def get_trials_dict(logfile): ### Results formatting helper functions ### def get_trials_df_dict(logfile): - """Get a dict with dataframes with metrics for each - tuning run. - Preferable format for saving dataframes for tables. - Args: - logfile: str path to logfile. - - Returns: - DataFrame where indices are index of eval and - columns are metric names. - """ + """Get a dict with dataframes with metrics for each + tuning run. + Preferable format for saving dataframes for tables. + Args: + logfile: str path to logfile. + + Returns: + DataFrame where indices are index of eval and + columns are metric names. + """ trials_dict = get_trials_dict(logfile) trials_df_dict = {} for trial, metrics in trials_dict.items(): @@ -119,20 +118,20 @@ def get_trials_df_dict(logfile): def get_trials_df(logfile): """Gets a df of per trial results from a logfile. - Args: - experiment_dir: str - - Returns: - df: DataFrame where indices are trials, columns are - metric names and values are lists. - e.g - +---------+-----------------+-----------------+ - | | loss | step | - |---------+-----------------+-----------------| - | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] | - | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] | - +---------+-----------------+-----------------+ - """ + Args: + experiment_dir: str + + Returns: + df: DataFrame where indices are trials, columns are + metric names and values are lists. + e.g + +---------+-----------------+-----------------+ + | | loss | step | + |---------+-----------------+-----------------| + | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] | + | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] | + +---------+-----------------+-----------------+ + """ trials_dict = get_trials_dict(logfile) df = pd.DataFrame(trials_dict).transpose() return df @@ -141,13 +140,13 @@ def get_trials_df(logfile): ## Get scoring code def get_experiment_df(experiment_dir): """Gets a df of per trial results from an experiment dir. - The output df can be provided as input to - performance_profile.compute_performance_profiles. + The output df can be provided as input to + performance_profile.compute_performance_profiles. Args: - experiment_dir: path to experiment directory containing - results for workloads. Measurements from experiments - sharing the same prefix but different timestamps are - collected together. + experiment_dir: path to experiment directory containing + results for workloads. Measurements from experiments + sharing the same prefix but different timestamps are + collected together. The directory structure is assumed to be: + experiment_dir + study @@ -156,9 +155,9 @@ def get_experiment_df(experiment_dir): - eval_measurements.csv Returns: - df: DataFrame where indices are trials, columns are + df: DataFrame where indices are trials, columns are metric names and values are lists of length num evals. - e.g + e.g +----+-----------+--------+----------------------------+--------------------+--------------------+ | | workload | study |trial | validation/accuracy| score | |----+-----------+--------+----------------------------+--------------------+--------------------| @@ -167,35 +166,37 @@ def get_experiment_df(experiment_dir): """ df = pd.DataFrame() paths = filter( - lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, - glob.glob(f"{experiment_dir}*")) + lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, + glob.glob(f'{experiment_dir}*'), + ) for experiment_dir in paths: study_dirs = os.listdir(experiment_dir) for study_dir in study_dirs: workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) workload_dirs = [ - w for w in workload_dirs - if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) + w + for w in workload_dirs + if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) ] print(workload_dirs) for workload in workload_dirs: data = { - 'workload': workload, + 'workload': workload, } logging.info(os.path.join(experiment_dir, study_dir, workload)) trial_dirs = [ - t for t in os.listdir( - os.path.join(experiment_dir, study_dir, workload)) - if re.match(TRIAL_DIR_REGEX, t) + t + for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + if re.match(TRIAL_DIR_REGEX, t) ] for trial in trial_dirs: eval_measurements_filepath = os.path.join( - experiment_dir, - study_dir, - workload, - trial, - MEASUREMENTS_FILENAME, + experiment_dir, + study_dir, + workload, + trial, + MEASUREMENTS_FILENAME, ) try: trial_df = pd.read_csv(eval_measurements_filepath) @@ -221,14 +222,16 @@ def get_workload_metrics_and_targets(workload, split='validation'): # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( - BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + f'{framework}', - 'workload.py') + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) workload_init_kwargs = {} workload_obj = workloads_registry.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs=workload_init_kwargs) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) metric_name = workload_obj.target_metric_name if split == 'validation': metric = f'validation/{metric_name}' diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py index 166c82d09..01c96de71 100644 --- a/scoring/test_performance_profile.py +++ b/scoring/test_performance_profile.py @@ -2,12 +2,10 @@ from absl.testing import absltest -from scoring import performance_profile -from scoring import scoring_utils +from scoring import performance_profile, scoring_utils class Test(absltest.TestCase): - def test_get_workloads_time_to_target(self): # TODO(kasimbeg) pass diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py index 7509e3e46..e3a5f7263 100644 --- a/scoring/test_scoring_utils.py +++ b/scoring/test_scoring_utils.py @@ -2,8 +2,7 @@ from absl.testing import absltest -from scoring import performance_profile -from scoring import scoring_utils +from scoring import performance_profile, scoring_utils TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log' TEST_DIR = 'test_data/experiment_dir' @@ -11,7 +10,6 @@ class Test(absltest.TestCase): - def test_get_trials_dict(self): trials_dict = scoring_utils.get_trials_dict(TEST_LOGFILE) self.assertEqual(len(trials_dict['1']['global_step']), NUM_EVALS) diff --git a/scoring/utils/package_logs.py b/scoring/utils/package_logs.py index 074075abf..e341570a1 100644 --- a/scoring/utils/package_logs.py +++ b/scoring/utils/package_logs.py @@ -3,11 +3,11 @@ python3 package_logs.py --experiment_dir --destination_dir """ + import os import shutil -from absl import app -from absl import flags +from absl import app, flags flags.DEFINE_string('experiment_dir', None, 'Path to experiment.') flags.DEFINE_string('destination_dir', None, 'Path to save submission logs') @@ -17,10 +17,10 @@ def move_logs(experiment_dir, destination_dir): """Copy files from experiment path to destination directory. - Args: - experiment_dir: Path to experiment dir. - destination_dir: Path to destination dir. - """ + Args: + experiment_dir: Path to experiment dir. + destination_dir: Path to destination dir. + """ if not os.path.exists(experiment_dir): raise IOError(f'Directory does not exist {destination_dir}') diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 683fb3c63..e2de01130 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -16,101 +16,111 @@ import subprocess import time -from absl import app -from absl import flags -from absl import logging +from absl import app, flags, logging +import docker from algoperf import random_utils as prng from algoperf.workloads.workloads import get_base_workload_name -import docker flags.DEFINE_string( - 'docker_image_url', - 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev', - 'URL to docker image') + 'docker_image_url', + 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev', + 'URL to docker image', +) flags.DEFINE_integer( - 'run_percentage', - 100, - 'Percentage of max num steps to run for.' - 'Must set the flag enable_step_budget to True for this to take effect.') -flags.DEFINE_string('experiment_name', - 'my_experiment', - 'Name of top sub directory in experiment dir.') -flags.DEFINE_boolean('rsync_data', - True, - 'Whether or not to transfer the data from GCP w rsync.') + 'run_percentage', + 100, + 'Percentage of max num steps to run for.' + 'Must set the flag enable_step_budget to True for this to take effect.', +) +flags.DEFINE_string( + 'experiment_name', + 'my_experiment', + 'Name of top sub directory in experiment dir.', +) +flags.DEFINE_boolean( + 'rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.' +) flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') flags.DEFINE_string( - 'submission_path', - 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', - 'Path to reference submission.') + 'submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.', +) flags.DEFINE_string( - 'tuning_search_space', - 'prize_qualification_baselines/external_tuning/tuning_search_space.json', - 'Path to tuning search space.') + 'tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.', +) flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean( - 'dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ', +) flags.DEFINE_enum( - 'tuning_ruleset', - 'external', - enum_values=['external', 'self'], - help='Can be either external of self.') + 'tuning_ruleset', + 'external', + enum_values=['external', 'self'], + help='Can be either external of self.', +) flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') -flags.DEFINE_integer('hparam_start_index', - None, - 'Start index for tuning trials.') +flags.DEFINE_integer( + 'hparam_start_index', None, 'Start index for tuning trials.' +) flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.') -flags.DEFINE_integer('submission_id', - 0, - 'Submission ID to generate study and hparam seeds.') -flags.DEFINE_string('held_out_workloads_config_path', - None, - 'Path to config containing held-out workloads') +flags.DEFINE_integer( + 'submission_id', 0, 'Submission ID to generate study and hparam seeds.' +) flags.DEFINE_string( - 'workload_metadata_path', - None, - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'We provide workload_metadata_external_tuning.json and ' - 'workload_metadata_self_tuning.json as references.' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') + 'held_out_workloads_config_path', + None, + 'Path to config containing held-out workloads', +) flags.DEFINE_string( - 'workloads', - None, - 'String representing a comma separated list of workload names.' - 'If not None, only run this workload, else run all workloads in workload_metadata_path.' + 'workload_metadata_path', + None, + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'We provide workload_metadata_external_tuning.json and ' + 'workload_metadata_self_tuning.json as references.' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.', +) +flags.DEFINE_string( + 'workloads', + None, + 'String representing a comma separated list of workload names.' + 'If not None, only run this workload, else run all workloads in workload_metadata_path.', +) +flags.DEFINE_string( + 'additional_requirements_path', None, 'Path to requirements.txt if any.' ) -flags.DEFINE_string('additional_requirements_path', - None, - 'Path to requirements.txt if any.') flags.DEFINE_integer( - 'max_steps', - None, - 'Maximum number of steps to run. Must set flag enable_step_budget.' - 'This flag takes precedence over the run_percentage flag.') + 'max_steps', + None, + 'Maximum number of steps to run. Must set flag enable_step_budget.' + 'This flag takes precedence over the run_percentage flag.', +) flags.DEFINE_bool( - 'enable_step_budget', - False, - 'Flag that has to be explicitly set to override time budgets to step budget percentage.' + 'enable_step_budget', + False, + 'Flag that has to be explicitly set to override time budgets to step budget percentage.', ) FLAGS = flags.FLAGS def read_held_out_workloads(filename): - with open(filename, "r") as f: + with open(filename, 'r') as f: held_out_workloads = json.load(f) return held_out_workloads @@ -132,11 +142,13 @@ def kill_containers(): def gpu_is_active(): - output = subprocess.check_output([ + output = subprocess.check_output( + [ 'nvidia-smi', '--query-gpu=utilization.gpu', - '--format=csv,noheader,nounits' - ]) + '--format=csv,noheader,nounits', + ] + ) return any(int(x) > 0 for x in output.decode().splitlines()) @@ -151,7 +163,8 @@ def wait_until_container_not_running(sleep_interval=5 * 60): gpu_last_active = datetime.datetime.now().timestamp() if (datetime.datetime.now().timestamp() - gpu_last_active) > 45 * 60: kill_containers( - "Killing container: GPUs have been inactive > 45 minutes...") + 'Killing container: GPUs have been inactive > 45 minutes...' + ) time.sleep(sleep_interval) return @@ -167,7 +180,9 @@ def main(_): hparam_start_index_flag = '' hparam_end_index_flag = '' if FLAGS.hparam_start_index: - hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' + hparam_start_index_flag = ( + f'--hparam_start_index {FLAGS.hparam_start_index} ' + ) if FLAGS.hparam_end_index: hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} ' study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 @@ -178,7 +193,9 @@ def main(_): additional_requirements_path_flag = '' if FLAGS.additional_requirements_path: - additional_requirements_path_flag = f'--additional_requirements_path {FLAGS.additional_requirements_path} ' + additional_requirements_path_flag = ( + f'--additional_requirements_path {FLAGS.additional_requirements_path} ' + ) submission_id = FLAGS.submission_id @@ -188,7 +205,7 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) + rng_key = prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id)) with open(FLAGS.workload_metadata_path) as f: workload_metadata = json.load(f) @@ -199,20 +216,24 @@ def main(_): # Read heldout workloads if FLAGS.held_out_workloads_config_path: held_out_workloads = read_held_out_workloads( - FLAGS.held_out_workloads_config_path) + FLAGS.held_out_workloads_config_path + ) workloads = workloads + held_out_workloads # Filter workloads if explicit workloads specified if FLAGS.workloads is not None: workloads = list( - filter(lambda x: x in FLAGS.workloads.split(','), workloads)) + filter(lambda x: x in FLAGS.workloads.split(','), workloads) + ) if len(workloads) != len(FLAGS.workloads.split(',')): unmatched_workloads = set(FLAGS.workloads.split(',')) - set(workloads) raise ValueError(f'Invalid workload name {unmatched_workloads}') rng_subkeys = prng.split(rng_key, num_studies) - for study_index, rng_subkey in zip(range(study_start_index, study_end_index + 1), rng_subkeys): + for study_index, rng_subkey in zip( + range(study_start_index, study_end_index + 1), rng_subkeys + ): print('-' * 100) print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40) print('-' * 100) @@ -225,40 +246,46 @@ def main(_): base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( - "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'" + ) # clear caches print('=' * 100) dataset = workload_metadata[base_workload_name]['dataset'] max_steps_flag = '' if FLAGS.enable_step_budget: - run_fraction = FLAGS.run_percentage / 100. + run_fraction = FLAGS.run_percentage / 100.0 if FLAGS.max_steps is None: - max_steps = int(workload_metadata[base_workload_name]['max_steps'] * - run_fraction) + max_steps = int( + workload_metadata[base_workload_name]['max_steps'] * run_fraction + ) else: max_steps = FLAGS.max_steps max_steps_flag = f'-m {max_steps}' mount_repo_flag = '' if FLAGS.local: - mount_repo_flag = '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v /home/kasimbeg/data/:/data/ ' - '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' - '-v /home/kasimbeg/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-e {study_dir} ' - f'{max_steps_flag} ' - f'--num_tuning_trials {num_tuning_trials} ' - f'--rng_seed {run_seed} ' - f'{additional_requirements_path_flag}' - '-c false ' - '-o true ' - '-i true ') + mount_repo_flag = ( + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' + ) + command = ( + 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' + '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' + '-v /home/kasimbeg/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-e {study_dir} ' + f'{max_steps_flag} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--rng_seed {run_seed} ' + f'{additional_requirements_path_flag}' + '-c false ' + '-o true ' + '-i true ' + ) # Append tuning ruleset flags tuning_ruleset_flags = '' @@ -280,18 +307,19 @@ def main(_): return_code = 0 if return_code == 0: print( - f'SUCCESS: container for {framework} {workload} launched successfully' + f'SUCCESS: container for {framework} {workload} launched successfully' ) print(f'Command: {command}') print(f'Results will be logged to {experiment_name}') else: print( - f'Failed: container for {framework} {workload} failed with exit code {return_code}.' + f'Failed: container for {framework} {workload} failed with exit code {return_code}.' ) print(f'Command: {command}') wait_until_container_not_running() os.system( - "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'" + ) # clear caches print('=' * 100) diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md index ed42752dd..a8e41f04b 100644 --- a/scoring/utils/slurm/README.md +++ b/scoring/utils/slurm/README.md @@ -1,27 +1,37 @@ # Launching SLURM jobs with SBATCH + This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset). To launch jobs: -1) Generate a job config. The following command will generate a config.json. -``` + +1. Generate a job config. The following command will generate a config.json. + +```bash python3 make_job_config.py \ --submission_path \ --tuning_search_space \ --experiment_dir $HOME/experiments/ \ --framework ``` -2) Save the config.json in the same directory you will run the sbatch script from. -3) Copy the example sbatch script `run_jobs.sh`. + +2. Save the config.json in the same directory you will run the sbatch script from. +3. Copy the example sbatch script `run_jobs.sh`. + - Set the task range to the number of tasks in the config. + ``` #SBATCH --array=0-119 ``` + - Set the output and error logs directory for the SLURM logs. + ``` #SBATCH --output=experiments///job_%A_%a.out #SBATCH --error=experiments///job_%A_%a.err ``` + - Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary: + ``` REPO="us-central1-docker.pkg.dev" IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main" @@ -31,22 +41,23 @@ docker pull $IMAGE config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path LOGS_BUCKET="algoperf-runs-internal" ``` -4) Submit a SLURM batch job by running: + +4. Submit a SLURM batch job by running: + ``` sbatch run_jobs.sh ``` - # Set up new SLURM cluster + If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster). To set up the new cluster: -1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart). -2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml). -3) Manually update the image: - 1) Create a VM from the Disk image created in the previous step. - 2) Install the NVIDIA container toolkit on the VM. - 3) Transfer the data from GCP bucket to `/opt/data`. - 4) Create a new disk image from the VM. -4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml). - +1. [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart). +2. Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml). +3. Manually update the image: + 1. Create a VM from the Disk image created in the previous step. + 2. Install the NVIDIA container toolkit on the VM. + 3. Transfer the data from GCP bucket to `/opt/data`. + 4. Create a new disk image from the VM. +4. Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml). diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml index e6c35e017..073fe98cc 100644 --- a/scoring/utils/slurm/algoperf_slurm_cluster.yaml +++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml @@ -32,74 +32,74 @@ vars: # bucket: <> deployment_groups: -- group: primary - modules: - - id: network - source: modules/network/vpc + - group: primary + modules: + - id: network + source: modules/network/vpc - - id: homefs - source: community/modules/file-system/nfs-server - use: [network] - settings: - local_mounts: [/home] - disk_size: 3000 - zone: $(vars.zone) + - id: homefs + source: community/modules/file-system/nfs-server + use: [network] + settings: + local_mounts: [/home] + disk_size: 3000 + zone: $(vars.zone) - - id: script - source: modules/scripts/startup-script - settings: + - id: script + source: modules/scripts/startup-script + settings: -- group: cluster - modules: - - id: v100_nodeset - source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset - use: - - network - settings: - node_count_dynamic_max: 25 # set to 0 if you want node to live forever - region: $(vars.region) - zone: $(vars.zone) - enable_placement: false - bandwidth_tier: gvnic_enabled - machine_type: n1-standard-64 - guest_accelerator: - - type: nvidia-tesla-v100 - count: 8 - instance_image: - project: $(vars.project_id) - name: $(vars.image_name) - instance_image_custom: true + - group: cluster + modules: + - id: v100_nodeset + source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset + use: + - network + settings: + node_count_dynamic_max: 25 # set to 0 if you want node to live forever + region: $(vars.region) + zone: $(vars.zone) + enable_placement: false + bandwidth_tier: gvnic_enabled + machine_type: n1-standard-64 + guest_accelerator: + - type: nvidia-tesla-v100 + count: 8 + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true - - id: v100_partition - source: community/modules/compute/schedmd-slurm-gcp-v6-partition - use: [v100_nodeset] - settings: - exclusive: false - partition_name: v100 - is_default: true + - id: v100_partition + source: community/modules/compute/schedmd-slurm-gcp-v6-partition + use: [v100_nodeset] + settings: + exclusive: false + partition_name: v100 + is_default: true - - id: slurm_login - source: community/modules/scheduler/schedmd-slurm-gcp-v6-login - use: [network] - settings: - enable_login_public_ips: true - instance_image: - project: $(vars.project_id) - name: $(vars.image_name) - instance_image_custom: true - zone: $(vars.zone) + - id: slurm_login + source: community/modules/scheduler/schedmd-slurm-gcp-v6-login + use: [network] + settings: + enable_login_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + zone: $(vars.zone) - - id: slurm_controller - source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller - use: - - network - - v100_partition - - homefs - - slurm_login - settings: - enable_controller_public_ips: true - instance_image: - project: $(vars.project_id) - name: $(vars.image_name) - instance_image_custom: true - region: $(vars.region) + - id: slurm_controller + source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller + use: + - network + - v100_partition + - homefs + - slurm_login + settings: + enable_controller_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + region: $(vars.region) diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml index 286728e1d..f3b5be5dd 100644 --- a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml +++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml @@ -13,7 +13,6 @@ # limitations under the License. --- - blueprint_name: algoperf-slurm-packer vars: @@ -36,97 +35,97 @@ vars: # bucket: <> deployment_groups: -- group: primary - modules: - - id: network - source: modules/network/vpc + - group: primary + modules: + - id: network + source: modules/network/vpc - - id: script - source: modules/scripts/startup-script - settings: - region: $(vars.region) - install_ansible: true - docker: - enabled: true - world_writable: true - # (TODO) Do I need this? - configure_ssh_host_patterns: - - 10.0.0.* - - 10.1.0.* - - 10.2.0.* - - 10.3.0.* - - 10.4.0.* - - 10.5.0.* - - 10.6.0.* - - 10.7.0.* - - $(vars.slurm_cluster_name)* - runners: - - type: shell - destination: install-ml-libraries.sh - content: | - #!/bin/bash - # this script is designed to execute on Slurm images published by SchedMD that: - # - are based on Debian distribution of Linux - # - have NVIDIA drivers pre-installed + - id: script + source: modules/scripts/startup-script + settings: + region: $(vars.region) + install_ansible: true + docker: + enabled: true + world_writable: true + # (TODO) Do I need this? + configure_ssh_host_patterns: + - 10.0.0.* + - 10.1.0.* + - 10.2.0.* + - 10.3.0.* + - 10.4.0.* + - 10.5.0.* + - 10.6.0.* + - 10.7.0.* + - $(vars.slurm_cluster_name)* + runners: + - type: shell + destination: install-ml-libraries.sh + content: | + #!/bin/bash + # this script is designed to execute on Slurm images published by SchedMD that: + # - are based on Debian distribution of Linux + # - have NVIDIA drivers pre-installed - set -e -o pipefail + set -e -o pipefail - echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list - apt-get update --allow-releaseinfo-change - apt-get install --assume-yes google-fast-socket + echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list + apt-get update --allow-releaseinfo-change + apt-get install --assume-yes google-fast-socket - CONDA_BASE=/opt/conda + CONDA_BASE=/opt/conda - if [ -d $CONDA_BASE ]; then - exit 0 - fi + if [ -d $CONDA_BASE ]; then + exit 0 + fi - DL_DIR=\$(mktemp -d) - cd $DL_DIR - curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh - HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE - cd - - rm -rf $DL_DIR - unset DL_DIR + DL_DIR=\$(mktemp -d) + cd $DL_DIR + curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh + HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE + cd - + rm -rf $DL_DIR + unset DL_DIR - source $CONDA_BASE/bin/activate base - conda init --system - conda config --system --set auto_activate_base False - # following channel ordering is important! use strict_priority! - conda config --system --set channel_priority strict - conda update -n base conda --yes + source $CONDA_BASE/bin/activate base + conda init --system + conda config --system --set auto_activate_base False + # following channel ordering is important! use strict_priority! + conda config --system --set channel_priority strict + conda update -n base conda --yes - ### create a virtual environment for tensorflow - conda create -n tf python=3.11 --yes - conda activate tf - pip install tensorflow[and-cuda]==2.18.* - pip install tensorrt==10.6.* + ### create a virtual environment for tensorflow + conda create -n tf python=3.11 --yes + conda activate tf + pip install tensorflow[and-cuda]==2.18.* + pip install tensorrt==10.6.* - ### create a virtual environment for pytorch - conda create -n pytorch python=3.11 --yes - conda activate pytorch - pip install torch torchvision torchaudio + ### create a virtual environment for pytorch + conda create -n pytorch python=3.11 --yes + conda activate pytorch + pip install torch torchvision torchaudio -- group: packer - modules: - - id: custom-image - source: modules/packer/custom-image - kind: packer - use: - - network - - script - settings: - # give VM a public IP to ensure startup script can reach public internet - # w/o new VPC - omit_external_ip: false - source_image_project_id: [schedmd-slurm-public] - # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family - source_image_family: slurm-gcp-6-8-debian-11 - # You can find size of source image by using following command - # gcloud compute images describe-from-family --project schedmd-slurm-public - disk_size: $(vars.disk_size_gb) - image_family: $(vars.new_image.family) - # building this image does not require a GPU-enabled VM - machine_type: c2-standard-16 - state_timeout: 300m - zone: $(vars.zone) + - group: packer + modules: + - id: custom-image + source: modules/packer/custom-image + kind: packer + use: + - network + - script + settings: + # give VM a public IP to ensure startup script can reach public internet + # w/o new VPC + omit_external_ip: false + source_image_project_id: [schedmd-slurm-public] + # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family + source_image_family: slurm-gcp-6-8-debian-11 + # You can find size of source image by using following command + # gcloud compute images describe-from-family --project schedmd-slurm-public + disk_size: $(vars.disk_size_gb) + image_family: $(vars.new_image.family) + # building this image does not require a GPU-enabled VM + machine_type: c2-standard-16 + state_timeout: 300m + zone: $(vars.zone) diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json index dc19e57f7..cb49f9bf4 100644 --- a/scoring/utils/slurm/config.json +++ b/scoring/utils/slurm/config.json @@ -1,106 +1,106 @@ { - "0": { - "framework": "jax", - "workload": "imagenet_resnet", - "dataset": "imagenet", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": 411096763, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "1": { - "framework": "jax", - "workload": "imagenet_vit", - "dataset": "imagenet", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": -1884713130, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "2": { - "framework": "jax", - "workload": "fastmri", - "dataset": "fastmri", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": -214785144, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "3": { - "framework": "jax", - "workload": "ogbg", - "dataset": "ogbg", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": -893097833, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "4": { - "framework": "jax", - "workload": "wmt", - "dataset": "wmt", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": -1244182279, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "5": { - "framework": "jax", - "workload": "librispeech_deepspeech", - "dataset": "librispeech", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": 1546003634, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "6": { - "framework": "jax", - "workload": "criteo1tb", - "dataset": "criteo1tb", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": -2062333143, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - }, - "7": { - "framework": "jax", - "workload": "librispeech_conformer", - "dataset": "librispeech", - "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", - "experiment_dir": "$HOME/experiments/jit_switch/study_0", - "rng_seed": 409209730, - "tuning_ruleset": "external", - "num_tuning_trials": 1, - "hparam_start_index": 0, - "hparam_end_index": 1, - "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" - } -} \ No newline at end of file + "0": { + "framework": "jax", + "workload": "imagenet_resnet", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 411096763, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "1": { + "framework": "jax", + "workload": "imagenet_vit", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1884713130, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "2": { + "framework": "jax", + "workload": "fastmri", + "dataset": "fastmri", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -214785144, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "3": { + "framework": "jax", + "workload": "ogbg", + "dataset": "ogbg", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -893097833, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "4": { + "framework": "jax", + "workload": "wmt", + "dataset": "wmt", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1244182279, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "5": { + "framework": "jax", + "workload": "librispeech_deepspeech", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 1546003634, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "6": { + "framework": "jax", + "workload": "criteo1tb", + "dataset": "criteo1tb", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -2062333143, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "7": { + "framework": "jax", + "workload": "librispeech_conformer", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 409209730, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + } +} diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 116e70459..f6a1ca158 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -6,60 +6,66 @@ --experiment_dir $HOME/experiments/ \ --framework """ + import json import os -from absl import app -from absl import flags import jax +from absl import app, flags SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py' -EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' +EXPERIMENT_DIR = ( + 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' +) TUNING_SEARCH_SPACE = None FRAMEWORK = 'pytorch' TUNING_RULESET = 'self' flags.DEFINE_string( - 'submission_path', - SUBMISSION_PATH, - 'Path to submission module relative to algorithmic-efficiency dir.') + 'submission_path', + SUBMISSION_PATH, + 'Path to submission module relative to algorithmic-efficiency dir.', +) +flags.DEFINE_string( + 'tuning_search_space', + TUNING_SEARCH_SPACE, + 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.', +) flags.DEFINE_string( - 'tuning_search_space', - TUNING_SEARCH_SPACE, - 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.' + 'experiment_dir', + EXPERIMENT_DIR, + 'Path to experiment dir where logs will be saved.', ) -flags.DEFINE_string('experiment_dir', - EXPERIMENT_DIR, - 'Path to experiment dir where logs will be saved.') flags.DEFINE_enum( - 'framework', - FRAMEWORK, - enum_values=['jax', 'pytorch'], - help='Can be either pytorch or jax.') + 'framework', + FRAMEWORK, + enum_values=['jax', 'pytorch'], + help='Can be either pytorch or jax.', +) flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') flags.DEFINE_enum( - 'tuning_ruleset', - TUNING_RULESET, - enum_values=['external', 'self'], - help='Which tuning ruleset to score this submission on. Can be external or self.' + 'tuning_ruleset', + TUNING_RULESET, + enum_values=['external', 'self'], + help='Which tuning ruleset to score this submission on. Can be external or self.', ) FLAGS = flags.FLAGS -MIN_INT = -2**(31) -MAX_INT = 2**(31) - 1 +MIN_INT = -(2 ** (31)) +MAX_INT = 2 ** (31) - 1 NUM_TUNING_TRIALS = 5 # For external tuning ruleset NUM_STUDIES = 3 WORKLOADS = { - "imagenet_resnet": {"dataset": "imagenet"}, - "imagenet_vit": {"dataset": "imagenet"}, - "fastmri": {"dataset": "fastmri"}, - "ogbg": {"dataset": "ogbg"}, - "wmt": {"dataset": "wmt"}, - "librispeech_deepspeech": {"dataset": "librispeech"}, - "criteo1tb": {"dataset": "criteo1tb"}, - "librispeech_conformer": {"dataset": "librispeech"} + 'imagenet_resnet': {'dataset': 'imagenet'}, + 'imagenet_vit': {'dataset': 'imagenet'}, + 'fastmri': {'dataset': 'fastmri'}, + 'ogbg': {'dataset': 'ogbg'}, + 'wmt': {'dataset': 'wmt'}, + 'librispeech_deepspeech': {'dataset': 'librispeech'}, + 'criteo1tb': {'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'dataset': 'librispeech'}, } @@ -81,7 +87,7 @@ def main(_): print(seed) # Add job job = {} - study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}') job['framework'] = FLAGS.framework job['workload'] = workload job['dataset'] = WORKLOADS[workload]['dataset'] @@ -103,7 +109,7 @@ def main(_): print(seed) # Add job job = {} - study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}') job['framework'] = FLAGS.framework job['workload'] = workload job['dataset'] = WORKLOADS[workload]['dataset'] @@ -119,7 +125,7 @@ def main(_): # Convert job array to dict with job indices job_dict = {} for i, job in enumerate(jobs): - job_dict[f"{i}"] = job + job_dict[f'{i}'] = job with open('config.json', 'w') as f: json.dump(job_dict, f, indent=4) diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c205d28b2..c7d4ae195 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -1,34 +1,34 @@ { - "imagenet_resnet": { - "max_steps": 186666, - "dataset": "imagenet" - }, - "imagenet_vit": { - "max_steps": 186666, - "dataset": "imagenet" - }, - "fastmri": { - "max_steps": 36189, - "dataset": "fastmri" - }, - "ogbg": { - "max_steps": 80000, - "dataset": "ogbg" - }, - "wmt": { - "max_steps": 133333, - "dataset": "wmt" - }, - "librispeech_deepspeech": { - "max_steps": 48000, - "dataset": "librispeech" - }, - "criteo1tb": { - "max_steps": 10666, - "dataset": "criteo1tb" - }, - "librispeech_conformer": { - "max_steps": 80000, - "dataset": "librispeech" - } - } \ No newline at end of file + "imagenet_resnet": { + "max_steps": 186666, + "dataset": "imagenet" + }, + "imagenet_vit": { + "max_steps": 186666, + "dataset": "imagenet" + }, + "fastmri": { + "max_steps": 36189, + "dataset": "fastmri" + }, + "ogbg": { + "max_steps": 80000, + "dataset": "ogbg" + }, + "wmt": { + "max_steps": 133333, + "dataset": "wmt" + }, + "librispeech_deepspeech": { + "max_steps": 48000, + "dataset": "librispeech" + }, + "criteo1tb": { + "max_steps": 10666, + "dataset": "criteo1tb" + }, + "librispeech_conformer": { + "max_steps": 80000, + "dataset": "librispeech" + } +} diff --git a/scoring/utils/workload_metadata_self_tuning.json b/scoring/utils/workload_metadata_self_tuning.json index 105d5c52f..9d3e6b93d 100644 --- a/scoring/utils/workload_metadata_self_tuning.json +++ b/scoring/utils/workload_metadata_self_tuning.json @@ -1,34 +1,34 @@ { - "imagenet_resnet": { - "max_steps": 559998, - "dataset": "imagenet" - }, - "imagenet_vit": { - "max_steps": 559998, - "dataset": "imagenet" - }, - "fastmri": { - "max_steps": 108567, - "dataset": "fastmri" - }, - "ogbg": { - "max_steps": 240000, - "dataset": "ogbg" - }, - "wmt": { - "max_steps": 399999, - "dataset": "wmt" - }, - "librispeech_deepspeech": { - "max_steps": 144000, - "dataset": "librispeech" - }, - "criteo1tb": { - "max_steps": 31998, - "dataset": "criteo1tb" - }, - "librispeech_conformer": { - "max_steps": 240000, - "dataset": "librispeech" - } - } \ No newline at end of file + "imagenet_resnet": { + "max_steps": 559998, + "dataset": "imagenet" + }, + "imagenet_vit": { + "max_steps": 559998, + "dataset": "imagenet" + }, + "fastmri": { + "max_steps": 108567, + "dataset": "fastmri" + }, + "ogbg": { + "max_steps": 240000, + "dataset": "ogbg" + }, + "wmt": { + "max_steps": 399999, + "dataset": "wmt" + }, + "librispeech_deepspeech": { + "max_steps": 144000, + "dataset": "librispeech" + }, + "criteo1tb": { + "max_steps": 31998, + "dataset": "criteo1tb" + }, + "librispeech_conformer": { + "max_steps": 240000, + "dataset": "librispeech" + } +} From f4ae9beac42dbcad30b77cfc386616ab29214061 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:19:28 +0200 Subject: [PATCH 103/123] Format reference_algorithms/ --- .../cifar/cifar_jax/submission.py | 178 +- .../cifar/cifar_pytorch/submission.py | 129 +- .../mnist/mnist_jax/submission.py | 156 +- .../mnist/mnist_pytorch/submission.py | 105 +- .../adafactor/jax/sharded_adafactor.py | 275 +-- .../adafactor/jax/submission.py | 205 +- .../adafactor/pytorch/submission.py | 288 +-- .../paper_baselines/adamw/jax/submission.py | 204 +- .../adamw/pytorch/submission.py | 157 +- .../paper_baselines/lamb/jax/submission.py | 204 +- .../lamb/pytorch/submission.py | 228 +- .../momentum/jax/submission.py | 220 +- .../momentum/pytorch/submission.py | 171 +- .../paper_baselines/nadamw/jax/submission.py | 269 ++- .../nadamw/pytorch/submission.py | 274 ++- .../nesterov/jax/submission.py | 220 +- .../nesterov/pytorch/submission.py | 171 +- .../paper_baselines/sam/jax/submission.py | 256 ++- .../paper_baselines/sam/pytorch/submission.py | 191 +- .../shampoo/jax/distributed_shampoo.py | 2007 +++++++++-------- .../paper_baselines/shampoo/jax/submission.py | 211 +- .../cosine_warmup.py | 30 +- .../criteo1tb/tuning_search_space.json | 20 +- .../data_selection.py | 17 +- .../target_setting_algorithms/jax_adamw.py | 52 +- .../target_setting_algorithms/jax_momentum.py | 79 +- .../target_setting_algorithms/jax_nadamw.py | 96 +- .../target_setting_algorithms/jax_nesterov.py | 79 +- .../jax_submission_base.py | 148 +- .../pytorch_adamw.py | 47 +- .../pytorch_momentum.py | 58 +- .../pytorch_nadamw.py | 172 +- .../pytorch_nesterov.py | 58 +- .../pytorch_submission_base.py | 101 +- 34 files changed, 3915 insertions(+), 3161 deletions(-) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 3d8e35eaa..37f74ac45 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -19,26 +19,29 @@ def get_batch_size(workload_name): def cosine_decay(lr, step, total_steps): - ratio = jnp.maximum(0., step / total_steps) - mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio)) + ratio = jnp.maximum(0.0, step / total_steps) + mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio)) return mult * lr -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): +def create_learning_rate_fn( + hparams: spec.Hyperparameters, steps_per_epoch: int +): """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128. + base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.0 warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) + init_value=0.0, + end_value=base_learning_rate, + transition_steps=hparams.warmup_epochs * steps_per_epoch, + ) cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) + init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) + schedules=[warmup_fn, cosine_fn], + boundaries=[hparams.warmup_epochs * steps_per_epoch], + ) return schedule_fn @@ -46,51 +49,59 @@ def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): steps_per_epoch = num_train_examples // get_batch_size('cifar') learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) + nesterov=True, + momentum=hyperparameters.momentum, + learning_rate=learning_rate_fn, + ) return opt_init_fn, opt_update_fn -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_params del model_state del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + opt_init_fn, opt_update_fn = optimizer( + hyperparameters, workload.num_train_examples + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng, +): def _loss_fn(params): """loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn(batch['targets'], logits) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] weight_penalty_params = jax.tree_util.tree_leaves(params) @@ -102,25 +113,27 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -130,21 +143,30 @@ def update_params( optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) new_optimizer_state, new_params, new_model_state = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + per_device_rngs, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -158,14 +180,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index d8b91f83a..5fd51c3b2 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -16,56 +16,63 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del workload del model_state del rng - base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128. + base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.0 optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=base_lr, + momentum=hyperparameters.momentum, + weight_decay=hyperparameters.l2, + ), } scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-5, - end_factor=1., - total_iters=hyperparameters.warmup_epochs) + optimizer_state['optimizer'], + start_factor=1e-5, + end_factor=1.0, + total_iters=hyperparameters.warmup_epochs, + ) cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) + hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1 + ) scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs) + optimizer_state['optimizer'], T_max=cosine_epochs + ) optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs]) + optimizer_state['optimizer'], + schedulers=[scheduler1, scheduler2], + milestones=[hyperparameters.warmup_epochs], + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters @@ -78,15 +85,17 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) + label_batch=batch['targets'], logits_batch=logits_batch + ) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] loss.backward() @@ -99,16 +108,18 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -122,14 +133,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index c1f54597d..3ef97577f 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -18,50 +18,59 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_params del model_state del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = optax.chain( - optax.scale_by_adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.999, - eps=hyperparameters.epsilon), - optax.scale(-hyperparameters.learning_rate)) + optax.scale_by_adam( + b1=1.0 - hyperparameters.one_minus_beta_1, + b2=0.999, + eps=hyperparameters.epsilon, + ), + optax.scale(-hyperparameters.learning_rate), + ) return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn # We need to jax.pmap here instead of inside update_params because the latter # would recompile the function every step. @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_update_params(workload: spec.Workload, - opt_update_fn, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - optimizer_state: spec.OptimizerState, - rng: spec.RandomState) -> spec.UpdateReturn: + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, None, 0, 0, 0), + static_broadcasted_argnums=(0, 1), +) +def pmapped_update_params( + workload: spec.Workload, + opt_update_fn, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + optimizer_state: spec.OptimizerState, + rng: spec.RandomState, +) -> spec.UpdateReturn: del hyperparameters def loss_fn(params): logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=params, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn(batch['targets'], logits_batch) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] return loss, new_model_state @@ -69,25 +78,27 @@ def loss_fn(params): grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -98,27 +109,30 @@ def update_params( per_device_rngs = jax.random.split(rng, jax.local_device_count()) optimizer_state, opt_update_fn = optimizer_state new_optimizer_state, updated_params, new_model_state = pmapped_update_params( - workload, - opt_update_fn, - current_param_container, - model_state, - hyperparameters, - batch, - optimizer_state, - per_device_rngs) + workload, + opt_update_fn, + current_param_container, + model_state, + hyperparameters, + batch, + optimizer_state, + per_device_rngs, + ) return (new_optimizer_state, opt_update_fn), updated_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -132,14 +146,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index dedd96793..9940fca6e 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -13,38 +13,41 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_state del workload del rng optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999), - eps=hyperparameters.epsilon), + 'optimizer': torch.optim.Adam( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999), + eps=hyperparameters.epsilon, + ), } return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type @@ -59,15 +62,17 @@ def update_params( param.grad = None output, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=output) + label_batch=batch['targets'], logits_batch=output + ) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] loss.backward() optimizer_state['optimizer'].step() @@ -75,16 +80,18 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -98,14 +105,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index ff98464ae..c83f14a13 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -38,8 +38,9 @@ NestedHParams = Any -def to_quantized(fvalue: JTensor, - quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]: +def to_quantized( + fvalue: JTensor, quantized_dtype: jnp.dtype +) -> Tuple[JTensor, JTensor]: """Converts floating point values `fvalues` to quantized values. We use a very simple quantization scheme where the range is symmetric around @@ -82,16 +83,17 @@ def to_quantized(fvalue: JTensor, # We first decide the scale. if fvalue.ndim < 1: raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') + f'Input array {fvalue} must have a strictly positive number of ' + 'dimensions.' + ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, ...] # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) @@ -128,8 +130,8 @@ def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor: """ step = step_counter beta2 = jnp.array(beta2, dtype=jnp.float32) - t = step + 1. - return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t)) + t = step + 1.0 + return beta2 * (1.0 - jnp.power(beta2, t - 1.0)) / (1.0 - jnp.power(beta2, t)) def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: @@ -145,7 +147,7 @@ def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: """ step = step_counter exponent = jnp.array(exponent, dtype=jnp.float32) - return 1. - jnp.power((step + 1.), -exponent) + return 1.0 - jnp.power((step + 1.0), -exponent) def reduce_mean(array: JTensor) -> JTensor: @@ -187,6 +189,7 @@ def reduce_rms(array: JTensor) -> JTensor: @dataclasses.dataclass(frozen=True) class _ShardedAdafactorUpdateResult: """Structure containing per-variable info for Adafactor.""" + update: Optional[Any] m: Optional[Any] m_scale: Optional[Any] @@ -197,6 +200,7 @@ class _ShardedAdafactorUpdateResult: class ShardedAdafactorState(NamedTuple): """Overall state of the ShardedAdafactor optimizer.""" + count: JTensor m: Optional[NestedJTensor] m_scale: Optional[NestedJTensor] @@ -208,27 +212,29 @@ class ShardedAdafactorState(NamedTuple): class _ShardedAdafactorHelper: """Helper class to implement optax-based sharded Adafactor.""" - def __init__(self, - learning_rate: optax.Schedule, - weight_decay: Optional[float], - layerwise_adaptation: bool, - decay_method: str, - decay_adam: float, - decay_pow: float, - beta1: float, - clip_threshold: Optional[float], - factored: bool, - epsilon1_grad_sq_reg: float, - quantized_dtype: jnp.dtype, - respect_skip_lp_regularization: bool, - exclude_from_layerwise_adaptation: Optional[List[str]], - per_var_learning_summary: bool, - sort_factored_second_moment_dims: bool, - min_dim_size_to_factor: int, - multiply_by_parameter_scale: bool, - epsilon2_param_scale_reg: float, - maybe_inf_to_nan: bool, - nesterov: bool) -> None: + def __init__( + self, + learning_rate: optax.Schedule, + weight_decay: Optional[float], + layerwise_adaptation: bool, + decay_method: str, + decay_adam: float, + decay_pow: float, + beta1: float, + clip_threshold: Optional[float], + factored: bool, + epsilon1_grad_sq_reg: float, + quantized_dtype: jnp.dtype, + respect_skip_lp_regularization: bool, + exclude_from_layerwise_adaptation: Optional[List[str]], + per_var_learning_summary: bool, + sort_factored_second_moment_dims: bool, + min_dim_size_to_factor: int, + multiply_by_parameter_scale: bool, + epsilon2_param_scale_reg: float, + maybe_inf_to_nan: bool, + nesterov: bool, + ) -> None: """Constructor. See ShardedAdafactor() below.""" self._learning_rate = learning_rate @@ -315,12 +321,13 @@ def should_store_momentum_in_qint(self, shape): def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( - count=count, - m=jax.tree.map(lambda o: o.m, result_tree), - m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), - vr=jax.tree.map(lambda o: o.vr, result_tree), - vc=jax.tree.map(lambda o: o.vc, result_tree), - v=jax.tree.map(lambda o: o.v, result_tree)) + count=count, + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree), + ) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -353,12 +360,13 @@ def init(self, param): else: output_v = jnp.zeros(shape, dtype=jnp.float32) return _ShardedAdafactorUpdateResult( - update=output_update, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) + update=output_update, + m=output_m, + m_scale=output_m_scale, + vr=output_vr, + vc=output_vc, + v=output_v, + ) def inf_to_nan(self, array): """Converting Infinity values to the more sticky NaN.""" @@ -386,16 +394,9 @@ def parameter_scale(self, var): """ return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype)) - def compute_var_and_slot_update(self, - count, - grad, - m, - m_scale, - vr, - vc, - v, - param, - var_name=None): + def compute_var_and_slot_update( + self, count, grad, m, m_scale, vr, vc, v, param, var_name=None + ): """Computes the var and optimizer slots updates for a single variable.""" # We can probably skip this step grad = grad.astype(jnp.float32) @@ -434,7 +435,7 @@ def compute_var_and_slot_update(self, update_scale += grad_squared_mean * 1e-30 # END HACK - mixing_rate = 1. - decay_rate + mixing_rate = 1.0 - decay_rate shape = param.shape output_m = jnp.zeros((1,)) @@ -449,18 +450,23 @@ def compute_var_and_slot_update(self, # reduce_mean(). vr_axis, vc_axis = factored_second_moment_dims grad_squared_row_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vr_axis)) + jnp.mean(grad_squared, axis=vr_axis) + ) grad_squared_col_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vc_axis)) + jnp.mean(grad_squared, axis=vc_axis) + ) new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean output_vr = new_vr output_vc = new_vc long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True) - r_factor = 1. / jnp.sqrt(new_vr / long_term_mean) - c_factor = 1. / jnp.sqrt(new_vc) - x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims( - c_factor, vc_axis) + r_factor = 1.0 / jnp.sqrt(new_vr / long_term_mean) + c_factor = 1.0 / jnp.sqrt(new_vc) + x = ( + grad + * jnp.expand_dims(r_factor, vr_axis) + * jnp.expand_dims(c_factor, vc_axis) + ) else: # v with sharding annotation. new_v = decay_rate * v + mixing_rate * grad_squared @@ -468,7 +474,7 @@ def compute_var_and_slot_update(self, x = grad / jnp.sqrt(new_v) if self._clip_threshold is not None: - clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold) + clipping_denom = jnp.maximum(1.0, reduce_rms(x) / self._clip_threshold) clipping_denom = self.inf_to_nan(clipping_denom) x /= clipping_denom @@ -481,7 +487,7 @@ def compute_var_and_slot_update(self, m = to_float(m, m_scale) if self._nesterov: subtrahend_original = subtrahend - subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend + subtrahend = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = self.inf_to_nan(subtrahend) if self._quantized_dtype == jnp.bfloat16: new_m = subtrahend.astype(jnp.bfloat16) @@ -496,8 +502,8 @@ def compute_var_and_slot_update(self, if self._nesterov: subtrahend = ( - self._beta1 * subtrahend + - (1.0 - self._beta1) * subtrahend_original) + self._beta1 * subtrahend + (1.0 - self._beta1) * subtrahend_original + ) if self._weight_decay is not None: # Apply decoupled weight decay to be consistent with AdamW. @@ -527,43 +533,45 @@ def compute_var_and_slot_update(self, g_norm = reduce_rms(subtrahend / update_scale) + self._epsilon1 ratio = w_norm / g_norm ratio = jnp.where( - jnp.greater(w_norm, 0), - jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), - 1.0) + jnp.greater(w_norm, 0), + jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) subtrahend *= ratio return _ShardedAdafactorUpdateResult( - update=-subtrahend, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) + update=-subtrahend, + m=output_m, + m_scale=output_m_scale, + vr=output_vr, + vc=output_vc, + v=output_v, + ) def sharded_adafactor( - learning_rate: optax.Schedule, - weight_decay: Optional[Union[float, Dict[str, float]]] = None, - layerwise_adaptation: bool = False, - decay_method: str = 'adam', - decay_adam: float = 0.99, - decay_pow: float = 0., - beta1: float = 0.9, - clip_threshold: Optional[float] = 1., - factored: bool = True, - epsilon1_grad_sq_reg: float = 1e-30, - quantized_dtype: jnp.dtype = jnp.int8, - respect_skip_lp_regularization: bool = False, - exclude_from_layerwise_adaptation: Optional[List[str]] = None, - per_var_learning_summary: bool = False, - sort_factored_second_moment_dims: bool = False, - # min_dim_size_to_factor is only used when - # sort_factored_second_moment_dims=True. - min_dim_size_to_factor: int = 128, - multiply_by_parameter_scale: bool = False, - epsilon2_param_scale_reg: float = 1e-3, - maybe_inf_to_nan: bool = True, - nesterov: bool = False, + learning_rate: optax.Schedule, + weight_decay: Optional[Union[float, Dict[str, float]]] = None, + layerwise_adaptation: bool = False, + decay_method: str = 'adam', + decay_adam: float = 0.99, + decay_pow: float = 0.0, + beta1: float = 0.9, + clip_threshold: Optional[float] = 1.0, + factored: bool = True, + epsilon1_grad_sq_reg: float = 1e-30, + quantized_dtype: jnp.dtype = jnp.int8, + respect_skip_lp_regularization: bool = False, + exclude_from_layerwise_adaptation: Optional[List[str]] = None, + per_var_learning_summary: bool = False, + sort_factored_second_moment_dims: bool = False, + # min_dim_size_to_factor is only used when + # sort_factored_second_moment_dims=True. + min_dim_size_to_factor: int = 128, + multiply_by_parameter_scale: bool = False, + epsilon2_param_scale_reg: float = 1e-3, + maybe_inf_to_nan: bool = True, + nesterov: bool = False, ) -> optax.GradientTransformation: """AdaFactor optimizer that supports SPMD sharding. @@ -638,53 +646,60 @@ def sharded_adafactor( assert decay_pow >= 0 assert learning_rate is not None assert decay_method == 'adam' or decay_method == 'pow', ( - f'decay_method: {decay_method} not supported. Supported methods are ' - '"pow", or "adam".') + f'decay_method: {decay_method} not supported. Supported methods are ' + '"pow", or "adam".' + ) sharded_adafactor_helper = _ShardedAdafactorHelper( - learning_rate=learning_rate, - weight_decay=weight_decay, - layerwise_adaptation=layerwise_adaptation, - decay_method=decay_method, - decay_adam=decay_adam, - decay_pow=decay_pow, - beta1=beta1, - clip_threshold=clip_threshold, - factored=factored, - epsilon1_grad_sq_reg=epsilon1_grad_sq_reg, - quantized_dtype=quantized_dtype, - respect_skip_lp_regularization=respect_skip_lp_regularization, - exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation, - per_var_learning_summary=per_var_learning_summary, - sort_factored_second_moment_dims=sort_factored_second_moment_dims, - min_dim_size_to_factor=min_dim_size_to_factor, - multiply_by_parameter_scale=multiply_by_parameter_scale, - epsilon2_param_scale_reg=epsilon2_param_scale_reg, - maybe_inf_to_nan=maybe_inf_to_nan, - nesterov=nesterov) + learning_rate=learning_rate, + weight_decay=weight_decay, + layerwise_adaptation=layerwise_adaptation, + decay_method=decay_method, + decay_adam=decay_adam, + decay_pow=decay_pow, + beta1=beta1, + clip_threshold=clip_threshold, + factored=factored, + epsilon1_grad_sq_reg=epsilon1_grad_sq_reg, + quantized_dtype=quantized_dtype, + respect_skip_lp_regularization=respect_skip_lp_regularization, + exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation, + per_var_learning_summary=per_var_learning_summary, + sort_factored_second_moment_dims=sort_factored_second_moment_dims, + min_dim_size_to_factor=min_dim_size_to_factor, + multiply_by_parameter_scale=multiply_by_parameter_scale, + epsilon2_param_scale_reg=epsilon2_param_scale_reg, + maybe_inf_to_nan=maybe_inf_to_nan, + nesterov=nesterov, + ) def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( - jnp.zeros([], jnp.int32), - jax.tree.map(sharded_adafactor_helper.init, params)) + jnp.zeros([], jnp.int32), + jax.tree.map(sharded_adafactor_helper.init, params), + ) def update_fn(updates, state, params=None): if params is None: raise ValueError( - 'You are using a transformation that requires the current value of ' - 'parameters, but you are not passing `params` when calling `update`.') + 'You are using a transformation that requires the current value of ' + 'parameters, but you are not passing `params` when calling `update`.' + ) compute_var_and_slot_update_fn = functools.partial( - sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree.map(compute_var_and_slot_update_fn, - updates, - state.m, - state.m_scale, - state.vr, - state.vc, - state.v, - params) + sharded_adafactor_helper.compute_var_and_slot_update, state.count + ) + output = jax.tree.map( + compute_var_and_slot_update_fn, + updates, + state.m, + state.m_scale, + state.vr, + state.vc, + state.v, + params, + ) updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 1833ab8af..898de35eb 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -10,17 +10,20 @@ import optax from algoperf import spec -from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ - sharded_adafactor +from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import ( + sharded_adafactor, +) _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an Adafactor optimizer and a learning rate schedule.""" del model_params del model_state @@ -30,99 +33,113 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = sharded_adafactor( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + beta1=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -139,37 +156,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -205,14 +228,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 7aa457a25..dd831566f 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -16,36 +16,40 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an Adafactor optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - Adafactor( - model_params.parameters(), - lr=hyperparameters.learning_rate, - beta1=1 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay), + 'optimizer': Adafactor( + model_params.parameters(), + lr=hyperparameters.learning_rate, + beta1=1 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + ), } optimizer = optimizer_state['optimizer'] warmup = LinearLR( - optimizer, - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) + optimizer, + start_factor=1e-10, + end_factor=1.0, + total_iters=hyperparameters.warmup_steps, + ) cosine_steps = max(workload.step_hint - hyperparameters.warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) optimizer_state['scheduler'] = SequentialLR( - optimizer, - schedulers=[warmup, cosine_decay], - milestones=[hyperparameters.warmup_steps]) + optimizer, + schedulers=[warmup, cosine_decay], + milestones=[hyperparameters.warmup_steps], + ) return optimizer_state @@ -54,56 +58,56 @@ class Adafactor(torch.optim.Optimizer): src/transformers/optimization.py#L386""" def __init__( - self, - params, - lr=None, - beta1=0.9, - decay_adam=0.99, - weight_decay=0.0, + self, + params, + lr=None, + beta1=0.9, + decay_adam=0.99, + weight_decay=0.0, ): defaults = dict( - lr=lr, - beta1=beta1, - decay_adam=decay_adam, - weight_decay=weight_decay, - decay_pow=0.0, - layerwise_adaptation=False, - decay_method='adam', - clip_threshold=1.0, - factored=True, - epsilon1_grad_sq_reg=1e-30, - respect_skip_lp_regularization=False, - exclude_from_layerwise_adaptation=None, - per_var_learning_summary=False, - sort_factored_second_moment_dims=False, - # Unused because sort_factored_second_moment_dims=False. - min_dim_size_to_factor=128, - multiply_by_parameter_scale=False, - # Unused because multiply_by_parameter_scale=False. - epsilon2_param_scale_reg=1e-3, - maybe_inf_to_nan=True, + lr=lr, + beta1=beta1, + decay_adam=decay_adam, + weight_decay=weight_decay, + decay_pow=0.0, + layerwise_adaptation=False, + decay_method='adam', + clip_threshold=1.0, + factored=True, + epsilon1_grad_sq_reg=1e-30, + respect_skip_lp_regularization=False, + exclude_from_layerwise_adaptation=None, + per_var_learning_summary=False, + sort_factored_second_moment_dims=False, + # Unused because sort_factored_second_moment_dims=False. + min_dim_size_to_factor=128, + multiply_by_parameter_scale=False, + # Unused because multiply_by_parameter_scale=False. + epsilon2_param_scale_reg=1e-3, + maybe_inf_to_nan=True, ) super().__init__(params, defaults) def inf_to_nan(self, group, x): - if group["maybe_inf_to_nan"]: + if group['maybe_inf_to_nan']: x = torch.nan_to_num(x, nan=torch.nan, posinf=torch.nan, neginf=torch.nan) return x def step(self, closure=None): """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Performs a single optimization step + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: loss = closure() for group in self.param_groups: inf_to_nan = partial(self.inf_to_nan, group) - for p in group["params"]: + for p in group['params']: if p.grad is None: continue grad = p.grad.data @@ -111,7 +115,7 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") + raise RuntimeError('Adafactor does not support sparse gradients.') state = self.state[p] grad_shape = grad.shape @@ -120,51 +124,54 @@ def step(self, closure=None): # State Initialization if len(state) == 0: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like(grad) + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(grad) if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + - grad_shape[-1:]).to(grad) + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) else: - state["exp_avg_sq"] = torch.zeros_like(grad) + state['exp_avg_sq'] = torch.zeros_like(grad) else: - state["exp_avg"] = state["exp_avg"].to(grad) + state['exp_avg'] = state['exp_avg'].to(grad) if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() - state["step"] += 1 - lr = group["lr"] - beta1 = group["beta1"] - beta2 = group["decay_adam"] + state['step'] += 1 + lr = group['lr'] + beta1 = group['beta1'] + beta2 = group['decay_adam'] - t = state["step"] - beta2t = beta2 * (1. - beta2**(t - 1.)) / (1. - beta2**t) + t = state['step'] + beta2t = beta2 * (1.0 - beta2 ** (t - 1.0)) / (1.0 - beta2**t) - exp_avg_sq_update = (grad**2) + group["epsilon1_grad_sq_reg"] + exp_avg_sq_update = (grad**2) + group['epsilon1_grad_sq_reg'] if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] exp_avg_sq_row.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t + ) exp_avg_sq_col.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t) + exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t + ) r_factor = inf_to_nan( - exp_avg_sq_row / - exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1) + exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) + ).unsqueeze(-1) c_factor = inf_to_nan(exp_avg_sq_col).unsqueeze(-2) denom = r_factor * c_factor else: - exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2t).add_(exp_avg_sq_update, alpha=1.0 - beta2t) denom = exp_avg_sq @@ -172,15 +179,16 @@ def step(self, closure=None): denom = denom.sqrt() update = grad / denom # Clip the update based on RMS. - clipping_denom = inf_to_nan(torch.square(update).mean().sqrt() \ - /group["clip_threshold"]).clamp(min=1.0) + clipping_denom = inf_to_nan( + torch.square(update).mean().sqrt() / group['clip_threshold'] + ).clamp(min=1.0) update = update / clipping_denom * lr # Momentum - exp_avg = state["exp_avg"] + exp_avg = state['exp_avg'] exp_avg.mul_(beta1).add_(update, alpha=1 - beta1) - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * lr) + if group['weight_decay'] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr) p_data_fp32.add_(-exp_avg) @@ -191,18 +199,19 @@ def step(self, closure=None): def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -214,22 +223,26 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -243,12 +256,14 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -256,28 +271,34 @@ def update_params( if global_step <= 100 or global_step % 500 == 0: if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -314,14 +335,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..52c8d5ee2 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -14,11 +14,13 @@ _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,101 +30,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -139,37 +155,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -209,14 +231,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..faefcd254 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -15,55 +15,60 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay, - fused=False), + 'optimizer': torch.optim.AdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + fused=False, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = hyperparameters.warmup_factor * step_hint warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -75,22 +80,26 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -104,7 +113,8 @@ def update_params( if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -113,31 +123,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -177,14 +194,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 70e305514..eedbbfc37 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -21,11 +21,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a LAMB optimizer and a learning rate schedule.""" del model_params del model_state @@ -35,61 +37,70 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.lamb( - learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] @@ -97,40 +108,45 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -147,37 +163,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -213,14 +235,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index c1c6cec0a..5cda59d6f 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -15,13 +15,9 @@ # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py class LAMB(torch.optim.Optimizer): - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0.0): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -39,7 +35,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -74,48 +71,53 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) lamb( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def lamb(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float): - +def lamb( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -147,61 +149,67 @@ def lamb(params: List[Tensor], update_norm = torch.linalg.norm(update) # Set trust_ratio to 1 in case where parameters would never be updated. - if param_norm == 0. or update_norm == 0.: - trust_ratio = 1. + if param_norm == 0.0 or update_norm == 0.0: + trust_ratio = 1.0 else: trust_ratio = param_norm / update_norm param.add_(update, alpha=-lr * trust_ratio) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a LAMB optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - LAMB( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay) + 'optimizer': LAMB( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -213,31 +221,36 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss, _ = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) loss.backward() if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -246,31 +259,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -306,14 +326,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cbb6d6dcd..3540f9415 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -14,11 +14,13 @@ _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=1.0 - hyperparameters.one_minus_beta1, - nesterov=False) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=1.0 - hyperparameters.one_minus_beta1, + nesterov=False, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn @@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -173,37 +191,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -243,14 +267,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index c3760d20e..bad750857 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -14,24 +14,26 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=False), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=False, + ), } # Create learning rate schedule. @@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -91,26 +98,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -123,7 +134,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -132,31 +144,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -196,14 +215,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index c451a18ac..aa1a08f69 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -4,15 +4,17 @@ # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) # isort: on import chex @@ -30,15 +32,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +74,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +128,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +137,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +304,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +380,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index a2f9fb4c5..ecd299988 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -21,33 +21,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +56,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +67,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +77,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +108,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +196,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +260,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +296,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +306,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +377,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..d32026212 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -14,11 +14,13 @@ _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=1.0 - hyperparameters.one_minus_beta1, - nesterov=True) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=1.0 - hyperparameters.one_minus_beta1, + nesterov=True, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn @@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -173,37 +191,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -243,14 +267,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index b4432fbff..77361f472 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -14,24 +14,26 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=True), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=True, + ), } # Create learning rate schedule. @@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -91,26 +98,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -123,7 +134,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -132,31 +144,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -196,14 +215,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index b76589705..ce6db3ac3 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -23,7 +23,8 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: y: A pytree of numpy ndarray, vector y in the equation above. """ gradient_norm = jnp.sqrt( - sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) + sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)) + ) normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -31,11 +32,11 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/ # sharpness_aware_minimization.py def sharpness_aware_minimization( - rho: float, - grad_clip: Optional[float], - batch_axis_name: str, - base_opt_init_fn, - base_opt_update_fn, + rho: float, + grad_clip: Optional[float], + batch_axis_name: str, + base_opt_init_fn, + base_opt_update_fn, ) -> optax.GradientTransformation: """Implementation of Sharpness Aware Minimization (SAM). Paper: https://arxiv.org/abs/2010.01412 @@ -68,22 +69,28 @@ def update_fn(updates, state, grad_fn_params_tuple): # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + lambda p, u: p + rho * u, params, updates + ) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. - (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), - axis_name=batch_axis_name) + (n_valid_examples, updates) = lax.psum( + (n_valid_examples, updates), axis_name=batch_axis_name + ) updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)) + ) scaled_updates = jax.tree.map( - lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, - lambda _: updates, - None) + lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates + ) + updates = jax.lax.cond( + updates_norm > grad_clip, + lambda _: scaled_updates, + lambda _: updates, + None, + ) updates, state = base_opt_update_fn(updates, state, params) return updates, state @@ -91,11 +98,13 @@ def update_fn(updates, state, grad_fn_params_tuple): return optax.GradientTransformation(init_fn, update_fn) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a SAM optimizer (with AdamW base) and a learning rate schedule.""" del model_params del model_state @@ -105,111 +114,127 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create base optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) # Create SAM update fn. grad_clip = ( - hyperparameters.grad_clip - if hasattr(hyperparameters, 'grad_clip') else None) + hyperparameters.grad_clip if hasattr(hyperparameters, 'grad_clip') else None + ) opt_init_fn, opt_update_fn = sharpness_aware_minimization( - rho=hyperparameters.rho, - grad_clip=grad_clip, - batch_axis_name='batch', - base_opt_init_fn=opt_init_fn, - base_opt_update_fn=opt_update_fn) + rho=hyperparameters.rho, + grad_clip=grad_clip, + batch_axis_name='batch', + base_opt_init_fn=opt_init_fn, + base_opt_update_fn=opt_update_fn, + ) # Initialize optimizer state. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params, update_batch_norm=True): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=update_batch_norm) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=update_batch_norm, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) second_grad_fn = jax.value_and_grad( - functools.partial(_loss_fn, update_batch_norm=False), has_aux=True) + functools.partial(_loss_fn, update_batch_norm=False), has_aux=True + ) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, (second_grad_fn, current_param_container)) + grad, optimizer_state, (second_grad_fn, current_param_container) + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -226,37 +251,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -293,14 +324,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 92603f036..fdd4eb8b7 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -17,13 +17,14 @@ # Modified from https://github.com/davda54/sam. class SAM(torch.optim.Optimizer): - - def __init__(self, - params: spec.ParameterContainer, - base_optimizer: torch.optim.Optimizer, - rho: float = 0.05, - adaptive: bool = False, - **kwargs): + def __init__( + self, + params: spec.ParameterContainer, + base_optimizer: torch.optim.Optimizer, + rho: float = 0.05, + adaptive: bool = False, + **kwargs, + ): if rho < 0.0: raise ValueError(f'Invalid rho, should be non-negative: {rho}') @@ -79,12 +80,18 @@ def _grad_norm(self): # In case of model parallelism, put everything on the same device. shared_device = self.param_groups[0]['params'][0].device norm = torch.norm( - torch.stack([((torch.abs(p) if group['adaptive'] else 1.0) * - p.grad).norm(p=2).to(shared_device) - for group in self.param_groups - for p in group['params'] - if p.grad is not None]), - p=2) + torch.stack( + [ + ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad) + .norm(p=2) + .to(shared_device) + for group in self.param_groups + for p in group['params'] + if p.grad is not None + ] + ), + p=2, + ) return norm def load_state_dict(self, state_dict: Dict): @@ -92,11 +99,13 @@ def load_state_dict(self, state_dict: Dict): self.base_optimizer.param_groups = self.param_groups -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -104,46 +113,50 @@ def init_optimizer_state(workload: spec.Workload, # Create SAM optimizer with AdamW base. base_optimizer = torch.optim.AdamW optimizer_state = { - 'optimizer': - SAM(model_params.parameters(), - base_optimizer=base_optimizer, - rho=hyperparameters.rho, - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': SAM( + model_params.parameters(), + base_optimizer=base_optimizer, + rho=hyperparameters.rho, + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) # Create learning rate schedule. optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -156,20 +169,24 @@ def update_params( def _loss_fn(params, update_batch_norm=True): """Loss function used for training.""" logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=update_batch_norm) + params=params, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=update_batch_norm, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -187,7 +204,8 @@ def _loss_fn(params, update_batch_norm=True): with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) optimizer_state['optimizer'].first_step(zero_grad=True) @@ -198,7 +216,8 @@ def _loss_fn(params, update_batch_norm=True): if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].second_step(zero_grad=True) optimizer_state['scheduler'].step() @@ -206,29 +225,34 @@ def _loss_fn(params, update_batch_norm=True): if global_step <= 100 or global_step % 500 == 0: if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': logging_loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - logging_loss.item(), - grad_norm.item()) + { + 'loss': logging_loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + logging_loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -265,14 +289,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index a5c2732ac..830dd4816 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -61,13 +61,16 @@ @struct.dataclass class QuantizedValue: """State associated with quantized value.""" + quantized: chex.Array diagonal: chex.Array # Diagonal (if extract_diagonal is set) bucket_size: chex.Array quantized_dtype: jnp.dtype = struct.field( - pytree_node=False) # Dtype for the quantized value. + pytree_node=False + ) # Dtype for the quantized value. extract_diagonal: bool = struct.field( - pytree_node=False) # In case its centered. + pytree_node=False + ) # In case its centered. shape: Any = struct.field(pytree_node=False) # Shape of the tensor. @classmethod @@ -75,13 +78,16 @@ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): if isinstance(fvalue, list) and not fvalue: return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( - fvalue, quantized_dtype, extract_diagonal) - return QuantizedValue(quantized, - diagonal_fvalue, - bucket_size, - quantized_dtype, - extract_diagonal, - list(quantized.shape)) + fvalue, quantized_dtype, extract_diagonal + ) + return QuantizedValue( + quantized, + diagonal_fvalue, + bucket_size, + quantized_dtype, + extract_diagonal, + list(quantized.shape), + ) # Quantization is from Lingvo JAX optimizers. # We extend it for int16 quantization of PSD matrices. @@ -106,7 +112,8 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): if extract_diagonal and fvalue.ndim != 2: raise ValueError( - f'Input array {fvalue} must be 2D to work with extract_diagonal.') + f'Input array {fvalue} must be 2D to work with extract_diagonal.' + ) diagonal_fvalue = [] if extract_diagonal: @@ -119,16 +126,17 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): # We first decide the scale. if fvalue.ndim < 1: raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') + f'Input array {fvalue} must have a strictly positive number of ' + 'dimensions.' + ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, Ellipsis] # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) @@ -155,10 +163,11 @@ def to_float(self): def _default_zero_field(): return struct.field( - default_factory=functools.partial(jnp.array, 0, jnp.float32)) + default_factory=functools.partial(jnp.array, 0, jnp.float32) + ) -T = TypeVar("T") +T = TypeVar('T') def _maybe_ix(ls, ix): @@ -180,17 +189,19 @@ def wrap_f(x, *args, **kwargs): InversePthRootDiagnosticsSubtype = TypeVar( - "InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics") + 'InversePthRootDiagnosticsSubtype', bound='InversePthRootDiagnostics' +) @struct.dataclass class InversePthRootDiagnostics: """Diagnostics for inverse p-th root iterative procedure. - Given an inverse pth root B = A^(-1/p), contains the average and - maximum diagonal and off diagonal absolute entrywise errors between - (B^p A) and I. - """ + Given an inverse pth root B = A^(-1/p), contains the average and + maximum diagonal and off diagonal absolute entrywise errors between + (B^p A) and I. + """ + max_diag_error: chex.Array = _default_zero_field() avg_diag_error: chex.Array = _default_zero_field() max_off_diag_error: chex.Array = _default_zero_field() @@ -201,35 +212,41 @@ class InversePthRootDiagnostics: def create(cls, pth_inverse_root, matrix, p): """Generates a diagnostics struct from (-1/p) root result.""" mat_m = jnp.matmul( - mat_power(pth_inverse_root, p), - matrix, - precision=jax.lax.Precision.HIGHEST) + mat_power(pth_inverse_root, p), + matrix, + precision=jax.lax.Precision.HIGHEST, + ) num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32) off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype( - jnp.float32) + jnp.float32 + ) return cls( - max_diag_error=jnp.max(diag_error).astype(jnp.float32), - avg_diag_error=jnp.mean(diag_error).astype(jnp.float32), - max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32), - avg_off_diag_error=(jnp.sum(off_diag_error) / - num_off_diag_entries).astype(jnp.float32), - p=jnp.array(p, jnp.float32)) + max_diag_error=jnp.max(diag_error).astype(jnp.float32), + avg_diag_error=jnp.mean(diag_error).astype(jnp.float32), + max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32), + avg_off_diag_error=( + jnp.sum(off_diag_error) / num_off_diag_entries + ).astype(jnp.float32), + p=jnp.array(p, jnp.float32), + ) LOBPCGDiagnosticsSubtype = TypeVar( - "LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics") + 'LOBPCGDiagnosticsSubtype', bound='LOBPCGDiagnostics' +) @struct.dataclass class LOBPCGDiagnostics: """Diagnostics for iterative LOBPCG eigenvalue routine. - Contains consistency error for LOBPCG eigenvalue routine, which - refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair - (v, lambda). This metics dataclass retains consistency error - and other useful LOBPCG values. - """ + Contains consistency error for LOBPCG eigenvalue routine, which + refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair + (v, lambda). This metics dataclass retains consistency error + and other useful LOBPCG values. + """ + lobpcg_iters: chex.Array = _default_zero_field() max_consistency_error: chex.Array = _default_zero_field() avg_consistency_error: chex.Array = _default_zero_field() @@ -248,7 +265,8 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters): mat_eigvecs = matrix.dot(eigvecs, precision=precision) consistency_error_unnormalized = jnp.linalg.norm( - mat_eigvecs - eigvals * eigvecs, ord=2, axis=0) + mat_eigvecs - eigvals * eigvecs, ord=2, axis=0 + ) normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals consistency_error = consistency_error_unnormalized / normalization @@ -256,20 +274,22 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters): orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error)) return cls( - lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32), - max_consistency_error=jnp.max(consistency_error).astype(jnp.float32), - avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32), - avg_orthogonality_error=(jnp.sum(orthogonality_error) / - num_off_diag).astype(jnp.float32), - max_eigenvalue=jnp.max(eigvals).astype(jnp.float32), - min_eigenvalue=jnp.min(eigvals).astype(jnp.float32), - num_topk_eigenvectors=jnp.array(num_topk, jnp.float32), + lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32), + max_consistency_error=jnp.max(consistency_error).astype(jnp.float32), + avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32), + avg_orthogonality_error=( + jnp.sum(orthogonality_error) / num_off_diag + ).astype(jnp.float32), + max_eigenvalue=jnp.max(eigvals).astype(jnp.float32), + min_eigenvalue=jnp.min(eigvals).astype(jnp.float32), + num_topk_eigenvectors=jnp.array(num_topk, jnp.float32), ) @struct.dataclass class TrainingMetrics: """Diagnostic metrics from training.""" + # Error for inverse-pth roots. inverse_pth_root_errors: chex.Array = _default_zero_field() # Iteration count for inverse-pth roots. @@ -283,20 +303,24 @@ class TrainingMetrics: total_retries: chex.Array = _default_zero_field() lobpcg_diagnostics: LOBPCGDiagnostics = struct.field( - default_factory=LOBPCGDiagnostics) + default_factory=LOBPCGDiagnostics + ) # Rich matrix entrywise error diagnostics, if enabled. inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field( - default_factory=InversePthRootDiagnostics) + default_factory=InversePthRootDiagnostics + ) # Diagnostics applied to the conditioned p-th root problem, after top # eigenvectors are removed, if LOBPCG is being applied. conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = ( - struct.field(default_factory=InversePthRootDiagnostics)) + struct.field(default_factory=InversePthRootDiagnostics) + ) # TODO(rohananil): Add more important metrics to track during training. # Per parameter optimizer state used in data-parallel training. class ParameterStats(NamedTuple): """State associated to each parameter of the model being trained.""" + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array) preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) @@ -321,12 +345,14 @@ class GlobalShardedParameterStats: @struct.dataclass class LocalShardedParameterStats: """State associated to each parameter of the model being trained.""" + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner momentum: QuantizedValue # Momentum for the shampoo preconditioner training_metrics: Union[TrainingMetrics, optax.MaskedNode] index_start: Union[np.int32, int] = struct.field( - pytree_node=False) # Index into global statistics array + pytree_node=False + ) # Index into global statistics array sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. @@ -336,39 +362,44 @@ def default_training_metrics(): def init_training_metrics( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ): """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() return jax.tree.map( - functools.partial(jnp.repeat, repeats=num_statistics), - default_training_metrics()) + functools.partial(jnp.repeat, repeats=num_statistics), + default_training_metrics(), + ) def init_training_metrics_shapes( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ): """Initialize training metrics shape/dtype.""" seed = init_training_metrics( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ) return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) -def init_training_metrics_pspec(generate_training_metrics,): +def init_training_metrics_pspec( + generate_training_metrics, +): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), - default_training_metrics()) + return jax.tree.map( + lambda _: jax.sharding.PartitionSpec(), default_training_metrics() + ) class ShardedShampooStats(NamedTuple): """Shampoo state in sharded mode.""" + global_stats: Any local_stats: Any @@ -406,35 +437,35 @@ class PreconditionerType(enum.IntEnum): def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - padding_start=None, + matrix, + num_iters=100, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + padding_start=None, ): r"""Power iteration algorithm. - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) - - Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - padding_start: if set, assumes rows and columns after padding_start are - zero. - - Returns: - eigen vector, eigen value - """ + The power iteration algorithm takes a symmetric PSD matrix `A`, and produces + a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue + of `A`, and a vector v, which is the corresponding eigenvector of `A`. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) + + Args: + matrix: the symmetric PSD matrix. + num_iters: Number of iterations. + error_tolerance: Iterative exit condition. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + padding_start: if set, assumes rows and columns after padding_start are + zero. + + Returns: + eigen vector, eigen value + """ matrix_size = matrix.shape[-1] def _iter_condition(state): @@ -446,32 +477,38 @@ def _iter_body(state): i, new_v, s, s_v, unused_run_step = state new_v = new_v / jnp.linalg.norm(new_v) - s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) - s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) - return (i + 1, - s_v, - s_new, - s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) + s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) + s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) + return ( + i + 1, + s_v, + s_new, + s_v, + jnp.greater(jnp.abs(s_new - s), error_tolerance), + ) # Figure out how to use step as seed for random. - v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0, - matrix_size).astype(matrix.dtype) + v_0 = ( + np.random.RandomState(1729) + .uniform(-1.0, 1.0, matrix_size) + .astype(matrix.dtype) + ) v_0 = jnp.array(v_0) if padding_start is not None: - v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start) + v_0 *= jnp.arange(len(v_0), dtype=jnp.int32) < padding_start init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, - init_state) + _, v_out, s_out, _, _ = lax.while_loop( + _iter_condition, _iter_body, init_state + ) v_out = v_out / jnp.linalg.norm(v_out) return v_out, s_out def mat_power( - mat_m, - p, - precision=lax.Precision.HIGHEST, + mat_m, + p, + precision=lax.Precision.HIGHEST, ): """A simple matrix power method. M^p where p can be TracedValue.""" power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE) @@ -483,9 +520,11 @@ def _iter_condition(state): def _iter_body(state): i, power, mat = state - power = jax.lax.cond(i % 2 == 1, - lambda: jnp.matmul(mat, power, precision=precision), - lambda: power) + power = jax.lax.cond( + i % 2 == 1, + lambda: jnp.matmul(mat, power, precision=precision), + lambda: power, + ) i //= 2 mat = jnp.matmul(mat, mat, precision=precision) return i, power, mat @@ -508,78 +547,81 @@ def _stable_subtract(b, a_minus_b): return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b)) return jnp.where( - # Choose the branch with the best log1p approximation. - jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), - -_stable_subtract(a, -a_minus_b), - _stable_subtract(b, a_minus_b)) + # Choose the branch with the best log1p approximation. + jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), + -_stable_subtract(a, -a_minus_b), + _stable_subtract(b, a_minus_b), + ) def matrix_inverse_pth_root( - matrix, - p, - num_iters=100, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - padding_start=None, - prev=None, - eigh=False, + matrix, + p, + num_iters=100, + ridge_epsilon=1e-6, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + relative_matrix_epsilon=True, + lobpcg_topk_precondition=0, + lobpcg_max_iter=0, + padding_start=None, + prev=None, + eigh=False, ): """Computes `matrix^(-1/p)`, where `p` is a positive integer. - This function uses the Eigh or Coupled newton iterations algorithm for - the computation of a matrix's inverse pth root. - - - References: - [Functions of Matrices, Theory and Computation, - Nicholas J Higham, Pg 184, Eq 7.18]( - https://epubs.siam.org/doi/book/10.1137/1.9780898717778) - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - num_iters: Maximum number of iterations. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to - `lobpcg_topk_precondition`. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - eigh: If True, uses eigh for inverse-pth root computation. - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ + This function uses the Eigh or Coupled newton iterations algorithm for + the computation of a matrix's inverse pth root. + + + References: + [Functions of Matrices, Theory and Computation, + Nicholas J Higham, Pg 184, Eq 7.18]( + https://epubs.siam.org/doi/book/10.1137/1.9780898717778) + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + num_iters: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + lobpcg_topk_precondition: If nonzero, specifies the number of top + eigenvectors to subtract out before performing LOBPCG. Note this makes + relative_matrix_epsilon essentially free. + lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to + `lobpcg_topk_precondition`. + padding_start: If the input matrix was padded, then zeros out columns and + rows at the padding start. + prev: previous iteration's solution, zero-padded (unused) + eigh: If True, uses eigh for inverse-pth root computation. + + Returns: + `(matrix + eps)^(-1/p)` and error metrics. + + Note `eps` is not added to zeroed out padding rows and + columns. `eps` is just `ridge_epsilon` if + `relative_matrix_epsilon` is set to `False`, otherwise, it is the + ridge epsilon value scaled by the derived maximum eigenvalue of + the input matrix. + """ if eigh: - return matrix_inverse_pth_root_eigh(matrix, - p, - ridge_epsilon, - error_tolerance, - precision, - relative_matrix_epsilon, - padding_start, - prev) + return matrix_inverse_pth_root_eigh( + matrix, + p, + ridge_epsilon, + error_tolerance, + precision, + relative_matrix_epsilon, + padding_start, + prev, + ) del prev assert matrix.shape[0] == matrix.shape[1] @@ -596,7 +638,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + matrix.dtype + ) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -607,18 +650,23 @@ def matrix_inverse_pth_root( eigvals, eigvecs, lobpcg_diagnostics = None, None, None if lobpcg_topk_precondition > 0: # TODO(vladf): reuse previous top-k as the initial search directions - pad_shape = (matrix_size - lobpcg_topk_precondition, - lobpcg_topk_precondition) + pad_shape = ( + matrix_size - lobpcg_topk_precondition, + lobpcg_topk_precondition, + ) search_dirs = jnp.concatenate( - (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0) + (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0 + ) eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard( # pylint: disable=unbalanced-tuple-unpacking - matrix, search_dirs, - lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter) + matrix, + search_dirs, + lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter, + ) lobpcg_diagnostics = LOBPCGDiagnostics.create( - matrix, - eigvals, - eigvecs, - lobpcg_iters, + matrix, + eigvals, + eigvecs, + lobpcg_iters, ) # The minimal eigenvalue among top-k becomes the maximal one in the whole @@ -628,7 +676,8 @@ def matrix_inverse_pth_root( # Deflate out top eigenvectors to reduce matrix condition number. matrix -= scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) + scaled_vecs.T, precision=jax.lax.Precision.HIGHEST + ) if relative_matrix_epsilon: if eigvals is not None: @@ -637,11 +686,12 @@ def matrix_inverse_pth_root( # Only use power iteration if lobpcg wasn't already used to derive the # top eigenvalue. _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=1e-6, - precision=precision, - padding_start=padding_start) + matrix=matrix, + num_iters=100, + error_tolerance=1e-6, + precision=precision, + padding_start=padding_start, + ) else: # Use absolute matrix epsilon scaling otherwise. max_ev = 1.0 @@ -654,8 +704,9 @@ def matrix_inverse_pth_root( def _iter_condition(state): i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state - error_above_threshold = jnp.logical_and(error > error_tolerance, - error_ratio < max_error_ratio) + error_above_threshold = jnp.logical_and( + error > error_tolerance, error_ratio < max_error_ratio + ) return jnp.logical_and(i < num_iters, error_above_threshold) def _iter_body(state): @@ -673,7 +724,6 @@ def _iter_body(state): iters = 0 error_ratio = 0.0 else: - retry_loop_error_threshold = 0.05 num_tries = 6 init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True]) @@ -691,23 +741,26 @@ def _outer_body_fn(state): new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0]) + [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0] + ) iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop( - _iter_condition, _iter_body, init_state) + _iter_condition, _iter_body, init_state + ) error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32) is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype) - resultant_mat_h = is_converged * \ - mat_h + (1 - is_converged) * old_mat_h - return (i + 1, - resultant_mat_h, - error, - iters, - error_ratio, - error > retry_loop_error_threshold) - - loop_outputs = jax.lax.while_loop(_outer_iter_condition_fn, - _outer_body_fn, - init_outer_state) + resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + return ( + i + 1, + resultant_mat_h, + error, + iters, + error_ratio, + error > retry_loop_error_threshold, + ) + + loop_outputs = jax.lax.while_loop( + _outer_iter_condition_fn, _outer_body_fn, init_outer_state + ) total_retries, resultant_mat_h, error, iters, error_ratio, _ = loop_outputs conditioned_resultant_mat = resultant_mat_h @@ -723,35 +776,39 @@ def _outer_body_fn(state): pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p) scaled_vecs = eigvecs * jnp.sqrt(pth_diff) resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) + scaled_vecs.T, precision=jax.lax.Precision.HIGHEST + ) error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32), - inverse_pth_root_iters=jnp.array(iters, jnp.float32), - final_error_ratio=jnp.array(error_ratio, jnp.float32), - max_eigen_value=jnp.array(max_ev, jnp.float32), - total_retries=jnp.array(total_retries, jnp.float32)) + inverse_pth_root_errors=jnp.array(error, jnp.float32), + inverse_pth_root_iters=jnp.array(iters, jnp.float32), + final_error_ratio=jnp.array(error_ratio, jnp.float32), + max_eigen_value=jnp.array(max_ev, jnp.float32), + total_retries=jnp.array(total_retries, jnp.float32), + ) if lobpcg_topk_precondition > 0: - damped_matrix = matrix + \ - (ridge_epsilon * (10**total_retries) * identity) + damped_matrix = matrix + (ridge_epsilon * (10**total_retries) * identity) conditioned_diagnostics = InversePthRootDiagnostics.create( - conditioned_resultant_mat, damped_matrix, p) + conditioned_resultant_mat, damped_matrix, p + ) unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity unconditioned_diagnostics = InversePthRootDiagnostics.create( - resultant_mat_h, unconditioned_damped_matrix, p) + resultant_mat_h, unconditioned_damped_matrix, p + ) # The max entrywise error in error_metrics.inverse_pth_root_errors refers # to what was derived from the inverse pth root iteration, which with # LOBPCG refers to the conditioned problem. Make sure to use the error # from the unconditioned problem. unconditional_errors = jnp.maximum( - unconditioned_diagnostics.max_diag_error, - unconditioned_diagnostics.max_off_diag_error) + unconditioned_diagnostics.max_diag_error, + unconditioned_diagnostics.max_off_diag_error, + ) error_metrics = error_metrics.replace( - inverse_pth_root_errors=unconditional_errors, - lobpcg_diagnostics=lobpcg_diagnostics, - conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics, - inverse_pth_root_diagnostics=unconditioned_diagnostics, + inverse_pth_root_errors=unconditional_errors, + lobpcg_diagnostics=lobpcg_diagnostics, + conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics, + inverse_pth_root_diagnostics=unconditioned_diagnostics, ) if padding_start is not None: @@ -759,9 +816,9 @@ def _outer_body_fn(state): # due to some TPU hosts not having the same number of preconditioning # matrices. resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) + error = jnp.where( + padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors + ) error_metrics = error_metrics.replace(inverse_pth_root_errors=error) resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype) @@ -769,44 +826,44 @@ def _outer_body_fn(state): def matrix_inverse_pth_root_eigh( - matrix, - p, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - padding_start=None, - prev=None, + matrix, + p, + ridge_epsilon=1e-6, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + relative_matrix_epsilon=True, + padding_start=None, + prev=None, ): """Computes `matrix^(-1/p)`, where `p` is a positive integer. - This function uses eigh for the computation of a matrix's inverse pth - root. - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ + This function uses eigh for the computation of a matrix's inverse pth + root. + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + padding_start: If the input matrix was padded, then zeros out columns and + rows at the padding start. + prev: previous iteration's solution, zero-padded (unused) + + Returns: + `(matrix + eps)^(-1/p)` and error metrics. + + Note `eps` is not added to zeroed out padding rows and + columns. `eps` is just `ridge_epsilon` if + `relative_matrix_epsilon` is set to `False`, otherwise, it is the + ridge epsilon value scaled by the derived maximum eigenvalue of + the input matrix. + """ del prev assert matrix.shape[0] == matrix.shape[1] matrix_size = matrix.shape[0] @@ -816,17 +873,19 @@ def matrix_inverse_pth_root_eigh( identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + matrix.dtype + ) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix if relative_matrix_epsilon: _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=error_tolerance, - precision=precision, - padding_start=padding_start) + matrix=matrix, + num_iters=100, + error_tolerance=error_tolerance, + precision=precision, + padding_start=padding_start, + ) else: # Use absolute matrix epsilon scaling otherwise. max_ev = 1.0 @@ -837,9 +896,9 @@ def matrix_inverse_pth_root_eigh( if padding_start is not None: e *= jnp.flip(ix) mm = functools.partial(jnp.matmul, precision=precision) - inv_e = jnp.where(e == 0.0, - 0.0, - jnp.power(jnp.maximum(e, ridge_epsilon), alpha)) + inv_e = jnp.where( + e == 0.0, 0.0, jnp.power(jnp.maximum(e, ridge_epsilon), alpha) + ) val = mm(mm(u, jnp.diag(inv_e)), u.T) root = u * jnp.sqrt(inv_e) val = mm(root, root.T) @@ -849,12 +908,13 @@ def matrix_inverse_pth_root_eigh( eig_error *= jnp.flip(ix) error = jnp.max(jnp.abs(eig_error)) error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32)) + inverse_pth_root_errors=jnp.array(error, jnp.float32) + ) if padding_start is not None: val = jnp.where(padding_start == 0, 0.0, val) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) + error = jnp.where( + padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors + ) error_metrics = error_metrics.replace(inverse_pth_root_errors=error) val = jnp.asarray(val, orig_dtype) return val, error_metrics @@ -863,17 +923,17 @@ def matrix_inverse_pth_root_eigh( def merge_small_dims(shape_to_merge, max_dim): """Merge small dimensions. - If there are some small dimensions, we collapse them: - e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 - [1, 2, 768, 1, 2048] --> [2, 768, 2048] + If there are some small dimensions, we collapse them: + e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 + [1, 2, 768, 1, 2048] --> [2, 768, 2048] - Args: - shape_to_merge: Shape to merge small dimensions. - max_dim: Maximal dimension of output shape used in merging. + Args: + shape_to_merge: Shape to merge small dimensions. + max_dim: Maximal dimension of output shape used in merging. - Returns: - Merged shape. - """ + Returns: + Merged shape. + """ if shape_to_merge and np.all(np.array(shape_to_merge) == 1): return [1] @@ -894,20 +954,23 @@ def merge_small_dims(shape_to_merge, max_dim): def pad_square_matrix(mat, max_size): """Pad a square matrix up to max_size. - Args: - mat: a matrix to pad. - max_size: matrix size requested. + Args: + mat: a matrix to pad. + max_size: matrix size requested. - Returns: - Given M returns [[M, 0], [0, I]] - """ + Returns: + Given M returns [[M, 0], [0, I]] + """ rows, cols = mat.shape if rows != cols: - raise ValueError("Must have rows == cols, instead got " - f"rows={rows}, cols={cols}") + raise ValueError( + f'Must have rows == cols, instead got rows={rows}, cols={cols}' + ) if cols > max_size: - raise ValueError("Must have cols <= max_size. Instead got " - f"cols={cols}, max_size={max_size}.") + raise ValueError( + 'Must have cols <= max_size. Instead got ' + f'cols={cols}, max_size={max_size}.' + ) if rows == max_size: return mat pad_size = max_size - rows @@ -923,13 +986,13 @@ def pad_square_matrix(mat, max_size): def pad_vector(vec, max_size): """Pad a vector to a max_size. - Args: - vec: a vector to pad. - max_size: matrix size requested. + Args: + vec: a vector to pad. + max_size: matrix size requested. - Returns: - Given V returns [V, 0] - """ + Returns: + Given V returns [V, 0] + """ size = vec.shape[0] assert size <= max_size if size == max_size: @@ -949,9 +1012,9 @@ def _iter_body(unused_state): def _iter_condition(state): return state[0] - results = jax.lax.while_loop(_iter_condition, - _iter_body, - tuple([predicate] + init_state)) + results = jax.lax.while_loop( + _iter_condition, _iter_body, tuple([predicate] + init_state) + ) return tuple(results[1:]) @@ -985,7 +1048,7 @@ def partition(self, tensor): assert tensor.shape == self._shape tensors = [tensor] - for (i, indices) in self._splits: + for i, indices in self._splits: tensors_local = [] for t in tensors: tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) @@ -995,13 +1058,14 @@ def partition(self, tensor): def merge_partitions(self, partitions): """Merge partitions back to original shape.""" - for (i, indices) in reversed(self._splits): + for i, indices in reversed(self._splits): n = len(indices) + 1 partial_merged_tensors = [] ind = 0 while ind < len(partitions): partial_merged_tensors.append( - jnp.concatenate(partitions[ind:ind + n], axis=i)) + jnp.concatenate(partitions[ind : ind + n], axis=i) + ) ind += n partitions = partial_merged_tensors assert len(partitions) == 1 @@ -1011,25 +1075,25 @@ def merge_partitions(self, partitions): def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None): """Updated statistics via weighted average with new Gram matrix. - Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose - columns are the flattened slices of the tensor `g` along the given `axis`. - (So, `old_stats` and the returned matrix have dimensions n x n where - n = `g.shape[axis]`). - - Args: - old_stats: Old statistics. - g: Gradient tensor. - axis: Axis along which to slice `g`. - w1: Scalar weight for old statistics. - w2: Scalar weight for new Gram matrix. - precision: Optional precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - Weighted average of old and new statistics. - """ + Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose + columns are the flattened slices of the tensor `g` along the given `axis`. + (So, `old_stats` and the returned matrix have dimensions n x n where + n = `g.shape[axis]`). + + Args: + old_stats: Old statistics. + g: Gradient tensor. + axis: Axis along which to slice `g`. + w1: Scalar weight for old statistics. + w2: Scalar weight for new Gram matrix. + precision: Optional precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + + Returns: + Weighted average of old and new statistics. + """ axes = [i for i in range(g.ndim) if i != axis] gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision) return w1 * old_stats + w2 * gram_matrix @@ -1039,67 +1103,68 @@ class Preconditioner: """Compute statistics/shape from gradients for preconditioning.""" def __init__( - self, - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - preconditioner_type=PreconditionerType.ALL, + self, + param, + block_size, + merge_small_dims_block_size, + best_effort_shape_interpretation, + preconditioner_type=PreconditionerType.ALL, ): """Initializes the preconditioner. - Args: - param: parameter to precondition. - block_size: Block size used to split param. - merge_small_dims_block_size: Block size for merging dims. - best_effort_shape_interpretation: Whether to - collapse/merge dims together. - preconditioner_type: Type of preconditioner to use. - """ + Args: + param: parameter to precondition. + block_size: Block size used to split param. + merge_small_dims_block_size: Block size for merging dims. + best_effort_shape_interpretation: Whether to + collapse/merge dims together. + preconditioner_type: Type of preconditioner to use. + """ self._original_shape = param.shape self._transformed_shape = param.shape if best_effort_shape_interpretation: - self._transformed_shape = merge_small_dims(self._original_shape, - merge_small_dims_block_size) + self._transformed_shape = merge_small_dims( + self._original_shape, merge_small_dims_block_size + ) reshaped_param = jnp.reshape(param, self._transformed_shape) self._partitioner = BlockPartitioner(reshaped_param, block_size) self._preconditioner_type = preconditioner_type def updated_statistics_from_grad( - self, - stats, - grad, - w1, - w2, - to_float=None, - from_float=None, - precision=None, + self, + stats, + grad, + w1, + w2, + to_float=None, + from_float=None, + precision=None, ): """Update statistics from gradients. - Args: - stats: Old statistics or its Cholesky factor if `cholesky` is True. - grad: Gradient to compute statistics from. - w1: Weight for old statistics. - w2: Weight for new statistics. - to_float: Optional function for converting stats to floating point. - from_float: Optional function for converting from floating point. - precision: Optional precision XLA related flag, the available options - are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - A list of updated gradient statistics for each partition. - """ + Args: + stats: Old statistics or its Cholesky factor if `cholesky` is True. + grad: Gradient to compute statistics from. + w1: Weight for old statistics. + w2: Weight for new statistics. + to_float: Optional function for converting stats to floating point. + from_float: Optional function for converting from floating point. + precision: Optional precision XLA related flag, the available options + are: + a) lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + + Returns: + A list of updated gradient statistics for each partition. + """ to_float = to_float if to_float is not None else (lambda x: x) from_float = from_float if from_float is not None else (lambda x: x) reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) should_preconditioned_dims = self.should_precondition_dims() preconditioned_dims = [ - i for i, p in enumerate(should_preconditioned_dims) if p + i for i, p in enumerate(should_preconditioned_dims) if p ] new_stats = [] index = 0 @@ -1136,8 +1201,7 @@ def _preconds_for_grad(self, preconditioners, rank, start, end): elif self._preconditioner_type == PreconditionerType.OUTPUT: # When _preconditioner_type is OUTPUT, we append (rank - 1) many None # values to the beginning of the list to handle the False indices. - preconditioners_for_grad = [None] * \ - (rank - 1) + preconditioners_for_grad + preconditioners_for_grad = [None] * (rank - 1) + preconditioners_for_grad assert len(preconditioners_for_grad) == rank return preconditioners_for_grad @@ -1165,13 +1229,13 @@ def exponent_for_preconditioner(self): def preconditioned_grad(self, grad, preconditioners): """Precondition the gradient. - Args: - grad: A gradient tensor to precondition. - preconditioners: A list of preconditioners to apply. + Args: + grad: A gradient tensor to precondition. + preconditioners: A list of preconditioners to apply. - Returns: - A preconditioned gradient. - """ + Returns: + A preconditioned gradient. + """ reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) should_preconditioned_dims = self.should_precondition_dims() @@ -1179,17 +1243,18 @@ def preconditioned_grad(self, grad, preconditioners): preconditioned_partitioned_grads = [] for i, g in enumerate(partitioned_grads): preconditioners_for_grad = self._preconds_for_grad( - preconditioners, - rank=len(should_preconditioned_dims), - start=i * num_preconditioners, - end=(i + 1) * num_preconditioners, + preconditioners, + rank=len(should_preconditioned_dims), + start=i * num_preconditioners, + end=(i + 1) * num_preconditioners, + ) + precond_g = self._precondition_block( + g, should_preconditioned_dims, preconditioners_for_grad ) - precond_g = self._precondition_block(g, - should_preconditioned_dims, - preconditioners_for_grad) preconditioned_partitioned_grads.append(precond_g) merged_grad = self._partitioner.merge_partitions( - preconditioned_partitioned_grads) + preconditioned_partitioned_grads + ) return jnp.reshape(merged_grad, self._original_shape) def _precondition_block(self, g, should_precondition_dim, preconditioners): @@ -1208,9 +1273,9 @@ def _precondition_block(self, g, should_precondition_dim, preconditioners): return g -def _convert_to_parameter_stats(global_stats, - local_stat, - convert_statistics=True): +def _convert_to_parameter_stats( + global_stats, local_stat, convert_statistics=True +): """Creates parameter stats from sharded stats.""" index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start @@ -1225,24 +1290,24 @@ def _convert_to_parameter_stats(global_stats, if not convert_statistics: new_statistics = None return ParameterStats( - local_stat.diagonal_statistics, - new_statistics, - new_preconditioners, - local_stat.diagonal_momentum, - local_stat.momentum, - local_stat.training_metrics, + local_stat.diagonal_statistics, + new_statistics, + new_preconditioners, + local_stat.diagonal_momentum, + local_stat.momentum, + local_stat.training_metrics, ) def _convert_from_parameter_stats(parameter_stats, local_stats): """Creates sharded stats from paramter stats.""" return LocalShardedParameterStats( - parameter_stats.diagonal_statistics, - parameter_stats.diagonal_momentum, - parameter_stats.momentum, - parameter_stats.training_metrics, - local_stats.index_start, - local_stats.sizes, + parameter_stats.diagonal_statistics, + parameter_stats.diagonal_momentum, + parameter_stats.momentum, + parameter_stats.training_metrics, + local_stats.index_start, + local_stats.sizes, ) @@ -1258,12 +1323,13 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured # values). - per_stat_metrics = efficient_cond(keep_old, - lambda: [local_stat.training_metrics], - [per_stat_metrics])[0] + per_stat_metrics = efficient_cond( + keep_old, lambda: [local_stat.training_metrics], [per_stat_metrics] + )[0] # pylint:enable=cell-var-from-loop new_local_stats.append( - local_stat.replace(training_metrics=per_stat_metrics)) + local_stat.replace(training_metrics=per_stat_metrics) + ) return new_local_stats @@ -1271,7 +1337,7 @@ def batch(x, num_devices): """Batch `x` so that so that leading axis is num_devices.""" n = len(x) b = int(n / num_devices) - return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)]) + return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)]) def unbatch(batched_values): @@ -1290,162 +1356,168 @@ def unbatch(batched_values): def distributed_shampoo( - learning_rate, - block_size=1024, - beta1=0.9, - beta2=0.999, - diagonal_epsilon=1e-8, - matrix_epsilon=1e-6, - weight_decay=0.0, - start_preconditioning_step=101, - preconditioning_compute_steps=20, - statistics_compute_steps=1, - best_effort_shape_interpretation=True, - graft_type=GraftingType.RMSPROP_NORMALIZED, - nesterov=True, - exponent_override=0, - # Pass pmap 'batch axis name' in pmap mode. - batch_axis_name=None, - # Only set following 3 params in pjit/spmd mode. - # WARNING: Experimental - statistics_partition_spec=None, - preconditioner_partition_spec=None, - num_devices_for_pjit=None, - shard_optimizer_states=False, - ### - # Experimental memory reduction mode - best_effort_memory_usage_reduction=True, - ### - inverse_failure_threshold=0.1, - moving_average_for_momentum=True, - skip_preconditioning_dim_size_gt=0, - clip_by_scaled_gradient_norm=None, - precision=lax.Precision.HIGHEST, - tensordot_precision=None, - relative_matrix_epsilon=True, - merge_small_dims_block_size=4096, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - precondtioner_type=PreconditionerType.ALL, - custom_preconditioner=False, - skip_preconditioning_rank_lt=1, - decoupled_learning_rate=True, - decoupled_weight_decay=False, - generate_training_metrics=True, - reuse_preconditioner=False, - eigh=True, + learning_rate, + block_size=1024, + beta1=0.9, + beta2=0.999, + diagonal_epsilon=1e-8, + matrix_epsilon=1e-6, + weight_decay=0.0, + start_preconditioning_step=101, + preconditioning_compute_steps=20, + statistics_compute_steps=1, + best_effort_shape_interpretation=True, + graft_type=GraftingType.RMSPROP_NORMALIZED, + nesterov=True, + exponent_override=0, + # Pass pmap 'batch axis name' in pmap mode. + batch_axis_name=None, + # Only set following 3 params in pjit/spmd mode. + # WARNING: Experimental + statistics_partition_spec=None, + preconditioner_partition_spec=None, + num_devices_for_pjit=None, + shard_optimizer_states=False, + ### + # Experimental memory reduction mode + best_effort_memory_usage_reduction=True, + ### + inverse_failure_threshold=0.1, + moving_average_for_momentum=True, + skip_preconditioning_dim_size_gt=0, + clip_by_scaled_gradient_norm=None, + precision=lax.Precision.HIGHEST, + tensordot_precision=None, + relative_matrix_epsilon=True, + merge_small_dims_block_size=4096, + lobpcg_topk_precondition=0, + lobpcg_max_iter=0, + precondtioner_type=PreconditionerType.ALL, + custom_preconditioner=False, + skip_preconditioning_rank_lt=1, + decoupled_learning_rate=True, + decoupled_weight_decay=False, + generate_training_metrics=True, + reuse_preconditioner=False, + eigh=True, ): """Distributed Shampoo optimizer. - Distributed Shampoo is a second-order preconditioned method (concretely, a - variant of full-matrix Adagrad), that provides significant convergence and - wall-clock time improvements compared to conventional first-order methods, - and that has been shown to scale to large state-of-the-art deep learning - models. - - References: - Scalable Second Order Optimization for Deep Learning, - Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer - - Preprint: https://arxiv.org/abs/2002.09018 - - Args: - learning_rate: the step size used to update the parameters. - block_size: Block size for large layers (if > 0). Preconditioning compute - operation is cubic in the dimension of the tensor. Block size allows us - to chunk the layers into sub-layers of maximal dimension dictated by - this value. Use 128 as default (increase if you have compute budget). - beta1: momentum parameter. - beta2: second moment averaging parameter. - diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting - to AdaGrad is enabled). - matrix_epsilon: epsilon to add to statistics before computing inverse pth - root. If you are running in f32 precision for inverse pth root - (recommended today) this can go upto 1e-6. If you have latest hardware - with native f64 precision, set this upto 1e-12. - weight_decay: Weight decay for regularization. - start_preconditioning_step: When to start Shampoo update before which - diagonal update is used. This is because we dont have enough information - to do stable inverse. - preconditioning_compute_steps: How often to compute preconditioner. - Performance tuning params for controlling memory and compute - requirements. - Ideally set this and statistics_compute_steps params to 1. - statistics_compute_steps: How often to compute statistics. - best_effort_shape_interpretation: If there are some small dimensions, - collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if - block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] - graft_type: Grafting is a technique to fix the layerwise scale of Shampoo - optimizer. This allows us to plugin the Shampoo optimizer into settings - where SGD/AdaGrad is already well tuned. - nesterov: Nesterov momentum. - exponent_override: Override the exponent used in matrix inverse. - batch_axis_name: labeled axis over pmap for data-parallel training the - optimizer used for. - statistics_partition_spec: PartitionSpec to be used in sharded mode. - preconditioner_partition_spec: PartitionSpec to be used in sharded mode. - num_devices_for_pjit: Number of devices to parallelize over when using - pjit. - shard_optimizer_states: Shard optimizer states to save memory in model - parallel training. - best_effort_memory_usage_reduction: Best effort memory usage reduction. - - diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) - -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals - inverse_failure_threshold: numerics are hard and inverses fail sometimes; - we determine that using this threshold. - moving_average_for_momentum: Whether to use moving average for momentum - instead of exponential moving average. - skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is - greater than this value. - clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful - when using RMSProp Grafting). - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - tensordot_precision: Optional precision to use for the tensordot operation - when computing statistics (e.g., G Gᵀ). Same options as `precision` - above. - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - merge_small_dims_block_size: Used as the maximum block size to merge the - shapes. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to - `lobpcg_topk_precondition`. - precondtioner_type: Preconditioner type to select all, left only or right - only preconditioners. - skip_preconditioning_rank_lt: Skips preconditioning for parameters with - rank less than this value. - decoupled_learning_rate: If True, use decoupled learning rate, otherwise - couple it with preconditioned gradient computation. (Default True) - decoupled_weight_decay: If True, use decoupled weight decay, otherwise - couple with weight decay. (Default False) - generate_training_metrics: If True, gather training metrics, otherwise - avoid generating them (to reduce memory usage). - reuse_preconditioner: If True, pass the previous derived preconditioner - as a warm start to the next iteratin's inverse pth root computation. - eigh: If True, and uses eigen decomposition for inverse-pth root. - - Returns: - a GradientTransformation. - """ + Distributed Shampoo is a second-order preconditioned method (concretely, a + variant of full-matrix Adagrad), that provides significant convergence and + wall-clock time improvements compared to conventional first-order methods, + and that has been shown to scale to large state-of-the-art deep learning + models. + + References: + Scalable Second Order Optimization for Deep Learning, + Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer + + Preprint: https://arxiv.org/abs/2002.09018 + + Args: + learning_rate: the step size used to update the parameters. + block_size: Block size for large layers (if > 0). Preconditioning compute + operation is cubic in the dimension of the tensor. Block size allows us + to chunk the layers into sub-layers of maximal dimension dictated by + this value. Use 128 as default (increase if you have compute budget). + beta1: momentum parameter. + beta2: second moment averaging parameter. + diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting + to AdaGrad is enabled). + matrix_epsilon: epsilon to add to statistics before computing inverse pth + root. If you are running in f32 precision for inverse pth root + (recommended today) this can go upto 1e-6. If you have latest hardware + with native f64 precision, set this upto 1e-12. + weight_decay: Weight decay for regularization. + start_preconditioning_step: When to start Shampoo update before which + diagonal update is used. This is because we dont have enough information + to do stable inverse. + preconditioning_compute_steps: How often to compute preconditioner. + Performance tuning params for controlling memory and compute + requirements. + Ideally set this and statistics_compute_steps params to 1. + statistics_compute_steps: How often to compute statistics. + best_effort_shape_interpretation: If there are some small dimensions, + collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if + block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] + graft_type: Grafting is a technique to fix the layerwise scale of Shampoo + optimizer. This allows us to plugin the Shampoo optimizer into settings + where SGD/AdaGrad is already well tuned. + nesterov: Nesterov momentum. + exponent_override: Override the exponent used in matrix inverse. + batch_axis_name: labeled axis over pmap for data-parallel training the + optimizer used for. + statistics_partition_spec: PartitionSpec to be used in sharded mode. + preconditioner_partition_spec: PartitionSpec to be used in sharded mode. + num_devices_for_pjit: Number of devices to parallelize over when using + pjit. + shard_optimizer_states: Shard optimizer states to save memory in model + parallel training. + best_effort_memory_usage_reduction: Best effort memory usage reduction. - + diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) + -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals + inverse_failure_threshold: numerics are hard and inverses fail sometimes; + we determine that using this threshold. + moving_average_for_momentum: Whether to use moving average for momentum + instead of exponential moving average. + skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is + greater than this value. + clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful + when using RMSProp Grafting). + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + tensordot_precision: Optional precision to use for the tensordot operation + when computing statistics (e.g., G Gᵀ). Same options as `precision` + above. + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + merge_small_dims_block_size: Used as the maximum block size to merge the + shapes. + lobpcg_topk_precondition: If nonzero, specifies the number of top + eigenvectors to subtract out before performing LOBPCG. Note this makes + relative_matrix_epsilon essentially free. + lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to + `lobpcg_topk_precondition`. + precondtioner_type: Preconditioner type to select all, left only or right + only preconditioners. + skip_preconditioning_rank_lt: Skips preconditioning for parameters with + rank less than this value. + decoupled_learning_rate: If True, use decoupled learning rate, otherwise + couple it with preconditioned gradient computation. (Default True) + decoupled_weight_decay: If True, use decoupled weight decay, otherwise + couple with weight decay. (Default False) + generate_training_metrics: If True, gather training metrics, otherwise + avoid generating them (to reduce memory usage). + reuse_preconditioner: If True, pass the previous derived preconditioner + as a warm start to the next iteratin's inverse pth root computation. + eigh: If True, and uses eigen decomposition for inverse-pth root. + + Returns: + a GradientTransformation. + """ reset_frequency = None def _graft_type_has_diagonal_statistics(): """Returns True if using diagonal firt order method for grafting.""" return graft_type not in [ - GraftingType.SGD, GraftingType.SQRT_N, GraftingType.NONE + GraftingType.SGD, + GraftingType.SQRT_N, + GraftingType.NONE, ] def quantized_dtype_for_momentum_buffers(var): - return jnp.int8 if best_effort_memory_usage_reduction and len( - var.shape) > 1 else jnp.float32 + return ( + jnp.int8 + if best_effort_memory_usage_reduction and len(var.shape) > 1 + else jnp.float32 + ) quantize_second_moment = ( - best_effort_memory_usage_reduction and batch_axis_name) + best_effort_memory_usage_reduction and batch_axis_name + ) # Preconditioner and statistics are both stores as int16 in this mode. # We take out the diagonal to make quantization easier. @@ -1472,19 +1544,20 @@ def _to_float(maybe_quantized): def preconditioner_from_params(param): """Returns a Preconditioner object for given param.""" return Preconditioner( - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - precondtioner_type, + param, + block_size, + merge_small_dims_block_size, + best_effort_shape_interpretation, + precondtioner_type, ) def precond_dim(max_size): """Derives largest preconditioner dimension.""" return max_size - def pad_and_maybe_zero_preconditioners(preconditioners, total, max_size, - step): + def pad_and_maybe_zero_preconditioners( + preconditioners, total, max_size, step + ): """Pad preconditioners up to total x max_size x precond_dim(max_size).""" pd = precond_dim(max_size) @@ -1513,9 +1586,9 @@ def _pad_preconditioner(preconditioner): def sharded_init_fn(params): """Returns optimizer state (for PJIT mode). - Args: - params: the parameters that should be updated. - """ + Args: + params: the parameters that should be updated. + """ params_flat, treedef = jax.tree_util.tree_flatten(params) # Find max size to pad to. max_size = 0 @@ -1542,21 +1615,22 @@ def sharded_init_fn(params): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() statistics = [ - matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) - for s in shapes + matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) for s in shapes ] pd = precond_dim(max_size) # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. preconditioners = [ - jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size) - for s in shapes + jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size) + for s in shapes ] padded_statistics.extend(statistics) padded_preconditioners.extend(preconditioners) exponent = ( - preconditioner.exponent_for_preconditioner() - if exponent_override == 0 else exponent_override) + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) exponents.extend([exponent] * len(shapes)) diagonal_statistics = jnp.zeros_like(param) @@ -1564,16 +1638,18 @@ def sharded_init_fn(params): momentum = jnp.zeros_like(param) local_stats_flat.append( - LocalShardedParameterStats( - diagonal_statistics, - diagonal_momentum, - momentum, - init_training_metrics( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes)) + LocalShardedParameterStats( + diagonal_statistics, + diagonal_momentum, + momentum, + init_training_metrics( + len(sizes), + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) to_pad = -len(padded_statistics) % num_devices_for_pjit @@ -1588,22 +1664,27 @@ def sharded_init_fn(params): # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) + [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] + ) pd = precond_dim(max_size) # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. - padded_preconditioners.extend([ + padded_preconditioners.extend( + [ jnp.eye(max_size, pd, dtype=stat_dtype) * (pd == max_size) for _ in range(to_pad) - ]) + ] + ) exponents.extend([1 for _ in range(to_pad)]) global_stats = GlobalShardedParameterStats( - jnp.stack(padded_statistics), - jnp.stack(padded_preconditioners), - jnp.stack(exponents)) + jnp.stack(padded_statistics), + jnp.stack(padded_preconditioners), + jnp.stack(exponents), + ) return ShampooState( - count=jnp.zeros([], jnp.int32), - stats=ShardedShampooStats(global_stats, local_stats)) + count=jnp.zeros([], jnp.int32), + stats=ShardedShampooStats(global_stats, local_stats), + ) def _max_statistics_size_from_params(params): max_size = 0 @@ -1624,20 +1705,21 @@ def _remove_leading_sharding_annotation(pspec): else: return [] - def sharded_init_partition_spec_fn(params, - params_partition_spec, - partition_spec_for_statistics): + def sharded_init_partition_spec_fn( + params, params_partition_spec, partition_spec_for_statistics + ): """Returns a parallel state tree with PartitionSpec associated with state. - Args: - params: A pytree with params. - params_partition_spec: A pytree with PartitionSpec for params. - partition_spec_for_statistics: PartitionSpec for the statistics. - """ + Args: + params: A pytree with params. + params_partition_spec: A pytree with PartitionSpec for params. + partition_spec_for_statistics: PartitionSpec for the statistics. + """ # Parallel lists of spec, and params. param_pspec_flat, _ = jax.tree_util.tree_flatten( - params_partition_spec, is_leaf=lambda x: x is None) + params_partition_spec, is_leaf=lambda x: x is None + ) params_flat, treedef = jax.tree_util.tree_flatten(params) assert param_pspec_flat assert params_flat @@ -1667,48 +1749,57 @@ def sharded_init_partition_spec_fn(params, m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - param_pspec, - [], - [], - jnp.float32, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m1_pspec, - [], - m1_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m2_pspec, - [], - m2_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - init_training_metrics_pspec(generate_training_metrics,), - index_start, - sizes)) + LocalShardedParameterStats( + QuantizedValue( + param_pspec, + [], + [], + jnp.float32, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + QuantizedValue( + m1_pspec, + [], + m1_scale_pspec, + qdtype, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + QuantizedValue( + m2_pspec, + [], + m2_scale_pspec, + qdtype, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + init_training_metrics_pspec( + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) - global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, - partition_spec_for_statistics, - jax.sharding.PartitionSpec()) + global_stats = GlobalShardedParameterStats( + partition_spec_for_statistics, + partition_spec_for_statistics, + jax.sharding.PartitionSpec(), + ) count_pspec = jax.sharding.PartitionSpec() return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=count_pspec, - stats=ShardedShampooStats(global_stats, local_stats)) + count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats) + ) def sharded_init_shape_and_dtype_fn(params): """Returns a parallel state tree with shape, dtype associated with state. - Args: - params: A pytree with params. - """ + Args: + params: A pytree with params. + """ # Parallel lists of spec, and params. params_flat, treedef = jax.tree_util.tree_flatten(params) assert params_flat @@ -1739,31 +1830,39 @@ def sharded_init_shape_and_dtype_fn(params): diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - diagonal_statistics_shape_and_dtype, - [], - [], # pytype: disable=wrong-arg-types # numpy-scalars - jnp.float32, - False, - list(param.shape)), - QuantizedValue(m1_shape_and_dtype, [], - m1_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - QuantizedValue(m2_shape_and_dtype, [], - m2_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - init_training_metrics_shapes( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes, - )) + LocalShardedParameterStats( + QuantizedValue( + diagonal_statistics_shape_and_dtype, + [], + [], # pytype: disable=wrong-arg-types # numpy-scalars + jnp.float32, + False, + list(param.shape), + ), + QuantizedValue( + m1_shape_and_dtype, + [], + m1_scale_shape_and_dtype, + qdtype, + False, + list(param.shape), + ), + QuantizedValue( + m2_shape_and_dtype, + [], + m2_scale_shape_and_dtype, + qdtype, + False, + list(param.shape), + ), + init_training_metrics_shapes( + len(sizes), + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) max_statistics_size = _max_statistics_size_from_params(params_flat) @@ -1773,29 +1872,36 @@ def sharded_init_shape_and_dtype_fn(params): num_statistics = num_devices_for_pjit max_statistics_size = block_size statistics_shape = [ - num_statistics, max_statistics_size, max_statistics_size + num_statistics, + max_statistics_size, + max_statistics_size, ] preconditioners_shape = [ - num_statistics, max_statistics_size, precond_dim(max_statistics_size) + num_statistics, + max_statistics_size, + precond_dim(max_statistics_size), ] global_stats = GlobalShardedParameterStats( - [statistics_shape, jnp.float32], [preconditioners_shape, jnp.float32], - [[num_statistics], jnp.int32]) + [statistics_shape, jnp.float32], + [preconditioners_shape, jnp.float32], + [[num_statistics], jnp.int32], + ) return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=[[], jnp.float32], - stats=ShardedShampooStats(global_stats, local_stats)) + count=[[], jnp.float32], + stats=ShardedShampooStats(global_stats, local_stats), + ) def sharded_update_fn(grads, state, params): """Transform the input gradient and update all statistics in sharded mode. - Args: - grads: the gradient tensors for the parameters. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. + Args: + grads: the gradient tensors for the parameters. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. - Returns: - A tuple containing the new parameters and the new optimizer state. - """ + Returns: + A tuple containing the new parameters and the new optimizer state. + """ params_flat, treedef = jax.tree_util.tree_flatten(params) grads_flat = treedef.flatten_up_to(grads) @@ -1803,43 +1909,45 @@ def sharded_update_fn(grads, state, params): local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) stats_flat = [] for local_stat in local_stats_flat: - stats_flat.append(_convert_to_parameter_stats( + stats_flat.append( + _convert_to_parameter_stats( global_stats, local_stat, - )) + ) + ) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - grads_flat, - stats_flat, - params_flat) + lambda g, s, p: _compute_stats(g, s, p, state.count), + grads_flat, + stats_flat, + params_flat, + ) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_util.tree_unflatten(treedef, updates_flat) new_local_stats_flat = [] for new_stat, local_stat in zip(new_stats_flat, local_stats_flat): new_local_stats_flat.append( - _convert_from_parameter_stats( - new_stat, - local_stat, - )) + _convert_from_parameter_stats( + new_stat, + local_stat, + ) + ) max_size = global_stats.statistics.shape[1] new_padded_statistics = [] padding_starts = [] for stat in new_stats_flat: new_padded_statistics.extend( - [pad_square_matrix(stat, max_size) for stat in stat.statistics]) + [pad_square_matrix(stat, max_size) for stat in stat.statistics] + ) padding_starts.extend([len(stat) for stat in stat.statistics]) # Create global stats @@ -1857,7 +1965,8 @@ def sharded_update_fn(grads, state, params): stat_dtype = new_padded_statistics[0].dtype new_padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) + [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] + ) padding_starts += [0] * to_pad if reuse_preconditioner: @@ -1865,29 +1974,30 @@ def sharded_update_fn(grads, state, params): for stat in new_stats_flat: prev_preconditioners.extend(stat.preconditioners) prev_padded_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, - len(new_padded_statistics), - max_size, - state.count) + prev_preconditioners, len(new_padded_statistics), max_size, state.count + ) else: prev_padded_preconditioners = None new_stacked_padded_statistics = jnp.stack(new_padded_statistics) new_stacked_padded_statistics = pjit.with_sharding_constraint( - new_stacked_padded_statistics, statistics_partition_spec) + new_stacked_padded_statistics, statistics_partition_spec + ) stacked_padding_starts = jnp.array(padding_starts, jnp.int32) prev_stacked_padded_preconditioners = _maybe(jnp.stack)( - prev_padded_preconditioners) + prev_padded_preconditioners + ) prev_stacked_padded_preconditioners = _maybe(pjit.with_sharding_constraint)( - prev_padded_preconditioners, statistics_partition_spec) + prev_padded_preconditioners, statistics_partition_spec + ) def _internal_inverse_pth_root_all(): preconditioners, metrics = _matrix_inverse_pth_root_pjit( - new_stacked_padded_statistics, - global_stats.exponents, - stacked_padding_starts, - prev_stacked_padded_preconditioners, - statistics_partition_spec, + new_stacked_padded_statistics, + global_stats.exponents, + stacked_padding_starts, + prev_stacked_padded_preconditioners, + statistics_partition_spec, ) return preconditioners, metrics @@ -1903,39 +2013,47 @@ def _internal_inverse_pth_root_all(): preconditioners_init = new_stacked_padded_statistics[:, :, :pd] n = new_stacked_padded_statistics.shape[0] metrics_init = cast( - TrainingMetrics, - init_training_metrics( - n, - generate_training_metrics=True, - )) + TrainingMetrics, + init_training_metrics( + n, + generate_training_metrics=True, + ), + ) new_errors = jnp.ones_like(metrics_init.inverse_pth_root_errors) * ( - inverse_failure_threshold) + inverse_failure_threshold + ) metrics_init = metrics_init.replace(inverse_pth_root_errors=new_errors) init_state = [preconditioners_init, metrics_init] new_preconditioners, metrics = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) + perform_step, _internal_inverse_pth_root_all, init_state + ) if generate_training_metrics: new_local_stats_flat = _add_metrics_into_local_stats( - new_local_stats_flat, metrics, ~perform_step) - new_local_stats = jax.tree_util.tree_unflatten(treedef, - new_local_stats_flat) + new_local_stats_flat, metrics, ~perform_step + ) + new_local_stats = jax.tree_util.tree_unflatten( + treedef, new_local_stats_flat + ) errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors >= inverse_failure_threshold + ).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( - predicate * global_stats.preconditioners + - (1.0 - predicate) * new_preconditioners) + predicate * global_stats.preconditioners + + (1.0 - predicate) * new_preconditioners + ) new_global_stats = GlobalShardedParameterStats( - new_stacked_padded_statistics, - new_conditional_preconditioners, - global_stats.exponents) + new_stacked_padded_statistics, + new_conditional_preconditioners, + global_stats.exponents, + ) new_shampoo_state = ShampooState( - count=state.count + 1, - stats=ShardedShampooStats(new_global_stats, new_local_stats)) + count=state.count + 1, + stats=ShardedShampooStats(new_global_stats, new_local_stats), + ) return updates, new_shampoo_state def init_fn(params): @@ -1948,13 +2066,13 @@ def _init(param): if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() statistics = [ - matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes + matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes ] # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. preconditioners = [ - jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1]) - for s in shapes + jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1]) + for s in shapes ] diagonal_statistics = [] @@ -1967,25 +2085,28 @@ def _init(param): momentum = jnp.zeros_like(param) return ParameterStats( - diagonal_statistics, - statistics, - preconditioners, - # _quantize_diagonal_statistics(diagonal_statistics), - # _maybe_quantize_statistics(statistics), - # _maybe_quantize_preconditioners(preconditioners), - diagonal_momentum, - momentum, - init_training_metrics( - len(statistics), - generate_training_metrics, - )) + diagonal_statistics, + statistics, + preconditioners, + # _quantize_diagonal_statistics(diagonal_statistics), + # _maybe_quantize_statistics(statistics), + # _maybe_quantize_preconditioners(preconditioners), + diagonal_momentum, + momentum, + init_training_metrics( + len(statistics), + generate_training_metrics, + ), + ) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params) + ) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( - s > skip_preconditioning_dim_size_gt for s in param.shape) + s > skip_preconditioning_dim_size_gt for s in param.shape + ) def _compute_stats(grad, state, param, step): """Compute per-parameter statistics.""" @@ -1997,105 +2118,117 @@ def _compute_stats(grad, state, param, step): def compute_updated_statistics(): return preconditioner.updated_statistics_from_grad( - state.statistics, - grad, - w1=w1, - w2=w2, - to_float=_to_float, - from_float=lambda x: x, - # from_float=lambda x: _maybe_quantize_statistics([x])[0], - precision=tensordot_precision, + state.statistics, + grad, + w1=w1, + w2=w2, + to_float=_to_float, + from_float=lambda x: x, + # from_float=lambda x: _maybe_quantize_statistics([x])[0], + precision=tensordot_precision, ) if statistics_compute_steps > 1: perform_step = step % statistics_compute_steps == 0 init_state = state.statistics new_statistics = list( - efficient_cond(perform_step, compute_updated_statistics, - init_state)) + efficient_cond(perform_step, compute_updated_statistics, init_state) + ) else: new_statistics = compute_updated_statistics() - return ParameterStats(state.diagonal_statistics, - new_statistics, - state.preconditioners, - state.diagonal_momentum, - state.momentum, - state.training_metrics) + return ParameterStats( + state.diagonal_statistics, + new_statistics, + state.preconditioners, + state.diagonal_momentum, + state.momentum, + state.training_metrics, + ) mi_pth_root = functools.partial( - matrix_inverse_pth_root, - ridge_epsilon=matrix_epsilon, - precision=precision, - relative_matrix_epsilon=relative_matrix_epsilon, - lobpcg_topk_precondition=lobpcg_topk_precondition, - lobpcg_max_iter=lobpcg_max_iter, - eigh=eigh) + matrix_inverse_pth_root, + ridge_epsilon=matrix_epsilon, + precision=precision, + relative_matrix_epsilon=relative_matrix_epsilon, + lobpcg_topk_precondition=lobpcg_topk_precondition, + lobpcg_max_iter=lobpcg_max_iter, + eigh=eigh, + ) def _matrix_inverse_pth_root_vmap(xs, ps, padding_starts, prev): return jax.vmap(mi_pth_root)( - xs, ps, padding_start=padding_starts, prev=prev) + xs, ps, padding_start=padding_starts, prev=prev + ) - def _matrix_inverse_pth_root_pjit(xs, - ps, - padding_starts, - prev_preconds=None, - statistics_partition_spec=None): + def _matrix_inverse_pth_root_pjit( + xs, ps, padding_starts, prev_preconds=None, statistics_partition_spec=None + ): # Partition the concatenated statistics matrix across all cores. pspec_for_partition = preconditioner_partition_spec partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) if preconditioner_partition_spec: partitioned_ps_spec = jax.sharding.PartitionSpec( - preconditioner_partition_spec[0]) + preconditioner_partition_spec[0] + ) else: partitioned_ps_spec = None partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec) partitioned_prev_preconds = _maybe(pjit.with_sharding_constraint)( - prev_preconds, preconditioner_partition_spec) + prev_preconds, preconditioner_partition_spec + ) partitioned_padding_starts = pjit.with_sharding_constraint( - padding_starts, partitioned_ps_spec) # paddings are scalars like ps. + padding_starts, partitioned_ps_spec + ) # paddings are scalars like ps. # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_metrics = ( - _matrix_inverse_pth_root_vmap( - partitioned_xs, - partitioned_ps, - partitioned_padding_starts, - prev=partitioned_prev_preconds)) + _matrix_inverse_pth_root_vmap( + partitioned_xs, + partitioned_ps, + partitioned_padding_starts, + prev=partitioned_prev_preconds, + ) + ) # Reshard output to have the same PSpec as input. This is required to avoid # vmap seeing the full set of statistics. partitioned_preconditioners = pjit.with_sharding_constraint( - partitioned_preconditioners, pspec_for_partition) + partitioned_preconditioners, pspec_for_partition + ) # Recombine the outputs at each core. - preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners, - statistics_partition_spec) - metrics = pjit.with_sharding_constraint(partitioned_metrics, - jax.sharding.PartitionSpec()) + preconditioners = pjit.with_sharding_constraint( + partitioned_preconditioners, statistics_partition_spec + ) + metrics = pjit.with_sharding_constraint( + partitioned_metrics, jax.sharding.PartitionSpec() + ) return preconditioners, metrics - def _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners): + def _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ): """Computes preconditioners for given statistics in states in PMAP mode. - Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. - - Returns: - New optimizer states after computing the preconditioner. - """ + Args: + states: A list of optimizer states. + step: Current step number + statistics: A list of statistics for all variables (for every dim) + num_statistics_per_state: Number of statistis per state to reconstruct + output states. + original_shapes: A list of shapes of the statistics. + exponents: Exponent power to use for inverse-pth roots. + max_size: Maximum dim of the statistics to pad. + prev_preconditioners: Previously available preconditioner. + + Returns: + New optimizer states after computing the preconditioner. + """ if batch_axis_name: num_devices = lax.psum(1, batch_axis_name) else: @@ -2103,13 +2236,15 @@ def _pmap_compute_preconditioners(states, num_statistics = len(statistics) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [ - pad_square_matrix(stat, max_size) for stat in statistics + pad_square_matrix(stat, max_size) for stat in statistics ] to_pad = -num_statistics % num_devices - packed_statistics.extend([ + packed_statistics.extend( + [ jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad) - ]) + ] + ) exponents.extend([1 for _ in range(to_pad)]) paddings = [len(stat) for stat in statistics] + [0] * to_pad @@ -2119,7 +2254,8 @@ def _pmap_compute_preconditioners(states, if reuse_preconditioner: assert len(prev_preconditioners) == num_statistics packed_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, len(packed_statistics), max_size, step) + prev_preconditioners, len(packed_statistics), max_size, step + ) else: packed_preconditioners = None @@ -2132,10 +2268,10 @@ def _internal_inverse_pth_root_all(): if batch_axis_name: current_replica = lax.axis_index(batch_axis_name) preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[current_replica], - all_exponents[current_replica], - all_paddings[current_replica], - _maybe_ix(all_preconditioners, current_replica), + all_statistics[current_replica], + all_exponents[current_replica], + all_paddings[current_replica], + _maybe_ix(all_preconditioners, current_replica), ) preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) @@ -2143,14 +2279,15 @@ def _internal_inverse_pth_root_all(): metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[0], - all_exponents[0], - all_paddings[0], - _maybe_ix(all_preconditioners, 0), + all_statistics[0], + all_exponents[0], + all_paddings[0], + _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) metrics = jax.tree.map( - functools.partial(jnp.expand_dims, axis=0), metrics) + functools.partial(jnp.expand_dims, axis=0), metrics + ) metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2163,40 +2300,57 @@ def _internal_inverse_pth_root_all(): # shaped tensors. Note statistics will be ignored as we are passing in # a large error value. preconditioners_init = [ - s[:, :precond_dim(s.shape[0])] for s in packed_statistics + s[:, : precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) metrics_init = jax.tree.map( - lambda x: [x] * n, - default_training_metrics().replace( - inverse_pth_root_errors=inverse_failure_threshold)) + lambda x: [x] * n, + default_training_metrics().replace( + inverse_pth_root_errors=inverse_failure_threshold + ), + ) init_state = [preconditioners_init, metrics_init] preconditioners_flat, metrics_flat = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) + perform_step, _internal_inverse_pth_root_all, init_state + ) def _skip(error): condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) + jnp.isnan(error), error >= inverse_failure_threshold + ) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) + _skip(error), lambda _: old_p, lambda _: new_p, operand=None + ) new_preconditioners_flat = [] new_errors_flat = metrics_flat.inverse_pth_root_errors - for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, - prev_preconditioners, new_errors_flat): + for p, shape, prev_p, error in zip( + preconditioners_flat, + original_shapes, + prev_preconditioners, + new_errors_flat, + ): new_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) + _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) + ) - assert len(states) == (len(num_statistics_per_state), - f"{len(states)} vs {len(num_statistics_per_state)}") + assert len(states) == ( + len(num_statistics_per_state), + f'{len(states)} vs {len(num_statistics_per_state)}', + ) assert len(new_preconditioners_flat) == num_statistics assert len(new_errors_flat) == len(packed_statistics), ( - len(new_errors_flat), len(packed_statistics)) + len(new_errors_flat), + len(packed_statistics), + ) assert len(new_errors_flat) == num_statistics + to_pad, ( - len(new_errors_flat), num_statistics, to_pad) + len(new_errors_flat), + num_statistics, + to_pad, + ) # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] @@ -2206,26 +2360,31 @@ def _select_preconditioner(error, new_p, old_p): if num_statistics == 0: preconditioners_for_states.append([]) metrics_for_states.append( - init_training_metrics(0, generate_training_metrics)) + init_training_metrics(0, generate_training_metrics) + ) else: - preconditioners_for_state = new_preconditioners_flat[idx:idx + - num_statistics] + preconditioners_for_state = new_preconditioners_flat[ + idx : idx + num_statistics + ] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. metrics_for_state = jax.tree.map( - lambda x: jnp.stack(x[idx:idx + num_statistics]), - metrics_flat, - is_leaf=lambda x: isinstance(x, list)) + lambda x: jnp.stack(x[idx : idx + num_statistics]), + metrics_flat, + is_leaf=lambda x: isinstance(x, list), + ) assert jax.tree_util.tree_all( - jax.tree.map(lambda x: len(state.statistics) == len(x), - metrics_for_state)) + jax.tree.map( + lambda x: len(state.statistics) == len(x), metrics_for_state + ) + ) # If we skipped preconditioner computation, record old metrics. - metrics_for_state = efficient_cond(perform_step, - lambda: [metrics_for_state], - [state.training_metrics])[0] + metrics_for_state = efficient_cond( + perform_step, lambda: [metrics_for_state], [state.training_metrics] + )[0] # pylint:enable=cell-var-from-loop else: metrics_for_state = optax.MaskedNode() @@ -2234,32 +2393,36 @@ def _select_preconditioner(error, new_p, old_p): idx += num_statistics new_states = [] for state, new_preconditioners, new_metrics in zip( - states, preconditioners_for_states, metrics_for_states): + states, preconditioners_for_states, metrics_for_states + ): # Note the preconditioner may have been skipped, but we still update the # metrics with the new error values; whether the preconditioner that's # actively being used is stale can be derived from the new_metrics # being greater than the failure threshold. new_states.append( - ParameterStats(state.diagonal_statistics, - state.statistics, - new_preconditioners, - state.diagonal_momentum, - state.momentum, - new_metrics)) + ParameterStats( + state.diagonal_statistics, + state.statistics, + new_preconditioners, + state.diagonal_momentum, + state.momentum, + new_metrics, + ) + ) return new_states def _compute_preconditioners(states, params, step): """Computes preconditioners for given statistics in states. - Args: - states: A list of optimizer states. - params: A list of params. - step: Current step number + Args: + states: A list of optimizer states. + params: A list of params. + step: Current step number - Returns: - New optimizer states after computing the preconditioner. - """ + Returns: + New optimizer states after computing the preconditioner. + """ statistics = [] num_statistics_per_state = [] original_shapes = [] @@ -2274,8 +2437,11 @@ def _compute_preconditioners(states, params, step): if num_statistics > 0: preconditioner = preconditioner_from_params(param) for statistic in state.statistics: - exponents.append(preconditioner.exponent_for_preconditioner( - ) if exponent_override == 0 else exponent_override) + exponents.append( + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) @@ -2283,14 +2449,16 @@ def _compute_preconditioners(states, params, step): prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) - return _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners) + return _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ) def _transform_grad(grad, state, param, step): """Transform per-parameter gradients.""" @@ -2298,21 +2466,25 @@ def _transform_grad(grad, state, param, step): sgd_update = grad new_diagonal_statistics = state.diagonal_statistics - if (graft_type == GraftingType.ADAGRAD or - graft_type == GraftingType.ADAGRAD_NORMALIZED): - + if ( + graft_type == GraftingType.ADAGRAD + or graft_type == GraftingType.ADAGRAD_NORMALIZED + ): scaled_grad = grad if graft_type == GraftingType.ADAGRAD_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) new_diagonal_statistics = ( - state.diagonal_statistics.to_float() + jnp.square(scaled_grad)) + state.diagonal_statistics.to_float() + jnp.square(scaled_grad) + ) adagrad_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) grafting_update = adagrad_update - elif (graft_type == GraftingType.RMSPROP or - graft_type == GraftingType.RMSPROP_NORMALIZED): - + elif ( + graft_type == GraftingType.RMSPROP + or graft_type == GraftingType.RMSPROP_NORMALIZED + ): scaled_grad = grad if graft_type == GraftingType.RMSPROP_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) @@ -2321,15 +2493,19 @@ def _transform_grad(grad, state, param, step): w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2) new_diagonal_statistics = ( - w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)) + w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad) + ) rmsprop_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) if clip_by_scaled_gradient_norm: scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( - jnp.sqrt(float(rmsprop_update.size))) + jnp.sqrt(float(rmsprop_update.size)) + ) clipping_denom = jnp.maximum( - 1., scaled_grad_norm / clip_by_scaled_gradient_norm) + 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm + ) rmsprop_update /= clipping_denom grafting_update = rmsprop_update @@ -2349,12 +2525,14 @@ def _transform_grad(grad, state, param, step): precond_grad = grad if not _skip_preconditioning(param): - precond_grad = preconditioner.preconditioned_grad(precond_grad, - state.preconditioners) + precond_grad = preconditioner.preconditioned_grad( + precond_grad, state.preconditioners + ) else: if graft_type == GraftingType.NONE: - logging.error("skipping preconditioning without grafting for param %s", - param) + logging.error( + 'skipping preconditioning without grafting for param %s', param + ) precond_grad = grafting_update grafting_update_norm = jnp.linalg.norm(grafting_update) @@ -2369,39 +2547,49 @@ def _transform_grad(grad, state, param, step): shampoo_update_with_wd = shampoo_update grafting_update_with_wd = grafting_update - if (weight_decay != 0 and weight_decay is not None and - not decoupled_weight_decay): + if ( + weight_decay != 0 + and weight_decay is not None + and not decoupled_weight_decay + ): shampoo_update_with_wd = shampoo_update + weight_decay * param grafting_update_with_wd = grafting_update + weight_decay * param w = (1.0 - beta1) if moving_average_for_momentum else 1.0 shampoo_update_with_wd_momentum = ( - state.momentum * beta1 + w * shampoo_update_with_wd) + state.momentum * beta1 + w * shampoo_update_with_wd + ) grafting_update_with_wd_momentum = ( - state.diagonal_momentum * beta1 + w * grafting_update_with_wd) + state.diagonal_momentum * beta1 + w * grafting_update_with_wd + ) run_shampoo = (step >= start_preconditioning_step).astype( - grafting_update_with_wd_momentum.dtype) + grafting_update_with_wd_momentum.dtype + ) momentum_update = ( - run_shampoo * shampoo_update_with_wd_momentum + - (1.0 - run_shampoo) * grafting_update_with_wd_momentum) + run_shampoo * shampoo_update_with_wd_momentum + + (1.0 - run_shampoo) * grafting_update_with_wd_momentum + ) wd_update = ( - run_shampoo * shampoo_update_with_wd + - (1.0 - run_shampoo) * grafting_update_with_wd) + run_shampoo * shampoo_update_with_wd + + (1.0 - run_shampoo) * grafting_update_with_wd + ) nesterov_momentum_update = momentum_update if nesterov: nesterov_momentum_update = w * wd_update + beta1 * momentum_update - if (weight_decay != 0 and weight_decay is not None and - decoupled_weight_decay): + if ( + weight_decay != 0 and weight_decay is not None and decoupled_weight_decay + ): nesterov_momentum_update = ( - nesterov_momentum_update + lr * weight_decay * param) + nesterov_momentum_update + lr * weight_decay * param + ) momentum_multiplier = lr if decoupled_learning_rate else 1.0 transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update @@ -2409,26 +2597,28 @@ def _transform_grad(grad, state, param, step): new_diagonal_momentum = grafting_update_with_wd_momentum new_momentum = shampoo_update_with_wd_momentum - param_stats = ParameterStats(new_diagonal_statistics, - state.statistics, - state.preconditioners, - new_diagonal_momentum, - new_momentum, - state.training_metrics) + param_stats = ParameterStats( + new_diagonal_statistics, + state.statistics, + state.preconditioners, + new_diagonal_momentum, + new_momentum, + state.training_metrics, + ) return transformed_update, param_stats def update_fn(grads, state, params): """Transform the input gradient and update all statistics. - Args: - grads: the gradient tensors for the parameters and any custom - gradients for preconditioners. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. + Args: + grads: the gradient tensors for the parameters and any custom + gradients for preconditioners. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. - Returns: - A tuple containing the new parameters and the new optimizer state. - """ + Returns: + A tuple containing the new parameters and the new optimizer state. + """ grads_custom = None if custom_preconditioner and isinstance(grads, tuple): grads, grads_custom = grads @@ -2442,23 +2632,21 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - stats_grads, - stats_flat, - params_flat) - - new_stats_flat = _compute_preconditioners(new_stats_flat, - params_flat, - state.count) + lambda g, s, p: _compute_stats(g, s, p, state.count), + stats_grads, + stats_flat, + params_flat, + ) + + new_stats_flat = _compute_preconditioners( + new_stats_flat, params_flat, state.count + ) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_util.tree_unflatten(treedef, updates_flat) new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat) @@ -2472,9 +2660,10 @@ def update_fn(grads, state, params): def _init_fns(unused_params): return InitFnState( - init_fn=opt_init_fn, - pspec_fn=sharded_init_partition_spec_fn, - shape_and_dtype_fn=sharded_init_shape_and_dtype_fn) + init_fn=opt_init_fn, + pspec_fn=sharded_init_partition_spec_fn, + shape_and_dtype_fn=sharded_init_shape_and_dtype_fn, + ) opt_update_fn = sharded_update_fn return optax.GradientTransformation(_init_fns, opt_update_fn) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 2cd054062..526afe7d5 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -10,17 +10,20 @@ import optax from algoperf import spec -from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ - distributed_shampoo +from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import ( + distributed_shampoo, +) _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Shampoo optimizer and a learning rate schedule.""" del model_params del model_state @@ -30,102 +33,116 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = distributed_shampoo( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - beta2=hyperparameters.beta2, - weight_decay=hyperparameters.weight_decay, - batch_axis_name='batch', - eigh=False) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + beta1=1.0 - hyperparameters.one_minus_beta1, + beta2=hyperparameters.beta2, + weight_decay=hyperparameters.weight_decay, + batch_axis_name='batch', + eigh=False, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -142,37 +159,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -208,14 +231,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py index 116ebc555..eeb87cd87 100644 --- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py +++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py @@ -9,27 +9,31 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[hyperparameters.warmup_steps] + ) return schedule_fn def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup = LinearLR( - optimizer, - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) + optimizer, + start_factor=1e-10, + end_factor=1.0, + total_iters=hyperparameters.warmup_steps, + ) cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, - schedulers=[warmup, cosine_decay], - milestones=[hyperparameters.warmup_steps]) + optimizer, + schedulers=[warmup, cosine_decay], + milestones=[hyperparameters.warmup_steps], + ) diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json index cab6fd5f7..6061940a9 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json @@ -1,27 +1,17 @@ { "learning_rate": { - "feasible_points": [ - 0.0033313215673016375 - ] + "feasible_points": [0.0033313215673016375] }, "beta1": { - "feasible_points": [ - 0.948000082541717 - ] + "feasible_points": [0.948000082541717] }, "beta2": { - "feasible_points": [ - 0.9987934318891598 - ] + "feasible_points": [0.9987934318891598] }, "warmup_steps": { - "feasible_points": [ - 159 - ] + "feasible_points": [159] }, "weight_decay": { - "feasible_points": [ - 0.0035784380304876183 - ] + "feasible_points": [0.0035784380304876183] } } diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py index 5e70f9f8b..e0d9c0ee9 100644 --- a/reference_algorithms/target_setting_algorithms/data_selection.py +++ b/reference_algorithms/target_setting_algorithms/data_selection.py @@ -4,14 +4,15 @@ def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..e6f8d915c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -1,4 +1,5 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" + from flax import jax_utils import jax import jax.numpy as jnp @@ -6,39 +7,48 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_params del model_state del rng target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..b37464c1c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -8,19 +8,24 @@ import optax from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload, # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=hyperparameters.beta1, - nesterov=False) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=hyperparameters.beta1, + nesterov=False, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) decay_steps = step_hint - hyperparameters.warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], + boundaries=[hyperparameters.warmup_steps], + ) return lr_schedule_fn @@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..7d4a1fcc1 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -10,26 +10,28 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( + update_params, +) # pylint: disable=unused-import # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -61,19 +63,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: There seem to be multiple versions of NAdam. The original version is here @@ -109,7 +114,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -117,6 +123,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -125,7 +132,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -141,31 +149,37 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state del rng target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..714cbb225 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -8,19 +8,24 @@ import optax from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload, # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=hyperparameters.beta1, - nesterov=True) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=hyperparameters.beta1, + nesterov=True, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) decay_steps = step_hint - hyperparameters.warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], + boundaries=[hyperparameters.warmup_steps], + ) return lr_schedule_fn @@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..557e6957c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,4 +1,5 @@ """Update submission function in Jax.""" + import functools from typing import Any, Dict, List, Optional, Tuple @@ -13,75 +14,84 @@ @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -98,32 +108,46 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, - label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + pmapped_train_step( # pylint: disable=line-too-long + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) + ) # Log loss, grad_norm. - if ((global_step <= 100 or global_step % 500 == 0) and - workload.metrics_logger is not None): + if ( + global_step <= 100 or global_step % 500 == 0 + ) and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py index c87bdfb7d..f2474a706 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py @@ -4,37 +4,44 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, - weight_decay=hyperparameters.weight_decay), + 'optimizer': torch.optim.AdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ), } target_setting_step_hint = int(0.75 * workload.step_hint) optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + target_setting_step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py index 584caff39..030939de5 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py @@ -4,40 +4,47 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_momentum import \ - create_lr_schedule_fn -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_momentum import ( + create_lr_schedule_fn, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( + update_params, +) # pylint: disable=unused-import + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=False), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=hyperparameters.beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=False, + ), } # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will # be multiplied with the base lr, so we have to divide by it here. @@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index a9dee1d79..ceeebda6d 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -8,43 +8,43 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( + update_params, +) # pylint: disable=unused-import # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -56,7 +56,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -64,7 +67,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -72,10 +76,10 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ self._cuda_graph_capture_health_check() loss = None @@ -103,51 +107,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float): +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ + See NAdamW class for details. + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -185,28 +195,32 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ), } target_setting_step_hint = int(0.75 * workload.step_hint) optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + target_setting_step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py index 8e10db4ef..ddbcaefdb 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py @@ -4,40 +4,47 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_nesterov import \ - create_lr_schedule_fn -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import ( + get_batch_size, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_nesterov import ( + create_lr_schedule_fn, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( + update_params, +) # pylint: disable=unused-import + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=True), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=hyperparameters.beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=True, + ), } # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will # be multiplied with the base lr, so we have to divide by it here. @@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index bbfd8b0f2..36f736a6b 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -13,18 +13,19 @@ def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -36,26 +37,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -68,7 +73,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() if 'scheduler' in optimizer_state: optimizer_state['scheduler'].step() @@ -78,32 +84,39 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters From e5209d1dce6b9d04722cef9ad2873be7425bd90e Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:19:49 +0200 Subject: [PATCH 104/123] Format tests/ --- tests/modeldiffs/criteo1tb/compare.py | 51 ++- .../criteo1tb_embed_init/compare.py | 51 ++- .../modeldiffs/criteo1tb_layernorm/compare.py | 51 ++- tests/modeldiffs/criteo1tb_resnet/compare.py | 51 ++- tests/modeldiffs/diff.py | 95 ++-- tests/modeldiffs/fastmri/compare.py | 46 +- tests/modeldiffs/fastmri_layernorm/compare.py | 46 +- .../modeldiffs/fastmri_model_size/compare.py | 46 +- tests/modeldiffs/fastmri_tanh/compare.py | 46 +- tests/modeldiffs/imagenet_resnet/compare.py | 49 +- .../imagenet_resnet/gelu_compare.py | 43 +- .../imagenet_resnet/silu_compare.py | 43 +- tests/modeldiffs/imagenet_vit/compare.py | 57 +-- tests/modeldiffs/imagenet_vit_glu/compare.py | 43 +- tests/modeldiffs/imagenet_vit_map/compare.py | 43 +- .../modeldiffs/imagenet_vit_postln/compare.py | 43 +- .../librispeech_conformer/compare.py | 47 +- .../compare.py | 47 +- .../librispeech_conformer_gelu/compare.py | 47 +- .../compare.py | 47 +- .../librispeech_deepspeech/compare.py | 57 ++- .../compare.py | 47 +- .../librispeech_deepspeech_normaug/compare.py | 47 +- .../librispeech_deepspeech_tanh/compare.py | 47 +- tests/modeldiffs/ogbg/compare.py | 67 +-- tests/modeldiffs/ogbg_gelu/compare.py | 67 +-- tests/modeldiffs/ogbg_model_size/compare.py | 67 +-- tests/modeldiffs/ogbg_silu/compare.py | 67 +-- tests/modeldiffs/torch2jax_utils.py | 31 +- tests/modeldiffs/vanilla_sgd_jax.py | 27 +- tests/modeldiffs/vanilla_sgd_pytorch.py | 27 +- tests/modeldiffs/wmt/compare.py | 64 +-- .../modeldiffs/wmt_attention_temp/compare.py | 69 +-- tests/modeldiffs/wmt_glu_tanh/compare.py | 69 +-- tests/modeldiffs/wmt_post_ln/compare.py | 69 +-- tests/reference_algorithm_tests.py | 428 ++++++++++-------- tests/submission_runner_test.py | 70 +-- tests/test_baselines.py | 94 ++-- tests/test_num_params.py | 154 ++++--- tests/test_param_shapes.py | 113 +++-- tests/test_param_types.py | 213 +++++---- tests/test_ssim.py | 35 +- tests/test_traindiffs.py | 121 ++--- tests/test_version.py | 4 +- .../imagenet_jax/workload_test.py | 59 +-- 45 files changed, 1726 insertions(+), 1379 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9de61a2a5..48f658d06 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -54,31 +56,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index f1897d16f..897920bac 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -53,31 +55,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 5aad3cc67..db1dec601 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -65,31 +67,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - # mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + # mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 169b1cdf4..4851a8ad4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallResNetWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallResNetWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -65,9 +67,9 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } init_fake_batch_size = 2 @@ -80,23 +82,26 @@ def sd_transform(sd): # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 52241fd3a..8449c3241 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -8,14 +8,17 @@ from tests.modeldiffs.torch2jax_utils import value_transform -#pylint: disable=dangerous-default-value -def torch2jax(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) +# pylint: disable=dangerous-default-value +def torch2jax( + jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0), +): + jax_params, model_state = jax_workload.init_model_fn( + jax.random.PRNGKey(0), **init_kwargs + ) pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) if isinstance(jax_params, dict): jax_params = FrozenDict(jax_params) @@ -24,8 +27,9 @@ def torch2jax(jax_workload, model_state = jax_utils.unreplicate(model_state) if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): pytorch_model = pytorch_model.module # Map and copy params of pytorch_model to jax_model. t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) @@ -39,22 +43,24 @@ def torch2jax(jax_workload, return jax_params, model_state, pytorch_model -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) +def out_diff( + jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None, +): + jax_params, model_state, pytorch_model = torch2jax( + jax_workload, pytorch_workload, key_transform, sd_transform + ) + out_p, _ = pytorch_workload.model_fn( + params=pytorch_model, **pytorch_model_kwargs + ) + out_j, _ = jax_workload.model_fn( + params=jax_params, model_state=model_state, **jax_model_kwargs + ) if out_transform is not None: out_p = out_transform(out_p) out_j = out_transform(out_j) @@ -67,15 +73,16 @@ def out_diff(jax_workload, class ModelDiffRunner: - - def __init__(self, - jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None) -> None: + def __init__( + self, + jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None, + ) -> None: """ Initializes the instance based on diffing logic. @@ -83,7 +90,7 @@ def __init__(self, jax_workload: Workload implementation using JAX. pytorch_workload: Workload implementation using PyTorch. jax_model_kwargs: Arguments to be used for model_fn in jax workload. - pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch + pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload. key_transform: Transformation function for keys. sd_transform: Transformation function for State Dictionary. @@ -99,10 +106,12 @@ def __init__(self, self.out_transform = out_transform def run(self): - out_diff(self.jax_workload, - self.pytorch_workload, - self.jax_model_kwargs, - self.pytorch_model_kwargs, - self.key_transform, - self.sd_transform, - self.out_transform) + out_diff( + self.jax_workload, + self.pytorch_workload, + self.jax_model_kwargs, + self.pytorch_model_kwargs, + self.key_transform, + self.sd_transform, + self.out_transform, + ) diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index c1a349cec..6a82bfb58 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRIWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index f26ad185e..8ad47bcae 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRILayerNormWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRILayerNormWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRILayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRILayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -35,7 +36,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -71,22 +72,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index 42789539b..f6d5c5074 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRIModelSizeWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIModelSizeWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIModelSizeWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIModelSizeWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 13ecb890c..714a025b3 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRITanhWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRITanhWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRITanhWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRITanhWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 59ab45555..e43cd069e 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -53,11 +55,13 @@ def sd_transform(sd): c += 1 new_key = (f'BottleneckResNetBlock_{c}',) + k[2:] if 'Sequential' in ''.join(new_key): - new_key = tuple([ + new_key = tuple( + [ (i.replace('_0', '_proj') if 'BatchNorm' in i or 'Conv' in i else i) for i in new_key if 'Sequential' not in i - ]) + ] + ) sd[new_key] = sd[k] del sd[k] elif 'BatchNorm' in k[0] or 'Conv' in k[0]: @@ -81,22 +85,25 @@ def sd_transform(sd): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 07510ad70..8aa48382d 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetGELUWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetGELUWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetGELUWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetGELUWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -28,22 +30,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 8246d17a2..393badd18 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetSiLUWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetSiLUWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetSiLUWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetSiLUWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -28,22 +30,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index b4ca7d8ec..84282d4be 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -40,16 +42,16 @@ def key_transform(k): if attention: if pool_head: i = { - 'Linear_0': 'query', - 'Linear_1': 'key_value', - 'Linear_2': 'out', + 'Linear_0': 'query', + 'Linear_1': 'key_value', + 'Linear_2': 'out', }[i] else: i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', }[i] else: i = i.replace('Linear', 'Dense') @@ -94,22 +96,25 @@ def key_transform(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index c152410b5..55c010b97 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitGluWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitGluWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitGluWorkload as PyTorchWorkload, +) sd_transform = None @@ -30,22 +32,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index 7f1af41ab..17a7483c2 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitMapWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitMapWorkload as PytWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitMapWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitMapWorkload as PytWorkload, +) def sd_transform(sd): @@ -41,22 +43,25 @@ def sd_transform(sd): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index a3a639101..72d407a4b 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitPostLNWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitPostLNWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetViTPostLNWorkload as PyTorchWorkload, +) sd_transform = None @@ -30,22 +32,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 664b1242d..80aba62f2 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index b0812e77d..a7bebf6bf 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 3032a0005..d8f7980a2 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerGeluWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerGeluWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerGeluWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerGeluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index d623ef352..7f4768c11 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerLayerNormWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerLayerNormWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerLayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerLayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 84b0a6c86..12cd11513 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -67,10 +69,12 @@ def sd_transform(sd): if isinstance(out[k], dict): kernels = ['kernel_ih_l0', 'kernel_hh_l0'] biases = ['bias_ih_l0', 'bias_hh_l0'] - weights = torch.cat([out[k][i].view(-1) for i in kernels] + - [out[k][i + '_reverse'].view(-1) for i in kernels] + - [out[k][i].view(-1) for i in biases] + - [out[k][i + '_reverse'].view(-1) for i in biases]) + weights = torch.cat( + [out[k][i].view(-1) for i in kernels] + + [out[k][i + '_reverse'].view(-1) for i in kernels] + + [out[k][i].view(-1) for i in biases] + + [out[k][i + '_reverse'].view(-1) for i in biases] + ) updates[k + ('weights',)] = weights keys_to_del.append(k) out.update(updates) @@ -94,24 +98,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 2540c1b93..6a719a84a 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechTanhWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechTanhWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -30,24 +32,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index e5972120d..c8820d397 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -30,24 +32,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 4d2c4a5d5..0882f3d1e 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -30,24 +32,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 5d5ef50bf..8c40d3c8a 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index fc3992998..f35ed8b17 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgGeluWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgGeluWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgGeluWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgGeluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index e7cfa745c..0042c71af 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgModelSizeWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgModelSizeWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgModelSizeWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgModelSizeWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 4e3b96cf7..7583282cd 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgSiluWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgSiluWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgSiluWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgSiluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index d9264b400..7c77a152c 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -32,13 +32,17 @@ def flatten(jm, ret, keys=None): def value_transform(k, value, jax_value): k_str = ''.join(k).lower() - if ('conv' in k_str and 'kernel' in k_str) or \ - ('embedding' in k_str and 'kernel' in k_str): + if ('conv' in k_str and 'kernel' in k_str) or ( + 'embedding' in k_str and 'kernel' in k_str + ): if 'transpose' in k_str: # Assumes 2D ConvTranspose with stride equal to kernel_size. - return value.reshape(value.shape[0], value.shape[1], - -1).flip(-1).permute(2, 0, - 1).reshape(*jax_value.shape) + return ( + value.reshape(value.shape[0], value.shape[1], -1) + .flip(-1) + .permute(2, 0, 1) + .reshape(*jax_value.shape) + ) else: rank = len(value.shape) if rank == 3: @@ -51,16 +55,17 @@ def value_transform(k, value, jax_value): value = value.t().reshape(*list(jax_value.shape)) elif 'attention' in k_str and 'bias' in k_str: value = value.reshape(*list(jax_value.shape)) - elif ('dense' in k_str and 'kernel' in k_str) or \ - ('lstm' in k_str and 'kernel' in k_str) or \ - ('head' in k_str and 'kernel' in k_str) or \ - ('pre_logits' in k_str and 'kernel' in k_str): + elif ( + ('dense' in k_str and 'kernel' in k_str) + or ('lstm' in k_str and 'kernel' in k_str) + or ('head' in k_str and 'kernel' in k_str) + or ('pre_logits' in k_str and 'kernel' in k_str) + ): value = value.t() return value class Torch2Jax: - def __init__(self, torch_model, jax_model): self.torch_model = torch_model self.jax_model = jax_model @@ -73,13 +78,13 @@ def __init__(self, torch_model, jax_model): def key_transform(self, k_transform_fn): self.pytorch_sd = { - k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd + k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd } def value_transform(self, v_transform_fn): self.pytorch_sd = { - k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) - for k in self.pytorch_sd + k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) + for k in self.pytorch_sd } def sd_transform(self, sd_transform_fn): diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index 5595894e6..e80e70b8e 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -4,25 +4,30 @@ import optax from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Vanilla SGD Optimizer.""" del model_params del model_state del rng # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) optimizer_state = opt_init_fn(params_zeros_like) diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py index a6a0c5fa6..d6613479e 100644 --- a/tests/modeldiffs/vanilla_sgd_pytorch.py +++ b/tests/modeldiffs/vanilla_sgd_pytorch.py @@ -1,24 +1,29 @@ import torch from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( + data_selection, +) # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( + update_params, +) # pylint: disable=unused-import -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Vanilla SGD Optimizer.""" del model_state del rng optimizer_state = { - 'optimizer': - torch.optim.SGD(model_params.parameters(), lr=0.001, weight_decay=0), + 'optimizer': torch.optim.SGD( + model_params.parameters(), lr=0.001, weight_decay=0 + ), } return optimizer_state diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 109bfa629..02175c8b5 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -8,17 +8,20 @@ from algoperf import spec from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkload as PyTorchWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -60,7 +63,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -73,11 +76,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -112,29 +115,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 1aa20fe3b..0c834dc86 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadAttentionTemp as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadAttentionTemp as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadAttentionTemp as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadAttentionTemp as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index e98a6945d..f7de12326 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadGLUTanH as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadGLUTanH as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadGLUTanH as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadGLUTanH as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index d110715b5..a8681ca8e 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadPostLN as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadPostLN as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadPostLN as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadPostLN as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 58a4a5ddc..d17848aaf 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -27,67 +27,66 @@ import os import pickle -from absl import flags -from absl import logging -from absl.testing import absltest import flax -from flax import jax_utils -from flax.core.frozen_dict import FrozenDict import jax -from jraph import GraphsTuple import numpy as np import tensorflow as tf import torch import torch.distributed as dist +from absl import flags, logging +from absl.testing import absltest +from flax import jax_utils +from flax.core.frozen_dict import FrozenDict +from jraph import GraphsTuple -from algoperf import halton -from algoperf import pytorch_utils +import submission_runner +from algoperf import halton, pytorch_utils from algoperf import random_utils as prng from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map -import submission_runner from tests.modeldiffs import diff as diff_utils flags.DEFINE_integer( - 'global_batch_size', - -1, - ('Global Batch size to use when running an individual workload. Otherwise ' - 'a per-device batch size of 2 is used.')) + 'global_batch_size', + -1, + ( + 'Global Batch size to use when running an individual workload. Otherwise ' + 'a per-device batch size of 2 is used.' + ), +) flags.DEFINE_integer('num_train_steps', 1, 'Number of steps to train.') flags.DEFINE_boolean('use_fake_input_queue', True, 'Use fake data examples.') flags.DEFINE_string('log_file', '/tmp/log.pkl', 'The log file') flags.DEFINE_boolean( - 'all', - False, - 'Run all workloads instead of using --workload and --framework.') -flags.DEFINE_boolean('identical', - False, - 'Run jax and pytorch with identical weights.') + 'all', False, 'Run all workloads instead of using --workload and --framework.' +) +flags.DEFINE_boolean( + 'identical', False, 'Run jax and pytorch with identical weights.' +) FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, PYTORCH_DEVICE, N_GPUS = pytorch_utils.pytorch_setup() tf.config.set_visible_devices([], 'GPU') _EXPECTED_METRIC_NAMES = { - 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'], - 'criteo1tb': ['train/loss', 'validation/loss'], - 'criteo1tb_test': ['train/loss', 'validation/loss'], - 'fastmri': ['train/ssim', 'validation/ssim'], - 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'], - 'imagenet_vit': ['train/accuracy', 'validation/accuracy'], - 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], - 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], - 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], - 'ogbg': [ - 'train/accuracy', 'validation/loss', 'test/mean_average_precision' - ], - 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], + 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'], + 'criteo1tb': ['train/loss', 'validation/loss'], + 'criteo1tb_test': ['train/loss', 'validation/loss'], + 'fastmri': ['train/ssim', 'validation/ssim'], + 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'], + 'imagenet_vit': ['train/accuracy', 'validation/accuracy'], + 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], + 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], + 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], + 'ogbg': ['train/accuracy', 'validation/loss', 'test/mean_average_precision'], + 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], } def _make_fake_image_batch(batch_shape, data_shape, num_classes): - examples = np.random.normal(size=(*batch_shape, - *data_shape)).astype(np.float32) + examples = np.random.normal(size=(*batch_shape, *data_shape)).astype( + np.float32 + ) labels = np.random.randint(0, num_classes, size=batch_shape) masks = np.ones(batch_shape, dtype=np.float32) return {'inputs': examples, 'targets': labels, 'weights': masks} @@ -96,16 +95,17 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: return jax.tree.map( - lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) + lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs + ) return jax.tree.map( - lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) - if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( - -1), - inputs) + lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) + if len(a.shape) == 3 + else torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1), + inputs, + ) class _FakeTokenizer: - def detokenize(self, *args): del args return tf.constant('this is a fake sequence?') @@ -113,15 +113,14 @@ def detokenize(self, *args): @flax.struct.dataclass class _FakeMetricsCollection: - def merge(self, *args): del args return self def compute(self): return { - 'wer': 0.0, - 'ctc_loss': 0.0, + 'wer': 0.0, + 'ctc_loss': 0.0, } def unreplicate(self): @@ -129,7 +128,6 @@ def unreplicate(self): class _FakeMetricsLogger: - def __init__(self): self.filename = FLAGS.log_file self.scalars = [] @@ -152,27 +150,27 @@ def append_eval_metrics(self, result): def save(self): with open(self.filename, 'wb') as f: - pickle.dump({'scalars': self.scalars, 'eval_results': self.eval_results}, - f) + pickle.dump( + {'scalars': self.scalars, 'eval_results': self.eval_results}, f + ) class _FakeMetricsBundle: - def gather_from_model_output(self, *args, **kwargs): del args del kwargs return _FakeMetricsCollection() -def _make_one_batch_workload(workload_class, - workload_name, - framework, - global_batch_size, - use_fake_input_queue, - n_gpus): - +def _make_one_batch_workload( + workload_class, + workload_name, + framework, + global_batch_size, + use_fake_input_queue, + n_gpus, +): class _OneEvalBatchWorkload(workload_class): - def __init__(self): kwargs = {} if 'librispeech' in workload_name: @@ -186,20 +184,27 @@ def __init__(self): def init_model_fn(self, rng): # pylint: disable=line-too-long - if not (FLAGS.identical and - os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): + if not ( + FLAGS.identical + and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py') + ): return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( - f'tests.modeldiffs.{workload_name}.compare') + f'tests.modeldiffs.{workload_name}.compare' + ) jax_params, model_state, _ = diff_utils.torch2jax( jax_workload=super(), pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs), key_transform=compare_module.key_transform, - sd_transform=compare_module.sd_transform) - return (FrozenDict(**jax_utils.replicate(jax_params)), - FrozenDict(**jax_utils.replicate(model_state)) - if model_state is not None else model_state) + sd_transform=compare_module.sd_transform, + ) + return ( + FrozenDict(**jax_utils.replicate(jax_params)), + FrozenDict(**jax_utils.replicate(model_state)) + if model_state is not None + else model_state, + ) return super().init_model_fn([0]) @property @@ -235,73 +240,74 @@ def _build_input_queue(self, *args, **kwargs): else: data_shape = (3, 32, 32) fake_batch = _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=10) + batch_shape, data_shape=data_shape, num_classes=10 + ) elif workload_name == 'criteo1tb' or workload_name == 'criteo1tb_test': targets = np.ones(batch_shape) targets[0] = 0 fake_batch = { - 'inputs': np.ones((*batch_shape, 13 + 26)), - 'targets': targets, - 'weights': np.ones(batch_shape), + 'inputs': np.ones((*batch_shape, 13 + 26)), + 'targets': targets, + 'weights': np.ones(batch_shape), } elif workload_name in ['imagenet_resnet', 'imagenet_vit']: data_shape = (224, 224, 3) fake_batch = _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) + batch_shape, data_shape=data_shape, num_classes=1000 + ) if framework == 'pytorch': num_dims = len(fake_batch['inputs'].shape) fake_batch['inputs'] = fake_batch['inputs'].transpose( - (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2)) + (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2) + ) elif 'librispeech' in workload_name: rate = 16000 - l = None - while l is None or l.shape[-1] < 320000: + audio_signal = None + while audio_signal is None or audio_signal.shape[-1] < 320000: duration = 0.5 - freq = 2**(np.random.rand(*batch_shape, 1) * 13) + freq = 2 ** (np.random.rand(*batch_shape, 1) * 13) wav = np.sin(2 * np.pi * freq * np.arange(rate * duration) / rate) - if l is None: - l = wav + if audio_signal is None: + audio_signal = wav else: - l = np.concatenate([l, wav], axis=-1) - inputs = l + audio_signal = np.concatenate([audio_signal, wav], axis=-1) + inputs = audio_signal targets = np.random.randint(low=1, high=1024, size=(*batch_shape, 256)) tgt_pad = np.arange(0, 256)[tuple([None] * len(batch_shape))] tgt_lengths = np.random.randint( - low=100, high=256, size=(*batch_shape, 1)) + low=100, high=256, size=(*batch_shape, 1) + ) tgt_pad = 1 * (tgt_pad > tgt_lengths) fake_batch = { - 'inputs': (inputs, np.zeros_like(inputs)), - 'targets': (targets, tgt_pad), + 'inputs': (inputs, np.zeros_like(inputs)), + 'targets': (targets, tgt_pad), } elif workload_name == 'mnist': fake_batch = _make_fake_image_batch( - batch_shape, data_shape=(28, 28, 1), num_classes=10) + batch_shape, data_shape=(28, 28, 1), num_classes=10 + ) elif workload_name == 'ogbg': tf.random.set_seed(5) def _fake_iter(): while True: fake_batch = { - 'num_nodes': - tf.ones((1,), dtype=tf.int64), - 'edge_index': - tf.ones((1, 2), dtype=tf.int64), - 'node_feat': - tf.random.normal((1, 9)), - 'edge_feat': - tf.random.normal((1, 3)), - 'labels': - tf.cast( - tf.random.uniform((self._num_outputs,), - minval=0, - maxval=2, - dtype=tf.int32), - tf.float32), + 'num_nodes': tf.ones((1,), dtype=tf.int64), + 'edge_index': tf.ones((1, 2), dtype=tf.int64), + 'node_feat': tf.random.normal((1, 9)), + 'edge_feat': tf.random.normal((1, 3)), + 'labels': tf.cast( + tf.random.uniform( + (self._num_outputs,), minval=0, maxval=2, dtype=tf.int32 + ), + tf.float32, + ), } yield fake_batch fake_batch_iter = ogbg_input_pipeline._get_batch_iterator( - _fake_iter(), global_batch_size) + _fake_iter(), global_batch_size + ) fake_batch = next(fake_batch_iter) # pylint: disable=stop-iteration-return if framework == 'pytorch': fake_batch['inputs'] = _graph_map(_pytorch_map, fake_batch['inputs']) @@ -310,48 +316,49 @@ def _fake_iter(): elif workload_name == 'wmt': max_len = 256 fake_batch = { - 'inputs': - np.random.randint( - low=0, high=32000, size=(*batch_shape, max_len)), - 'targets': - np.random.randint( - low=0, high=32000, size=(*batch_shape, max_len)), - 'weights': - np.random.randint(low=0, high=2, size=(*batch_shape, max_len)), + 'inputs': np.random.randint( + low=0, high=32000, size=(*batch_shape, max_len) + ), + 'targets': np.random.randint( + low=0, high=32000, size=(*batch_shape, max_len) + ), + 'weights': np.random.randint( + low=0, high=2, size=(*batch_shape, max_len) + ), } self._tokenizer = _FakeTokenizer() elif workload_name == 'fastmri': data_shape = (320, 320) fake_batch = { - 'inputs': - _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) - ['inputs'], - 'targets': - _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) - ['inputs'], - 'mean': - np.zeros(batch_shape), - 'std': - np.ones(batch_shape), - 'volume_max': - np.zeros(batch_shape), - 'weights': - np.ones(batch_shape), + 'inputs': _make_fake_image_batch( + batch_shape, data_shape=data_shape, num_classes=1000 + )['inputs'], + 'targets': _make_fake_image_batch( + batch_shape, data_shape=data_shape, num_classes=1000 + )['inputs'], + 'mean': np.zeros(batch_shape), + 'std': np.ones(batch_shape), + 'volume_max': np.zeros(batch_shape), + 'weights': np.ones(batch_shape), } else: raise ValueError( - 'Workload {} does not have a fake batch defined, you ' - 'can add it or use --use_fake_input_queue=false.'.format( - workload_name)) + 'Workload {} does not have a fake batch defined, you ' + 'can add it or use --use_fake_input_queue=false.'.format( + workload_name + ) + ) if framework == 'pytorch': def to_device(k, v): dtype = ( - torch.long if (k == 'targets' and workload_name != 'fastmri') else - torch.bool if k == 'weights' else torch.float) + torch.long + if (k == 'targets' and workload_name != 'fastmri') + else torch.bool + if k == 'weights' + else torch.float + ) if USE_PYTORCH_DDP: v = v[RANK] return torch.as_tensor(v, device=PYTORCH_DEVICE, dtype=dtype) @@ -387,24 +394,28 @@ def eval_model(self, *args, **kwargs): return _OneEvalBatchWorkload() -def _test_submission(workload_name, - framework, - submission_path, - search_space_path, - data_dir, - use_fake_input_queue, - n_gpus): +def _test_submission( + workload_name, + framework, + submission_path, + search_space_path, + data_dir, + use_fake_input_queue, + n_gpus, +): logging.info(f'========= Testing {workload_name} in {framework}.') FLAGS.framework = framework workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload_name]) workload_metadata['workload_path'] = os.path.join( - submission_runner.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + submission_runner.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_class = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - return_class=True) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + return_class=True, + ) print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) @@ -421,30 +432,32 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') - workload = _make_one_batch_workload(workload_class, - workload_name, - framework, - global_batch_size, - use_fake_input_queue, - n_gpus) + workload = _make_one_batch_workload( + workload_class, + workload_name, + framework, + global_batch_size, + use_fake_input_queue, + n_gpus, + ) # Get a sample hyperparameter setting. hyperparameters = {} if search_space_path != 'None': with open(search_space_path, 'r', encoding='UTF-8') as search_space_file: hyperparameters = halton.generate_search( - json.load(search_space_file), num_trials=1)[0] + json.load(search_space_file), num_trials=1 + )[0] rng = prng.PRNGKey(0) data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) input_queue = workload._build_input_queue( - data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size) + data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size + ) model_params, model_state = workload.init_model_fn(model_init_rng) - optimizer_state = init_optimizer_state(workload, - model_params, - model_state, - hyperparameters, - opt_init_rng) + optimizer_state = init_optimizer_state( + workload, model_params, model_state, hyperparameters, opt_init_rng + ) if USE_PYTORCH_DDP: torch.cuda.empty_cache() @@ -452,44 +465,49 @@ def _test_submission(workload_name, for global_step in range(FLAGS.num_train_steps): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - batch = data_selection(workload, - input_queue, - optimizer_state, - model_params, - model_state, - hyperparameters, - global_step, - data_select_rng) + batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + model_state, + hyperparameters, + global_step, + data_select_rng, + ) optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - batch=batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - train_state={}, - eval_results=[], - global_step=global_step, - rng=update_rng) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + batch=batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + train_state={}, + eval_results=[], + global_step=global_step, + rng=update_rng, + ) eval_result = workload.eval_model( - global_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir=None, - global_step=global_step) - _ = workload.eval_model( global_batch_size, model_params, model_state, eval_rng, data_dir, imagenet_v2_data_dir=None, - global_step=global_step) + global_step=global_step, + ) + _ = workload.eval_model( + global_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir=None, + global_step=global_step, + ) return eval_result @@ -499,12 +517,15 @@ def _make_paths(repo_location, framework, workload_name): else: dataset_name = workload_name workload_dir = ( - f'{repo_location}/reference_algorithms/target_setting_algorithms/' - f'{workload_name}') + f'{repo_location}/reference_algorithms/target_setting_algorithms/' + f'{workload_name}' + ) search_space_path = f'{workload_dir}/tuning_search_space.json' - submission_path = (f'reference_algorithms/target_setting_algorithms/' - f'{workload_name}/{dataset_name}_{framework}/' - 'submission.py') + submission_path = ( + f'reference_algorithms/target_setting_algorithms/' + f'{workload_name}/{dataset_name}_{framework}/' + 'submission.py' + ) full_submission_path = f'{repo_location}/{submission_path}' if not os.path.exists(full_submission_path): return None, None @@ -534,7 +555,8 @@ def test_submission(self): if FLAGS.tuning_search_space: raise ValueError('Cannot set --tuning_search_space and --all.') references_dir = ( - f'{repo_location}/reference_algorithms/target_setting_algorithms') + f'{repo_location}/reference_algorithms/target_setting_algorithms' + ) for workload_name in os.listdir(references_dir): for framework in ['jax', 'pytorch']: if framework == 'pytorch': @@ -542,17 +564,19 @@ def test_submission(self): # First jax operation has to be called after pytorch_init. n_gpus = max(N_GPUS, jax.local_device_count()) search_space_path, submission_path = _make_paths( - repo_location, framework, workload_name) + repo_location, framework, workload_name + ) if search_space_path is None: continue eval_result = _test_submission( - workload_name, - framework, - submission_path, - search_space_path, - data_dir=FLAGS.data_dir, - use_fake_input_queue=FLAGS.use_fake_input_queue, - n_gpus=n_gpus) + workload_name, + framework, + submission_path, + search_space_path, + data_dir=FLAGS.data_dir, + use_fake_input_queue=FLAGS.use_fake_input_queue, + n_gpus=n_gpus, + ) self._assert_eval_result(workload_name, eval_result) else: framework = FLAGS.framework @@ -566,15 +590,17 @@ def test_submission(self): submission_path = FLAGS.submission_path else: search_space_path, submission_path = _make_paths( - repo_location, framework, workload_name) + repo_location, framework, workload_name + ) eval_result = _test_submission( - workload_name, - framework, - submission_path, - search_space_path, - data_dir=FLAGS.data_dir, - use_fake_input_queue=FLAGS.use_fake_input_queue, - n_gpus=n_gpus) + workload_name, + framework, + submission_path, + search_space_path, + data_dir=FLAGS.data_dir, + use_fake_input_queue=FLAGS.use_fake_input_queue, + n_gpus=n_gpus, + ) self._assert_eval_result(workload_name, eval_result) if USE_PYTORCH_DDP: diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index ff724b201..b9beb9101 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -4,6 +4,7 @@ dataset to be available. For testing the workload and reference submission code for all workloads, see reference_algorithm_tests.py. """ + import copy import os import sys @@ -28,47 +29,46 @@ class SubmissionRunnerTest(parameterized.TestCase): """Tests for reference submissions.""" @parameterized.named_parameters( - dict( - testcase_name='mnist_jax', - workload='mnist', - framework='jax', - submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'), - tuning_search_space=( - f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')), - dict( - testcase_name='mnist_pytorch', - workload='mnist', - framework='pytorch', - submission_path=( - f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'), - tuning_search_space=( - f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')), + dict( + testcase_name='mnist_jax', + workload='mnist', + framework='jax', + submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'), + tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'), + ), + dict( + testcase_name='mnist_pytorch', + workload='mnist', + framework='pytorch', + submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'), + tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'), + ), ) - def test_submission(self, - workload, - framework, - submission_path, - tuning_search_space): + def test_submission( + self, workload, framework, submission_path, tuning_search_space + ): FLAGS.framework = framework workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload]) workload_metadata['workload_path'] = os.path.join( - submission_runner.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + submission_runner.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_obj = submission_runner.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs={}) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}, + ) score = submission_runner.score_submission_on_workload( - workload_obj, - workload, - submission_path, - data_dir='~/tensorflow_datasets', # The default in TFDS. - tuning_ruleset='external', - tuning_search_space=tuning_search_space, - num_tuning_trials=1, - profiler=PassThroughProfiler(), - max_global_steps=500, + workload_obj, + workload, + submission_path, + data_dir='~/tensorflow_datasets', # The default in TFDS. + tuning_ruleset='external', + tuning_search_space=tuning_search_space, + num_tuning_trials=1, + profiler=PassThroughProfiler(), + max_global_steps=500, ) logging.info(score) diff --git a/tests/test_baselines.py b/tests/test_baselines.py index b2be8aa11..9ebc50222 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -1,8 +1,9 @@ """Tests for submission.py for baselines. -This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that +This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that requires the dataset to be available. """ + import copy import os import sys @@ -24,41 +25,42 @@ MAX_GLOBAL_STEPS = 5 baselines = { - 'jax': [ - 'adafactor', - 'adamw', - 'lamb', - 'momentum', - 'nadamw', - 'nesterov', - 'sam', - 'shampoo', - ], - 'pytorch': [ - 'adamw', - 'momentum', - 'nadamw', - 'nesterov', - ], + 'jax': [ + 'adafactor', + 'adamw', + 'lamb', + 'momentum', + 'nadamw', + 'nesterov', + 'sam', + 'shampoo', + ], + 'pytorch': [ + 'adamw', + 'momentum', + 'nadamw', + 'nesterov', + ], } frameworks = [ - 'pytorch', - 'jax', + 'pytorch', + 'jax', ] -baseline_path = "reference_algorithms/paper_baselines" +baseline_path = 'reference_algorithms/paper_baselines' named_parameters = [] for f in frameworks: for b in baselines[f]: named_parameters.append( - dict( - testcase_name=f'{b}_{f}', - workload='mnist', - framework=f'{f}', - submission_path=f'{baseline_path}/{b}/{f}/submission.py', - tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json') + dict( + testcase_name=f'{b}_{f}', + workload='mnist', + framework=f'{f}', + submission_path=f'{baseline_path}/{b}/{f}/submission.py', + tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json', + ) ) @@ -66,31 +68,31 @@ class BaselineTest(parameterized.TestCase): """Tests for reference submissions.""" @parameterized.named_parameters(*named_parameters) - def test_baseline_submission(self, - workload, - framework, - submission_path, - tuning_search_space): + def test_baseline_submission( + self, workload, framework, submission_path, tuning_search_space + ): FLAGS.framework = framework workload_metadata = copy.deepcopy(workloads.WORKLOADS[workload]) workload_metadata['workload_path'] = os.path.join( - workloads.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + workloads.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_obj = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs={}) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}, + ) score = submission_runner.score_submission_on_workload( - workload_obj, - workload, - submission_path, - data_dir='~/tensorflow_datasets', # The default in TFDS. - tuning_ruleset='external', - tuning_search_space=tuning_search_space, - num_tuning_trials=1, - profiler=PassThroughProfiler(), - max_global_steps=MAX_GLOBAL_STEPS, + workload_obj, + workload, + submission_path, + data_dir='~/tensorflow_datasets', # The default in TFDS. + tuning_ruleset='external', + tuning_search_space=tuning_search_space, + num_tuning_trials=1, + profiler=PassThroughProfiler(), + max_global_steps=MAX_GLOBAL_STEPS, ) logging.info(score) diff --git a/tests/test_num_params.py b/tests/test_num_params.py index b0633025e..9361f4c72 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -5,48 +5,59 @@ import pytest import torch -from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ - DlrmSmall as JaxDlrmSmall -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DlrmSmall as PyTorchDlrmSmall -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNet18 as JaxResNet_c10 -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNet50 as JaxResNet -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet18 as PyTorchResNet_c10 -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet50 as PyTorchResNet +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import ( + DlrmSmall as JaxDlrmSmall, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( + DlrmSmall as PyTorchDlrmSmall, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ( + ResNet18 as JaxResNet_c10, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ( + ResNet50 as JaxResNet, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + resnet18 as PyTorchResNet_c10, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + resnet50 as PyTorchResNet, +) from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ - ViT as PyTorchViT -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ - Conformer as JaxConformer -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ - ConformerConfig as JaxConformerConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerConfig as PytorchConformerConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerEncoderDecoder as PytorchConformer +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ( + ViT as PyTorchViT, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ( + Conformer as JaxConformer, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ( + ConformerConfig as JaxConformerConfig, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + ConformerConfig as PytorchConformerConfig, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + ConformerEncoderDecoder as PytorchConformer, +) from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP -from algoperf.workloads.mnist.mnist_pytorch.workload import \ - _Model as PyTorchMLP +from algoperf.workloads.mnist.mnist_pytorch.workload import _Model as PyTorchMLP from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig -from algoperf.workloads.wmt.wmt_pytorch.models import \ - Transformer as PyTorchTransformer +from algoperf.workloads.wmt.wmt_pytorch.models import ( + Transformer as PyTorchTransformer, +) WORKLOADS = [ - 'mnist', - 'cifar', - 'criteo1tb', - 'imagenet_resnet', - 'imagenet_vit', - 'wmt', - 'ogbg', - 'librispeech_conformer', + 'mnist', + 'cifar', + 'criteo1tb', + 'imagenet_resnet', + 'imagenet_vit', + 'wmt', + 'ogbg', + 'librispeech_conformer', ] @@ -56,7 +67,8 @@ def test_matching_num_params(workload): # Count parameters of both models. num_jax_params = sum(x.size for x in jax.tree_util.tree_leaves(jax_model)) num_pytorch_params = sum( - p.numel() for p in pytorch_model.parameters() if p.requires_grad) + p.numel() for p in pytorch_model.parameters() if p.requires_grad + ) assert num_jax_params == num_pytorch_params @@ -72,8 +84,9 @@ def get_models(workload): # Init Jax model. input_shape = (1, 32, 32, 3) model_init = jax.jit(JaxResNet_c10(num_classes=10, dtype=jnp.float32).init) - jax_model = model_init(init_rngs, jnp.ones(input_shape, - jnp.float32))["params"] + jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32))[ + 'params' + ] # Init PyTorch model. pytorch_model = PyTorchResNet_c10(num_classes=10) @@ -85,35 +98,38 @@ def get_models(workload): vocab_size = 32 * 128 * 1024 input_shape = (1, 39) model_init = JaxDlrmSmall( - vocab_size=vocab_size, - num_dense_features=13, - mlp_bottom_dims=mlp_bottom_dims, - mlp_top_dims=mlp_top_dims, - embed_dim=embed_dim).init - jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32), - False)['params'] + vocab_size=vocab_size, + num_dense_features=13, + mlp_bottom_dims=mlp_bottom_dims, + mlp_top_dims=mlp_top_dims, + embed_dim=embed_dim, + ).init + jax_model = model_init( + init_rngs, jnp.ones(input_shape, jnp.float32), False + )['params'] # Init PyTorch model. pytorch_model = PyTorchDlrmSmall( - vocab_size=vocab_size, - num_dense_features=13, - mlp_bottom_dims=mlp_bottom_dims, - mlp_top_dims=mlp_top_dims, - embed_dim=embed_dim) + vocab_size=vocab_size, + num_dense_features=13, + mlp_bottom_dims=mlp_bottom_dims, + mlp_top_dims=mlp_top_dims, + embed_dim=embed_dim, + ) elif workload == 'imagenet_resnet': # Init Jax model. input_shape = (1, 224, 224, 3) - jax_model = JaxResNet( - num_classes=1000, - dtype=jnp.float32).init(init_rngs, jnp.ones(input_shape, - jnp.float32))['params'] + jax_model = JaxResNet(num_classes=1000, dtype=jnp.float32).init( + init_rngs, jnp.ones(input_shape, jnp.float32) + )['params'] # Init PyTorch model. pytorch_model = PyTorchResNet() elif workload == 'imagenet_vit': # Init Jax model. input_shape = (1, 224, 224, 3) jax_model = JaxViT(num_classes=1000).init( - init_rngs, jnp.ones(input_shape, jnp.float32))['params'] + init_rngs, jnp.ones(input_shape, jnp.float32) + )['params'] # Init PyTorch model. pytorch_model = PyTorchViT() elif workload == 'librispeech_conformer': @@ -123,8 +139,9 @@ def get_models(workload): # Init Jax model input_shape = [(320000,), (320000,)] fake_input_batch = [jnp.zeros((2, *x), jnp.float32) for x in input_shape] - jax_model = jax_model.init( - init_rngs, train=False, *fake_input_batch)["params"] + jax_model = jax_model.init(init_rngs, train=False, *fake_input_batch)[ + 'params' + ] # Run model once to initialize lazy layers wave = torch.randn(2, 320000) @@ -136,23 +153,26 @@ def get_models(workload): input_shape = (16, 256) target_shape = (16, 256) jax_model = JaxTransformer(TransformerConfig).init( - init_rngs, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32))['params'] + init_rngs, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + )['params'] # Init PyTorch model. pytorch_model = PyTorchTransformer() elif workload == 'ogbg': # Init Jax model. fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0]), + ) jax_model = JaxGNN(num_outputs=128).init( - init_rngs, fake_batch, train=False)['params'] + init_rngs, fake_batch, train=False + )['params'] # Init PyTorch model. pytorch_model = PyTorchGNN(num_outputs=128) else: diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index df4c798d8..1badd39ed 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -7,40 +7,80 @@ # isort: skip_file # pylint:disable=line-too-long -from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import ( + CifarWorkload as JaxCifarWorkload, +) +from algoperf.workloads.cifar.cifar_pytorch.workload import ( + CifarWorkload as PyTorchCifarWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload, +) +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxFastMRIWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchFastMRIWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxImagenetViTWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchImagenetViTWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.mnist.mnist_jax.workload import ( + MnistWorkload as JaxMnistWorkload, +) +from algoperf.workloads.mnist.mnist_pytorch.workload import ( + MnistWorkload as PyTorchMnistWorkload, +) +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxOgbgWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchOgbgWorkload, +) +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkload as JaxWmtWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWmtWorkload, +) # pylint:enable=line-too-long WORKLOADS = [ - 'cifar', - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - # TODO: make tests work for these. - # 'librispeech_conformer', - # 'librispeech_deepspeech', - 'mnist', - 'ogbg', - 'wmt', + 'cifar', + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + # TODO: make tests work for these. + # 'librispeech_conformer', + # 'librispeech_deepspeech', + 'mnist', + 'ogbg', + 'wmt', ] @@ -56,9 +96,11 @@ def test_param_shapes(workload): if isinstance(jax_workload_param_shapes, dict): jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload_param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze() + ) pytorch_param_shapes = jax.tree_util.tree_leaves( - pytorch_workload.param_shapes) + pytorch_workload.param_shapes + ) if workload == 'wmt': # The PyTorch transformer for WMT is implemented with fused linear layers # for the projection of QKV inside of the MultiheadAttention module. @@ -74,8 +116,9 @@ def test_param_shapes(workload): # Check if total number of params deduced from shapes match. num_jax_params = 0 num_pytorch_params = 0 - for jax_shape, pytorch_shape in zip_longest(jax_param_shapes, - pytorch_param_shapes): + for jax_shape, pytorch_shape in zip_longest( + jax_param_shapes, pytorch_param_shapes + ): if jax_shape is not None: num_jax_params += np.prod(jax_shape.shape_tuple) if pytorch_shape is not None: diff --git a/tests/test_param_types.py b/tests/test_param_types.py index d3722ae86..1583342ff 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -6,39 +6,79 @@ # isort: skip_file # pylint:disable=line-too-long -from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import ( + CifarWorkload as JaxCifarWorkload, +) +from algoperf.workloads.cifar.cifar_pytorch.workload import ( + CifarWorkload as PyTorchCifarWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload, +) +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxFastMRIWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchFastMRIWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxImagenetViTWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchImagenetViTWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.mnist.mnist_jax.workload import ( + MnistWorkload as JaxMnistWorkload, +) +from algoperf.workloads.mnist.mnist_pytorch.workload import ( + MnistWorkload as PyTorchMnistWorkload, +) +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxOgbgWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchOgbgWorkload, +) +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkload as JaxWmtWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWmtWorkload, +) # pylint:enable=line-too-long WORKLOADS = [ - 'cifar', - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'mnist', - 'ogbg', - 'wmt', + 'cifar', + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'mnist', + 'ogbg', + 'wmt', ] @@ -66,40 +106,32 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): # Sometimes one framework will implement QKV as a single parameter, so we need # to make sure there are the same number of QKV params as Q, K, V. num_qkv = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), + 'pytorch': pytorch_param_types_dict.get( + spec.ParameterType.ATTENTION_QKV, 0 + ), } num_kv = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), } num_q = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), } num_k = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), } num_v = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), } num_bias = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), + 'pytorch': pytorch_param_types_dict.get( + spec.ParameterType.ATTENTION_BIAS, 0 + ), } qkv_match = num_qkv['jax'] == num_qkv['pytorch'] kv_match = num_kv['jax'] == num_kv['pytorch'] @@ -108,24 +140,33 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): v_match = num_v['jax'] == num_v['pytorch'] bias_match = num_bias['jax'] == num_bias['pytorch'] qkv_match = ( - qkv_match and kv_match and q_match and k_match and v_match and bias_match) + qkv_match and kv_match and q_match and k_match and v_match and bias_match + ) # We subtract 2 * num_qkv from the number of biases because there are 2 # missing for each of q, k, v. - jax_qkv_match = ( - num_q['pytorch'] == num_k['pytorch'] == num_v['pytorch'] == num_qkv['jax'] - and (num_qkv['jax'] != 0 and - (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax'])) - pytorch_qkv_match = ( - num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and - (num_qkv['pytorch'] != 0 and - (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'])) + jax_qkv_match = num_q['pytorch'] == num_k['pytorch'] == num_v[ + 'pytorch' + ] == num_qkv['jax'] and ( + num_qkv['jax'] != 0 + and (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax'] + ) + pytorch_qkv_match = num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv[ + 'pytorch' + ] and ( + num_qkv['pytorch'] != 0 + and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'] + ) pytorch_kv_match = ( - num_q['jax'] == num_k['jax'] == num_v['jax'] == - num_qkv['pytorch'] + num_kv['pytorch'] and - num_q['pytorch'] == num_kv['pytorch']) + num_q['jax'] + == num_k['jax'] + == num_v['jax'] + == num_qkv['pytorch'] + num_kv['pytorch'] + and num_q['pytorch'] == num_kv['pytorch'] + ) qkv_match = ( - qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match) + qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match + ) return qkv_match @@ -137,7 +178,8 @@ def test_param_types(workload_name): # Compare number of parameter tensors of both models. jax_param_types = jax.tree_util.tree_leaves(jax_workload.model_params_types) pytorch_param_types = jax.tree_util.tree_leaves( - pytorch_workload.model_params_types) + pytorch_workload.model_params_types + ) jax_param_types_dict = count_param_types(jax_param_types) pytorch_param_types_dict = count_param_types(pytorch_param_types) @@ -161,30 +203,33 @@ def test_param_types(workload_name): # Check if total number of each type match. attention_keys = { - spec.ParameterType.ATTENTION_QKV, - spec.ParameterType.ATTENTION_KV, - spec.ParameterType.ATTENTION_Q, - spec.ParameterType.ATTENTION_K, - spec.ParameterType.ATTENTION_V, - spec.ParameterType.ATTENTION_BIAS, + spec.ParameterType.ATTENTION_QKV, + spec.ParameterType.ATTENTION_KV, + spec.ParameterType.ATTENTION_Q, + spec.ParameterType.ATTENTION_K, + spec.ParameterType.ATTENTION_V, + spec.ParameterType.ATTENTION_BIAS, } non_attention_keys = set(jax_param_types_dict.keys()).union( - set(pytorch_param_types_dict.keys())) + set(pytorch_param_types_dict.keys()) + ) non_attention_keys -= attention_keys mismatches = '' - mismatches += _count_mismatches(jax_param_types_dict, - pytorch_param_types_dict, - non_attention_keys) - qkv_match = _check_attention_qkv_match(jax_param_types_dict, - pytorch_param_types_dict) + mismatches += _count_mismatches( + jax_param_types_dict, pytorch_param_types_dict, non_attention_keys + ) + qkv_match = _check_attention_qkv_match( + jax_param_types_dict, pytorch_param_types_dict + ) if not qkv_match: - mismatches += _count_mismatches(jax_param_types_dict, - pytorch_param_types_dict, - attention_keys) + mismatches += _count_mismatches( + jax_param_types_dict, pytorch_param_types_dict, attention_keys + ) if mismatches: raise ValueError( - f'On workload {workload_name}, count mismatch: {mismatches}') + f'On workload {workload_name}, count mismatch: {mismatches}' + ) def get_workload(workload_name): diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 920556964..7d730c251 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -10,13 +10,14 @@ import torch from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.fastmri.fastmri_jax.ssim import \ - _uniform_filter as _jax_uniform_filter +from algoperf.workloads.fastmri.fastmri_jax.ssim import ( + _uniform_filter as _jax_uniform_filter, +) from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim -from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ - _uniform_filter as _pytorch_uniform_filter -from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ - ssim as pytorch_ssim +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ( + _uniform_filter as _pytorch_uniform_filter, +) +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim as pytorch_ssim # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' @@ -31,7 +32,7 @@ def _create_fake_im(height: int, width: int) -> Tuple[jnp.array, torch.Tensor]: def _create_fake_batch( - batch_size: int, height: int, width: int + batch_size: int, height: int, width: int ) -> Tuple[Tuple[jnp.array, jnp.array], Tuple[torch.Tensor, torch.Tensor]]: logits = np.random.randn(batch_size, height, width) targets = np.random.randn(batch_size, height, width) @@ -47,9 +48,9 @@ class SSIMTest(parameterized.TestCase): and PyTorch.""" @parameterized.named_parameters( - dict(testcase_name='fastmri_im', height=320, width=320), - dict(testcase_name='uneven_even_im', height=31, width=16), - dict(testcase_name='even_uneven_im', height=42, width=53), + dict(testcase_name='fastmri_im', height=320, width=320), + dict(testcase_name='uneven_even_im', height=31, width=16), + dict(testcase_name='even_uneven_im', height=42, width=53), ) def test_uniform_filter(self, height: int, width: int) -> None: jax_im, pytorch_im = _create_fake_im(height, width) @@ -58,12 +59,9 @@ def test_uniform_filter(self, height: int, width: int) -> None: assert np.allclose(jax_result, torch_result, atol=1e-6) @parameterized.named_parameters( - dict( - testcase_name='fastmri_batch', batch_size=256, height=320, width=320), - dict( - testcase_name='uneven_even_batch', batch_size=8, height=31, width=16), - dict( - testcase_name='even_uneven_batch', batch_size=8, height=42, width=53), + dict(testcase_name='fastmri_batch', batch_size=256, height=320, width=320), + dict(testcase_name='uneven_even_batch', batch_size=8, height=31, width=16), + dict(testcase_name='even_uneven_batch', batch_size=8, height=42, width=53), ) def test_ssim(self, batch_size: int, height: int, width: int) -> None: jax_inputs, pytorch_inputs = _create_fake_batch(batch_size, height, width) @@ -71,9 +69,8 @@ def test_ssim(self, batch_size: int, height: int, width: int) -> None: pytorch_ssim_result = pytorch_ssim(*pytorch_inputs) self.assertEqual(jax_ssim_result.shape, pytorch_ssim_result.shape) assert np.allclose( - jax_ssim_result.sum().item(), - pytorch_ssim_result.sum().item(), - atol=1e-6) + jax_ssim_result.sum().item(), pytorch_ssim_result.sum().item(), atol=1e-6 + ) if __name__ == '__main__': diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index cea589202..b1982a3bf 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -3,6 +3,7 @@ Run it as: python3 test_traindiffs.py """ + import pickle import subprocess from subprocess import DEVNULL @@ -17,14 +18,14 @@ FLAGS = flags.FLAGS WORKLOADS = [ - 'imagenet_resnet', - 'imagenet_vit', - 'wmt', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'fastmri', - 'ogbg', - 'criteo1tb' + 'imagenet_resnet', + 'imagenet_vit', + 'wmt', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'fastmri', + 'ogbg', + 'criteo1tb', ] GLOBAL_BATCH_SIZE = 16 NUM_TRAIN_STEPS = 10 @@ -35,7 +36,6 @@ class ModelDiffTest(parameterized.TestCase): - @parameterized.named_parameters(*named_parameters) def test_workload(self, workload): # pylint: disable=line-too-long, unnecessary-lambda-assignment @@ -50,24 +50,26 @@ def test_workload(self, workload): pytorch_logs_path = '/tmp/pyt_log.pkl' try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}' - f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', - shell=True, - stdout=DEVNULL, - stderr=STDOUT, - check=True) + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}' + f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', + shell=True, + stdout=DEVNULL, + stderr=STDOUT, + check=True, + ) except subprocess.CalledProcessError as e: - print("Error:", e) + print('Error:', e) try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}' - f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', - shell=True, - stdout=DEVNULL, - stderr=STDOUT, - check=True) + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}' + f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', + shell=True, + stdout=DEVNULL, + stderr=STDOUT, + check=True, + ) except subprocess.CalledProcessError as e: - print("Error:", e) + print('Error:', e) with open(jax_logs_path, 'rb') as f: jax_results = pickle.load(f) with open(pytorch_logs_path, 'rb') as f: @@ -75,17 +77,20 @@ def test_workload(self, workload): # PRINT RESULTS eval_metric_key = next( - iter( - filter(lambda k: 'train' in k and 'loss' in k, - jax_results['eval_results'][0]))) + iter( + filter( + lambda k: 'train' in k and 'loss' in k, jax_results['eval_results'][0] + ) + ) + ) header = [ - 'Iter', - 'Eval (jax)', - 'Eval (torch)', - 'Grad Norm (jax)', - 'Grad Norm (torch)', - 'Train Loss (jax)', - 'Train Loss (torch)', + 'Iter', + 'Eval (jax)', + 'Eval (torch)', + 'Grad Norm (jax)', + 'Grad Norm (torch)', + 'Train Loss (jax)', + 'Train Loss (torch)', ] fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|' header = fmt(header) @@ -97,33 +102,41 @@ def test_workload(self, workload): for i in range(NUM_TRAIN_STEPS): rtol = 1 - row = map(lambda x: str(round(x, 5)), - [ - jax_results['eval_results'][i][eval_metric_key], - pytorch_results['eval_results'][i][eval_metric_key], - jax_results['scalars'][i]['grad_norm'], - pytorch_results['scalars'][i]['grad_norm'], - jax_results['scalars'][i]['loss'], - pytorch_results['scalars'][i]['loss'], - ]) + row = map( + lambda x: str(round(x, 5)), + [ + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], + jax_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], + jax_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], + ], + ) print(fmt([f'{i}', *row])) print('=' * len(header)) self.assertTrue( # eval_results - allclose( - jax_results['eval_results'][i][eval_metric_key], - pytorch_results['eval_results'][i][eval_metric_key], - rtol=rtol)) + allclose( + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], + rtol=rtol, + ) + ) self.assertTrue( # grad_norms - allclose( - jax_results['scalars'][i]['grad_norm'], - pytorch_results['scalars'][i]['grad_norm'], - rtol=rtol)) + allclose( + jax_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], + rtol=rtol, + ) + ) self.assertTrue( # loss - allclose( - jax_results['scalars'][i]['loss'], - pytorch_results['scalars'][i]['loss'], - rtol=rtol)) + allclose( + jax_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], + rtol=rtol, + ) + ) if __name__ == '__main__': diff --git a/tests/test_version.py b/tests/test_version.py index d1bfbd18f..69384953a 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -6,10 +6,10 @@ def test_version_attribute(): """Check whether __version__ exists and is a valid string.""" - assert hasattr(algoperf, "__version__") + assert hasattr(algoperf, '__version__') version = algoperf.__version__ assert isinstance(version, str) - version_elements = version.split(".") + version_elements = version.split('.') print(version_elements) # Only check the first two elements, i.e. major, minor # (patch is not checked as it is not required). diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index d44234927..3d06c9839 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -5,8 +5,9 @@ import jax.numpy as jnp from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload, +) def _pytree_total_diff(pytree_a, pytree_b): @@ -32,42 +33,48 @@ def test_forward_pass(self): # this function because we call it with a different combination of those two # args each time. Can't call with kwargs. pmapped_model_fn = jax.pmap( - workload.model_fn, - axis_name='batch', - in_axes=(0, 0, 0, None, None, None), - static_broadcasted_argnums=(3, 5)) + workload.model_fn, + axis_name='batch', + in_axes=(0, 0, 0, None, None, None), + static_broadcasted_argnums=(3, 5), + ) logits, updated_batch_stats = pmapped_model_fn( - model_params, - {'inputs': first_input_batch}, - batch_stats, - spec.ForwardPassMode.TRAIN, - rng, - True) + model_params, + {'inputs': first_input_batch}, + batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + True, + ) self.assertEqual(logits.shape, expected_logits_shape) # Test that batch stats are updated. self.assertNotEqual( - _pytree_total_diff(batch_stats, updated_batch_stats), 0.0) + _pytree_total_diff(batch_stats, updated_batch_stats), 0.0 + ) second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape) # Test that batch stats are not updated when we say so. _, same_batch_stats = pmapped_model_fn( - model_params, - {'inputs': second_input_batch}, - updated_batch_stats, - spec.ForwardPassMode.TRAIN, - rng, - False) + model_params, + {'inputs': second_input_batch}, + updated_batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + False, + ) self.assertEqual( - _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0) + _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0 + ) # Test eval model. logits, _ = pmapped_model_fn( - model_params, - {'inputs': second_input_batch}, - batch_stats, - spec.ForwardPassMode.EVAL, - rng, - False) + model_params, + {'inputs': second_input_batch}, + batch_stats, + spec.ForwardPassMode.EVAL, + rng, + False, + ) self.assertEqual(logits.shape, expected_logits_shape) From f02671100980e742464c273a0825df9c4d99f264 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:20:18 +0200 Subject: [PATCH 105/123] Format prize_qualification_baselines/ --- .../external_tuning/jax_nadamw_full_budget.py | 269 ++++++++-------- .../jax_nadamw_target_setting.py | 269 ++++++++-------- .../pytorch_nadamw_full_budget.py | 274 +++++++++-------- .../pytorch_nadamw_target_setting.py | 274 +++++++++-------- .../self_tuning/jax_nadamw_full_budget.py | 281 +++++++++-------- .../self_tuning/jax_nadamw_target_setting.py | 281 +++++++++-------- .../self_tuning/pytorch_nadamw_full_budget.py | 286 ++++++++++-------- .../pytorch_nadamw_target_setting.py | 286 ++++++++++-------- 8 files changed, 1224 insertions(+), 996 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index c451a18ac..aa1a08f69 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -4,15 +4,17 @@ # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) # isort: on import chex @@ -30,15 +32,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +74,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +128,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +137,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +304,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +380,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index b8ac10f33..f1d7d62e0 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -4,15 +4,17 @@ # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) # isort: on import chex @@ -30,15 +32,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +74,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +128,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +137,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +304,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +380,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index a2f9fb4c5..ecd299988 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -21,33 +21,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +56,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +67,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +77,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +108,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +196,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +260,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +296,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +306,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +377,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index a37b0d341..0d8054135 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -21,33 +21,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +56,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +67,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +77,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +108,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +196,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +260,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +296,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +306,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +377,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 78c3b5b3e..fb322bd5a 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -4,15 +4,17 @@ # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) # isort: on import chex @@ -27,27 +29,26 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -82,19 +83,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -133,7 +137,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -141,6 +146,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -149,7 +155,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -165,11 +172,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -182,101 +191,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters['learning_rate'], - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters['learning_rate'], + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps) + init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters['one_minus_beta1'], - b2=hyperparameters['beta2'], - eps=1e-8, - weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters['one_minus_beta1'], + b2=hyperparameters['beta2'], + eps=1e-8, + weight_decay=hyperparameters['weight_decay'], + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -296,37 +319,43 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -366,14 +395,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index ffe854a0e..99d996bb9 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -4,15 +4,17 @@ # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) # isort: on import chex @@ -27,27 +29,26 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -82,19 +83,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -133,7 +137,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -141,6 +146,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -149,7 +155,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -165,11 +172,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -182,101 +191,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters['learning_rate'], - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters['learning_rate'], + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps) + init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters['one_minus_beta1'], - b2=hyperparameters['beta2'], - eps=1e-8, - weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters['one_minus_beta1'], + b2=hyperparameters['beta2'], + eps=1e-8, + weight_decay=hyperparameters['weight_decay'], + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -296,37 +319,43 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -366,14 +395,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 554a28762..cc54e3b4e 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -18,12 +18,12 @@ USE_PYTORCH_DDP = pytorch_setup()[0] HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) @@ -32,33 +32,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -70,7 +67,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -78,7 +78,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -87,9 +88,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -118,51 +119,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -200,11 +207,13 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -213,44 +222,47 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters = HPARAMS optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -265,26 +277,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -297,7 +313,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -306,31 +323,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -370,14 +394,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index e4317fa18..bd065dc06 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -18,12 +18,12 @@ USE_PYTORCH_DDP = pytorch_setup()[0] HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) @@ -32,33 +32,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -70,7 +67,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -78,7 +78,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -87,9 +88,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -118,51 +119,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -200,11 +207,13 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -213,44 +222,47 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters = HPARAMS optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -265,26 +277,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -297,7 +313,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -306,31 +323,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -370,14 +394,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ From 531c99ec5545a01905b53bd2f92c5805dec1576b Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:20:52 +0200 Subject: [PATCH 106/123] Format datasets/ --- datasets/dataset_setup.py | 519 ++++++++++++++++------------- datasets/librispeech_preprocess.py | 74 ++-- datasets/librispeech_tokenizer.py | 42 ++- 3 files changed, 366 insertions(+), 269 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..e110930cd 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -72,8 +72,7 @@ from torchvision.datasets import CIFAR10 from algoperf.workloads.wmt import tokenizer -from algoperf.workloads.wmt.input_pipeline import \ - normalize_feature_names +from algoperf.workloads.wmt.input_pipeline import normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer @@ -101,84 +100,96 @@ FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz' flags.DEFINE_boolean( - 'interactive_deletion', - True, - 'If true, user will be prompted before any files are deleted. If false, no ' - 'files will be deleted.') + 'interactive_deletion', + True, + 'If true, user will be prompted before any files are deleted. If false, no ' + 'files will be deleted.', +) flags.DEFINE_boolean( - 'all', - False, - 'Whether or not to download all datasets. If false, can download some ' - 'combination of datasets by setting the individual dataset flags below.') - -flags.DEFINE_boolean('criteo1tb', - False, - 'If --all=false, whether or not to download Criteo 1TB.') -flags.DEFINE_boolean('cifar', - False, - 'If --all=false, whether or not to download CIFAR-10.') -flags.DEFINE_boolean('fastmri', - False, - 'If --all=false, whether or not to download FastMRI.') -flags.DEFINE_boolean('imagenet', - False, - 'If --all=false, whether or not to download Imagenet.') -flags.DEFINE_boolean('librispeech', - False, - 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('mnist', - False, - 'If --all=false, whether or not to download MNIST.') -flags.DEFINE_boolean('ogbg', - False, - 'If --all=false, whether or not to download OGBG.') -flags.DEFINE_boolean('wmt', - False, - 'If --all=false, whether or not to download WMT.') + 'all', + False, + 'Whether or not to download all datasets. If false, can download some ' + 'combination of datasets by setting the individual dataset flags below.', +) + +flags.DEFINE_boolean( + 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.' +) +flags.DEFINE_boolean( + 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.' +) +flags.DEFINE_boolean( + 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' +) +flags.DEFINE_boolean( + 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' +) +flags.DEFINE_boolean( + 'librispeech', + False, + 'If --all=false, whether or not to download LibriSpeech.', +) +flags.DEFINE_boolean( + 'mnist', False, 'If --all=false, whether or not to download MNIST.' +) +flags.DEFINE_boolean( + 'ogbg', False, 'If --all=false, whether or not to download OGBG.' +) +flags.DEFINE_boolean( + 'wmt', False, 'If --all=false, whether or not to download WMT.' +) flags.DEFINE_string( - 'data_dir', - '~/data', - 'The path to the folder where datasets should be downloaded.') + 'data_dir', + '~/data', + 'The path to the folder where datasets should be downloaded.', +) flags.DEFINE_string( - 'temp_dir', - '/tmp/mlcommons', - 'A local path to a folder where temp files can be downloaded.') + 'temp_dir', + '/tmp/mlcommons', + 'A local path to a folder where temp files can be downloaded.', +) flags.DEFINE_string( - 'imagenet_train_url', - None, - 'Only necessary if you want this script to `wget` the ImageNet train ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'imagenet_train_url', + None, + 'Only necessary if you want this script to `wget` the ImageNet train ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'imagenet_val_url', - None, - 'Only necessary if you want this script to `wget` the ImageNet validation ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'imagenet_val_url', + None, + 'Only necessary if you want this script to `wget` the ImageNet validation ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_train_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI train ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_train_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI train ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_val_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI validation ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_val_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI validation ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_test_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI test ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_test_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI test ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_integer( - 'num_decompression_threads', - 8, - 'The number of threads to use in parallel when decompressing.') + 'num_decompression_threads', + 8, + 'The number of threads to use in parallel when decompressing.', +) flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') @@ -186,7 +197,7 @@ FLAGS = flags.FLAGS -os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' def _maybe_mkdir(d): @@ -198,8 +209,10 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion): if not interactive_deletion: return files_for_deletion = '\n'.join(paths) - logging.info('\n\n\nWARNING: the following temp files will be DELETED:' - f'\n{files_for_deletion}') + logging.info( + '\n\n\nWARNING: the following temp files will be DELETED:' + f'\n{files_for_deletion}' + ) delete_str = input('Confirm deletion? [y/N]: ') if delete_str.lower() == 'y': del_cmd = 'rm ' + ' '.join(f'"{s}"' for s in paths) @@ -225,8 +238,9 @@ def _download_url(url, data_dir, name=None): if os.path.exists(file_path): while True: - overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( - file_path)).lower() + overwrite = input( + 'File already exists {}.\n Overwrite? (Y/n)'.format(file_path) + ).lower() if overwrite in ['y', 'n']: break logging.info('Invalid response. Try again.') @@ -240,17 +254,18 @@ def _download_url(url, data_dir, name=None): progress_bar.update(chunk_size_in_mib) f.write(chunk) progress_bar.close() - if (progress_bar.total != 0 and progress_bar.n != progress_bar.total): + if progress_bar.total != 0 and progress_bar.n != progress_bar.total: raise RuntimeError( - ('Download corrupted, size {n} MiB from {url} does not match ' - 'expected size {size} MiB').format( - url=url, n=progress_bar.n, size=progress_bar.total)) + ( + 'Download corrupted, size {n} MiB from {url} does not match ' + 'expected size {size} MiB' + ).format(url=url, n=progress_bar.n, size=progress_bar.total) + ) -def download_criteo1tb(data_dir, - tmp_dir, - num_decompression_threads, - interactive_deletion): +def download_criteo1tb( + data_dir, tmp_dir, num_decompression_threads, interactive_deletion +): criteo_dir = os.path.join(data_dir, 'criteo1tb') tmp_criteo_dir = os.path.join(tmp_dir, 'criteo1tb') _maybe_mkdir(criteo_dir) @@ -258,47 +273,56 @@ def download_criteo1tb(data_dir, # Forked from # https://github.com/iamleot/transferwee/blob/master/transferwee.py. - user_agent = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) ' - 'Gecko/20100101 Firefox/102.0') + user_agent = ( + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) ' + 'Gecko/20100101 Firefox/102.0' + ) criteo_wetransfer_url = ( - 'https://criteo.wetransfer.com/downloads/' - '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2') + 'https://criteo.wetransfer.com/downloads/' + '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2' + ) _, _, transfer_id, security_hash = urllib.parse.urlparse( - criteo_wetransfer_url).path.split('/') + criteo_wetransfer_url + ).path.split('/') session = requests.Session() - session.headers.update({ + session.headers.update( + { 'User-Agent': user_agent, 'x-requested-with': 'XMLHttpRequest', - }) + } + ) r = session.get('https://wetransfer.com/') m = re.search('name="csrf-token" content="([^"]+)"', r.text) if m: session.headers.update({'x-csrf-token': m.group(1)}) get_url_request = session.post( - f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download', - json={ - 'intent': 'entire_transfer', - 'security_hash': security_hash, - }) + f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download', + json={ + 'intent': 'entire_transfer', + 'security_hash': security_hash, + }, + ) session.close() download_url = get_url_request.json().get('direct_link') logging.info(f'Downloading ~342GB Criteo 1TB data .zip file:\n{download_url}') download_request = requests.get( # pylint: disable=missing-timeout - download_url, - headers={'User-Agent': user_agent}, - stream=True) + download_url, headers={'User-Agent': user_agent}, stream=True + ) all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip') if not FLAGS.skip_download: download = True if os.path.exists(all_days_zip_filepath): while True: - overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( - all_days_zip_filepath)).lower() + overwrite = input( + 'File already exists {}.\n Overwrite? (Y/n)'.format( + all_days_zip_filepath + ) + ).lower() if overwrite in ['y', 'n']: break logging.info('Invalid response. Try again.') @@ -324,8 +348,10 @@ def download_criteo1tb(data_dir, input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz') gz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') - unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' - f'"{unzipped_path}"') + unzip_cmd = ( + f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' + f'"{unzipped_path}"' + ) logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}') processes.append(subprocess.Popen(unzip_cmd, shell=True)) for p in processes: @@ -341,8 +367,7 @@ def download_criteo1tb(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 2 -d -l 5000000 ' - f'"{unzipped_path}" "{split_path}"') + split_cmd = f'split -a 2 -d -l 5000000 "{unzipped_path}" "{split_path}"' logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) for p in batch_processes: @@ -362,45 +387,50 @@ def download_cifar(data_dir, framework): def extract_filename_from_url(url, start_str='knee', end_str='.xz'): - """ The url filenames are sometimes couched within a urldefense+aws access id + """The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. - """ + """ failure = -1 start = url.find(start_str) end = url.find(end_str) if failure in (start, end): raise ValueError( - f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}') + f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}' + ) end += len(end_str) # make it inclusive return url[start:end] -def download_fastmri(data_dir, - fastmri_train_url, - fastmri_val_url, - fastmri_test_url): +def download_fastmri( + data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url +): data_dir = os.path.join(data_dir, 'fastmri') # Download fastmri train dataset knee_train_filename = extract_filename_from_url(fastmri_train_url) logging.info( - 'Downloading fastmri train dataset from {}'.format(fastmri_train_url)) + 'Downloading fastmri train dataset from {}'.format(fastmri_train_url) + ) _download_url( - url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename) + url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename + ) # Download fastmri val dataset knee_val_filename = extract_filename_from_url(fastmri_val_url) logging.info( - 'Downloading fastmri val dataset from {}'.format(fastmri_val_url)) + 'Downloading fastmri val dataset from {}'.format(fastmri_val_url) + ) _download_url(url=fastmri_val_url, data_dir=data_dir, name=knee_val_filename) # Download fastmri test dataset knee_test_filename = extract_filename_from_url(fastmri_test_url) logging.info( - 'Downloading fastmri test dataset from {}'.format(fastmri_test_url)) + 'Downloading fastmri test dataset from {}'.format(fastmri_test_url) + ) _download_url( - url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename) + url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename + ) return data_dir @@ -432,18 +462,18 @@ def setup_fastmri(data_dir): # Rename folders to match what the workload expects os.rename( - os.path.join(data_dir, "singlecoil_train"), - os.path.join(data_dir, "knee_singlecoil_train"), + os.path.join(data_dir, 'singlecoil_train'), + os.path.join(data_dir, 'knee_singlecoil_train'), ) os.rename( - os.path.join(data_dir, "singlecoil_val"), - os.path.join(data_dir, "knee_singlecoil_val"), + os.path.join(data_dir, 'singlecoil_val'), + os.path.join(data_dir, 'knee_singlecoil_val'), ) os.rename( - os.path.join(data_dir, "singlecoil_test"), - os.path.join(data_dir, "knee_singlecoil_test"), + os.path.join(data_dir, 'singlecoil_test'), + os.path.join(data_dir, 'knee_singlecoil_test'), ) - logging.info("Set up fastMRI dataset complete") + logging.info('Set up fastMRI dataset complete') def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): @@ -456,26 +486,32 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): # been moved to the manual_download_dir. # Get paths in manual_download_dir. imagenet_jax_data_dir = os.path.join(data_dir, 'jax') - manual_download_dir = os.path.join(imagenet_jax_data_dir, - 'downloads', - 'manual') - imagenet_train_download_filepath = os.path.join(manual_download_dir, - IMAGENET_TRAIN_TAR_FILENAME) - imagenet_val_download_filepath = os.path.join(manual_download_dir, - IMAGENET_VAL_TAR_FILENAME) + manual_download_dir = os.path.join( + imagenet_jax_data_dir, 'downloads', 'manual' + ) + imagenet_train_download_filepath = os.path.join( + manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME + ) + imagenet_val_download_filepath = os.path.join( + manual_download_dir, IMAGENET_VAL_TAR_FILENAME + ) # Download imagenet train dataset if not os.path.exists(imagenet_train_filepath) and not os.path.exists( - imagenet_train_download_filepath): + imagenet_train_download_filepath + ): logging.info( - 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) + 'Downloading imagenet train dataset from {}'.format(imagenet_train_url) + ) _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset if not os.path.exists(imagenet_val_filepath) and not os.path.exists( - imagenet_val_download_filepath): - logging.info('Downloading imagenet validation dataset from {}'.format( - imagenet_val_url)) + imagenet_val_download_filepath + ): + logging.info( + 'Downloading imagenet validation dataset from {}'.format(imagenet_val_url) + ) _download_url(url=imagenet_val_url, data_dir=data_dir) # Download imagenet test set @@ -501,31 +537,40 @@ def setup_imagenet_jax(data_dir): # Setup jax dataset dir imagenet_jax_data_dir = os.path.join(data_dir, 'jax') - manual_download_dir = os.path.join(imagenet_jax_data_dir, - 'downloads', - 'manual') + manual_download_dir = os.path.join( + imagenet_jax_data_dir, 'downloads', 'manual' + ) os.makedirs(manual_download_dir, exist_ok=True) # Copy tar file into jax/downloads/manual logging.info('Checking if tar files already exists in jax/downloads/manual.') if not os.path.exists( - os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(train_tar_file_path, - manual_download_dir)) + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(train_tar_file_path, manual_download_dir) + ) shutil.move(train_tar_file_path, manual_download_dir) if not os.path.exists( - os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(val_tar_file_path, - manual_download_dir)) + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(val_tar_file_path, manual_download_dir) + ) shutil.move(val_tar_file_path, manual_download_dir) if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): - logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) - shutil.move(test_dir_path, - os.path.join(imagenet_jax_data_dir, 'imagenet_v2')) + logging.info( + 'Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2') + ) + ) + shutil.move( + test_dir_path, os.path.join(imagenet_jax_data_dir, 'imagenet_v2') + ) logging.info('Preparing imagenet data.') ds_builder = tfds.builder( - 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)) + 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir) + ) ds_builder.download_and_prepare() logging.info('Set up imagenet dataset for jax framework complete') @@ -539,14 +584,18 @@ def setup_imagenet_pytorch(data_dir): manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') if not os.path.exists(train_tar_file_path): if os.path.exists( - os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - train_tar_file_path = os.path.join(manual_download_dir, - IMAGENET_TRAIN_TAR_FILENAME) + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + train_tar_file_path = os.path.join( + manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME + ) if not os.path.exists(val_tar_file_path): if os.path.exists( - os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - val_tar_file_path = os.path.join(manual_download_dir, - IMAGENET_VAL_TAR_FILENAME) + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + ): + val_tar_file_path = os.path.join( + manual_download_dir, IMAGENET_VAL_TAR_FILENAME + ) # Setup pytorch dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') @@ -557,56 +606,68 @@ def setup_imagenet_pytorch(data_dir): # Move tar files and imagenet_v2 into pytorch directory if not os.path.exists( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(train_tar_file_path, - imagenet_pytorch_data_dir)) + os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(train_tar_file_path, imagenet_pytorch_data_dir) + ) shutil.move(train_tar_file_path, imagenet_pytorch_data_dir) if not os.path.exists( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(val_tar_file_path, - imagenet_pytorch_data_dir)) + os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir) + ) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')): - logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))) - shutil.move(test_dir_path, - os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')) + logging.info( + 'Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2') + ) + ) + shutil.move( + test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2') + ) # Extract train data\ logging.info('Extracting imagenet train data') extract( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'train'), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), + os.path.join(imagenet_pytorch_data_dir, 'train'), + mode='r:', + ) train_tar_filenames = os.listdir( - os.path.join(imagenet_pytorch_data_dir, 'train')) + os.path.join(imagenet_pytorch_data_dir, 'train') + ) for tar_filename in train_tar_filenames: if tar_filename.endswith('.tar'): dir_name = tar_filename[:-4] extract( - os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename), - os.path.join(imagenet_pytorch_data_dir, 'train', dir_name), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename), + os.path.join(imagenet_pytorch_data_dir, 'train', dir_name), + mode='r:', + ) # Extract val data logging.info('Extracting imagenet val data') extract( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'val'), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), + os.path.join(imagenet_pytorch_data_dir, 'val'), + mode='r:', + ) valprep_command = [ - 'wget', - '-qO-', - 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh' + 'wget', + '-qO-', + 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh', ] valprep_download = subprocess.Popen(valprep_command, stdout=subprocess.PIPE) - valprep_process = subprocess.Popen(['bash'], - stdin=valprep_download.stdout, - cwd=os.path.expanduser( - os.path.join(imagenet_pytorch_data_dir, - 'val'))) + valprep_process = subprocess.Popen( + ['bash'], + stdin=valprep_download.stdout, + cwd=os.path.expanduser(os.path.join(imagenet_pytorch_data_dir, 'val')), + ) valprep_download.stdout.close() valprep_process.communicate() logging.info('Set up imagenet dataset for pytorch framework complete') @@ -614,8 +675,8 @@ def setup_imagenet_pytorch(data_dir): def download_imagenet_v2(data_dir): tfds.builder( - 'imagenet_v2/matched-frequency:3.0.0', - data_dir=data_dir).download_and_prepare() + 'imagenet_v2/matched-frequency:3.0.0', data_dir=data_dir + ).download_and_prepare() def download_librispeech(data_dir, tmp_dir): @@ -634,41 +695,46 @@ def download_librispeech(data_dir, tmp_dir): if split == 'test' and version == 'other': continue wget_cmd = ( - f'wget --directory-prefix={tmp_librispeech_dir} ' - f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz') + f'wget --directory-prefix={tmp_librispeech_dir} ' + f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz' + ) subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz') subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', - shell=True).communicate() + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True + ).communicate() tars = [ - 'raw-metadata.tar.gz', - 'train-clean-100.tar.gz', - 'train-clean-360.tar.gz', - 'train-other-500.tar.gz', + 'raw-metadata.tar.gz', + 'train-clean-100.tar.gz', + 'train-clean-360.tar.gz', + 'train-other-500.tar.gz', ] for tar_filename in tars: - wget_cmd = (f'wget --directory-prefix={tmp_librispeech_dir} ' - f'http://www.openslr.org/resources/12/{tar_filename}') + wget_cmd = ( + f'wget --directory-prefix={tmp_librispeech_dir} ' + f'http://www.openslr.org/resources/12/{tar_filename}' + ) subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, tar_filename) subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', - shell=True).communicate() + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True + ).communicate() tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): librispeech_tokenizer.run( - train=True, - input_dir=extracted_data_dir, - tokenizer_vocab_path=tokenizer_vocab_path) + train=True, + input_dir=extracted_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path, + ) librispeech_preprocess.run( - input_dir=extracted_data_dir, - output_dir=final_data_dir, - tokenizer_vocab_path=tokenizer_vocab_path) + input_dir=extracted_data_dir, + output_dir=final_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path, + ) def download_mnist(data_dir): @@ -691,12 +757,14 @@ def download_wmt(data_dir): if ds_name == 'wmt17_translate/de-en:1.0.0': ds = dataset_builder.as_dataset(split='train', shuffle_files=False) ds = ds.map( - functools.partial(normalize_feature_names, dataset_builder.info), - num_parallel_calls=tf.data.AUTOTUNE) + functools.partial(normalize_feature_names, dataset_builder.info), + num_parallel_calls=tf.data.AUTOTUNE, + ) # Tokenize data. vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model') tokenizer.train_tokenizer( - ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) + ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7 + ) def main(_): @@ -715,10 +783,9 @@ def main(_): if FLAGS.all or FLAGS.criteo1tb: logging.info('Downloading criteo1tb...') - download_criteo1tb(data_dir, - tmp_dir, - num_decompression_threads, - FLAGS.interactive_deletion) + download_criteo1tb( + data_dir, tmp_dir, num_decompression_threads, FLAGS.interactive_deletion + ) if FLAGS.all or FLAGS.mnist: logging.info('Downloading MNIST...') @@ -730,19 +797,24 @@ def main(_): knee_singlecoil_train_url = FLAGS.fastmri_knee_singlecoil_train_url knee_singlecoil_val_url = FLAGS.fastmri_knee_singlecoil_val_url knee_singlecoil_test_url = FLAGS.fastmri_knee_singlecoil_test_url - if None in (knee_singlecoil_train_url, - knee_singlecoil_val_url, - knee_singlecoil_test_url): + if None in ( + knee_singlecoil_train_url, + knee_singlecoil_val_url, + knee_singlecoil_test_url, + ): raise ValueError( - 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url ' - 'to download the FastMRI dataset.\nSign up for the URLs at ' - 'https://fastmri.med.nyu.edu/.') + 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url ' + 'to download the FastMRI dataset.\nSign up for the URLs at ' + 'https://fastmri.med.nyu.edu/.' + ) if not FLAGS.skip_download: - download_fastmri(data_dir, - knee_singlecoil_train_url, - knee_singlecoil_val_url, - knee_singlecoil_test_url) + download_fastmri( + data_dir, + knee_singlecoil_train_url, + knee_singlecoil_val_url, + knee_singlecoil_test_url, + ) logging.info('fastMRI download completed. Extracting...') setup_fastmri(data_dir) @@ -754,12 +826,13 @@ def main(_): imagenet_val_url = FLAGS.imagenet_val_url if imagenet_train_url is None or imagenet_val_url is None: raise ValueError( - 'Must provide both --imagenet_{train,val}_url to download the ' - 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.') + 'Must provide both --imagenet_{train,val}_url to download the ' + 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.' + ) if FLAGS.framework is None: raise ValueError( - 'Please specify either jax or pytorch framework through framework ' - 'flag.') + 'Please specify either jax or pytorch framework through framework flag.' + ) if not FLAGS.skip_download: logging.info('Downloading ImageNet...') download_imagenet(data_dir, imagenet_train_url, imagenet_val_url) diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index a8c5cae1d..c419eb39b 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -28,17 +28,18 @@ # taken from TFDS page for librispeech dataset : # https://www.tensorflow.org/datasets/catalog/librispeech librispeech_example_counts = { - 'train-clean-100': 28539, - 'train-clean-360': 104014, - 'train-other-500': 148688, - 'test-clean': 2620, # 'test-other': 2939, - 'dev-clean': 2703, - 'dev-other': 2864, + 'train-clean-100': 28539, + 'train-clean-360': 104014, + 'train-other-500': 148688, + 'test-clean': 2620, # 'test-other': 2939, + 'dev-clean': 2703, + 'dev-other': 2864, } class Counter: """A threadsafe counter.""" + lock = threading.Lock() value = 0 @@ -56,10 +57,12 @@ def report_progress(count, total, start_time): now = time.time() size = 50 filled = int(round(size * count / float(total))) - percent = round(100. * count / float(total), 1) - bar = "-" * filled + "." * (size - filled) - sys.stdout.write("[%s] %d%% (%d of %d) %.2f sample/sec\r" % - (bar, percent, count, total, count / (now - start_time))) + percent = round(100.0 * count / float(total), 1) + bar = '-' * filled + '.' * (size - filled) + sys.stdout.write( + '[%s] %d%% (%d of %d) %.2f sample/sec\r' + % (bar, percent, count, total, count / (now - start_time)) + ) sys.stdout.flush() @@ -72,8 +75,10 @@ def process(index): data_folder, speaker_folder, chapter_folder = index utterance_ids = [] - trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = ( + f'{data_folder}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt' + ) if not exists(trans_file): skipped.inc() return utterance_ids @@ -82,7 +87,8 @@ def process(index): for l in f: utt, trans = l.strip().split(' ', maxsplit=1) audio_path = ( - f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac') + f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac' + ) if not os.path.isfile(audio_path): skipped.inc() @@ -105,9 +111,11 @@ def process(index): np.save('{}/{}/{}_targets.npy'.format(out_folder, split, utt), targets) finished.inc() - report_progress(finished.val() + skipped.val(), - librispeech_example_counts[split], - start_time) + report_progress( + finished.val() + skipped.val(), + librispeech_example_counts[split], + start_time, + ) utterance_ids.append(utt) return utterance_ids @@ -126,10 +134,12 @@ def process(index): end_time = time.time() elapsed_time = end_time - start_time - print(' \n time taken to preprocess split : ', - split, - ' = ', - time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) + print( + ' \n time taken to preprocess split : ', + split, + ' = ', + time.strftime('%H:%M:%S', time.gmtime(elapsed_time)), + ) final_count = finished.val() + skipped.val() return pd.DataFrame(file_trans, columns=['id']), final_count @@ -147,12 +157,12 @@ def run(input_dir, output_dir, tokenizer_vocab_path): os.makedirs(output_dir, exist_ok=True) subset_list = [ - 'train-clean-100', - 'train-clean-360', - 'train-other-500', - 'dev-clean', - 'dev-other', - 'test-clean', # 'test-other', + 'train-clean-100', + 'train-clean-360', + 'train-other-500', + 'dev-clean', + 'dev-other', + 'test-clean', # 'test-other', ] for subset in subset_list: logging.info('Processing split = %s...', subset) @@ -160,10 +170,14 @@ def run(input_dir, output_dir, tokenizer_vocab_path): out_dir = os.path.join(output_dir, subset) os.makedirs(out_dir, exist_ok=True) example_ids, num_entries = preprocess_data( - in_dir, output_dir, tokenizer, subset) + in_dir, output_dir, tokenizer, subset + ) if num_entries != librispeech_example_counts[subset]: - raise ValueError('Preprocessed dataframe final count not equal to ' - 'expected count: {} vs expected {}'.format( - num_entries, librispeech_example_counts[subset])) + raise ValueError( + 'Preprocessed dataframe final count not equal to ' + 'expected count: {} vs expected {}'.format( + num_entries, librispeech_example_counts[subset] + ) + ) example_ids.to_csv(os.path.join(output_dir, f'{subset}.csv')) diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index 2f559752a..5b9888cc2 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -24,7 +24,8 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): char_count = 0 with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix='/tmp/ds_chars' + ) as outfp: for split in splits: data_folder = data_folder + '/' + split for _, speaker_folder in enumerate(os.listdir(data_folder)): @@ -32,8 +33,10 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): break for chapter_folder in os.listdir(f'{data_folder}/{speaker_folder}'): - trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = ( + f'{data_folder}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt' + ) if not exists(trans_file): logging.info('path does not exist -> %s', trans_file) continue @@ -50,13 +53,15 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): return outfp -def train_tokenizer(data_dir: str, - splits, - vocab_size: int = 1024, - model_path: str = 'spm_model.vocab', - maxchars: int = int(1e7), - model_type: str = 'unigram', - character_coverage: float = 1.0): +def train_tokenizer( + data_dir: str, + splits, + vocab_size: int = 1024, + model_path: str = 'spm_model.vocab', + maxchars: int = int(1e7), + model_type: str = 'unigram', + character_coverage: float = 1.0, +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -77,15 +82,18 @@ def train_tokenizer(data_dir: str, charfile = dump_chars_for_training(data_dir, splits, maxchars=maxchars) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ + argstr = ' '.join( + [ f'--input={charfile.name}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', - ]) + ] + ) spm.SentencePieceTrainer.Train(argstr) copy_rename_path = abs_model_path + '.rntmp' @@ -104,7 +112,8 @@ def load_tokenizer(model_filepath): with gfile.GFile(model_filepath, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=False, add_eos=True, reverse=False) + model=sp_model, add_bos=False, add_eos=True, reverse=False + ) return sp_tokenizer @@ -123,8 +132,9 @@ def run(train, input_dir, tokenizer_vocab_path): detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8') logging.info('Original input = %s', test_input) - logging.info('Output after after tokenizing and detokenizing = %s', - detokenized) + logging.info( + 'Output after after tokenizing and detokenizing = %s', detokenized + ) if detokenized == test_input: logging.info('Tokenizer working correctly!') From c34af17f6e272be88e914ef297fdea65c62940d5 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:21:09 +0200 Subject: [PATCH 107/123] Format algoperf/ --- algoperf/__init__.py | 2 +- algoperf/checkpoint_utils.py | 195 +++-- algoperf/data_utils.py | 72 +- algoperf/halton.py | 154 ++-- algoperf/init_utils.py | 4 +- algoperf/interop_utils.py | 3 +- algoperf/logger_utils.py | 149 ++-- algoperf/param_utils.py | 28 +- algoperf/profiler.py | 81 +- algoperf/pytorch_utils.py | 21 +- algoperf/random_utils.py | 10 +- algoperf/spec.py | 322 +++---- .../cifar/cifar_jax/input_pipeline.py | 143 ++-- algoperf/workloads/cifar/cifar_jax/models.py | 52 +- .../workloads/cifar/cifar_jax/workload.py | 189 ++-- .../workloads/cifar/cifar_pytorch/models.py | 109 ++- .../workloads/cifar/cifar_pytorch/workload.py | 165 ++-- algoperf/workloads/cifar/workload.py | 106 +-- .../criteo1tb/criteo1tb_jax/models.py | 133 +-- .../criteo1tb/criteo1tb_jax/workload.py | 120 +-- .../criteo1tb/criteo1tb_pytorch/models.py | 144 ++-- .../criteo1tb/criteo1tb_pytorch/workload.py | 174 ++-- .../workloads/criteo1tb/input_pipeline.py | 97 ++- algoperf/workloads/criteo1tb/workload.py | 75 +- .../workloads/fastmri/fastmri_jax/models.py | 80 +- .../workloads/fastmri/fastmri_jax/ssim.py | 24 +- .../workloads/fastmri/fastmri_jax/workload.py | 160 ++-- .../fastmri/fastmri_pytorch/models.py | 99 +-- .../workloads/fastmri/fastmri_pytorch/ssim.py | 21 +- .../fastmri/fastmri_pytorch/workload.py | 206 ++--- algoperf/workloads/fastmri/input_pipeline.py | 120 +-- algoperf/workloads/fastmri/workload.py | 38 +- .../imagenet_jax/custom_tf_addons.py | 548 ++++++------ .../imagenet_jax/input_pipeline.py | 358 ++++---- .../imagenet_resnet/imagenet_jax/models.py | 71 +- .../imagenet_jax/randaugment.py | 228 +++-- .../imagenet_resnet/imagenet_jax/workload.py | 267 +++--- .../imagenet_pytorch/models.py | 219 ++--- .../imagenet_pytorch/randaugment.py | 124 +-- .../imagenet_pytorch/workload.py | 230 ++--- .../workloads/imagenet_resnet/imagenet_v2.py | 38 +- .../workloads/imagenet_resnet/workload.py | 56 +- .../imagenet_vit/imagenet_jax/models.py | 153 ++-- .../imagenet_vit/imagenet_jax/workload.py | 113 +-- .../imagenet_vit/imagenet_pytorch/models.py | 157 ++-- .../imagenet_vit/imagenet_pytorch/workload.py | 55 +- algoperf/workloads/imagenet_vit/workload.py | 120 +-- .../librispeech_conformer/input_pipeline.py | 8 +- .../librispeech_preprocessor.py | 512 ++++++----- .../librispeech_jax/models.py | 484 ++++++----- .../librispeech_jax/spectrum_augmenter.py | 81 +- .../librispeech_jax/workload.py | 297 ++++--- .../librispeech_pytorch/models.py | 230 ++--- .../librispeech_pytorch/preprocessor.py | 642 +++++++------- .../librispeech_pytorch/spectrum_augmenter.py | 97 ++- .../librispeech_pytorch/workload.py | 239 +++--- .../librispeech_conformer/metrics.py | 39 +- .../librispeech_conformer/workload.py | 11 +- .../librispeech_jax/models.py | 309 ++++--- .../librispeech_jax/workload.py | 88 +- .../librispeech_pytorch/models.py | 177 ++-- .../librispeech_pytorch/workload.py | 80 +- .../workloads/mnist/mnist_jax/workload.py | 98 ++- .../workloads/mnist/mnist_pytorch/workload.py | 163 ++-- algoperf/workloads/mnist/workload.py | 152 ++-- algoperf/workloads/ogbg/input_pipeline.py | 59 +- algoperf/workloads/ogbg/metrics.py | 13 +- algoperf/workloads/ogbg/ogbg_jax/models.py | 56 +- algoperf/workloads/ogbg/ogbg_jax/workload.py | 108 +-- .../workloads/ogbg/ogbg_pytorch/models.py | 146 ++-- .../workloads/ogbg/ogbg_pytorch/workload.py | 184 ++-- algoperf/workloads/ogbg/workload.py | 110 +-- algoperf/workloads/utils.py | 12 +- algoperf/workloads/wmt/bleu.py | 341 ++++---- algoperf/workloads/wmt/input_pipeline.py | 162 ++-- algoperf/workloads/wmt/tokenizer.py | 78 +- algoperf/workloads/wmt/wmt_jax/decode.py | 160 ++-- algoperf/workloads/wmt/wmt_jax/models.py | 489 ++++++----- algoperf/workloads/wmt/wmt_jax/workload.py | 246 +++--- algoperf/workloads/wmt/wmt_pytorch/decode.py | 210 +++-- algoperf/workloads/wmt/wmt_pytorch/models.py | 808 ++++++++++-------- .../workloads/wmt/wmt_pytorch/workload.py | 270 +++--- algoperf/workloads/wmt/workload.py | 116 +-- algoperf/workloads/workloads.py | 303 +++---- 84 files changed, 7324 insertions(+), 6287 deletions(-) diff --git a/algoperf/__init__.py b/algoperf/__init__.py index 7d54f8290..5ecee05af 100644 --- a/algoperf/__init__.py +++ b/algoperf/__init__.py @@ -2,4 +2,4 @@ from ._version import version as __version__ -__all__ = ["__version__"] +__all__ = ['__version__'] diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..577baaa34 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -20,24 +20,28 @@ from algoperf.pytorch_utils import pytorch_setup _, _, DEVICE, _ = pytorch_setup() -CheckpointReturn = Tuple[spec.OptimizerState, - spec.ParameterContainer, - spec.ModelAuxiliaryState, - dict, - list, - int, - int] - - -def maybe_restore_checkpoint(framework: str, - optimizer_state: spec.OptimizerState, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - train_state: dict, - eval_results: list, - global_step: int, - preemption_count: int, - checkpoint_dir: str) -> CheckpointReturn: +CheckpointReturn = Tuple[ + spec.OptimizerState, + spec.ParameterContainer, + spec.ModelAuxiliaryState, + dict, + list, + int, + int, +] + + +def maybe_restore_checkpoint( + framework: str, + optimizer_state: spec.OptimizerState, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + train_state: dict, + eval_results: list, + global_step: int, + preemption_count: int, + checkpoint_dir: str, +) -> CheckpointReturn: """Optionally restores from a checkpoint. The checkpoint logic is as follows: if there is a checkpoint in @@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str, uninitialized_global_step = -1 uninitialized_preemption_count = -1 checkpoint_state = { - 'model_params': model_params, - 'optimizer_state': opt_state, - 'model_state': model_state, - 'train_state': train_state, - 'eval_results': None, - 'global_step': uninitialized_global_step, - 'preemption_count': uninitialized_preemption_count, + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': None, + 'global_step': uninitialized_global_step, + 'preemption_count': uninitialized_preemption_count, } if framework == 'jax': latest_ckpt = flax_checkpoints.restore_checkpoint( - checkpoint_dir, target=checkpoint_state) - save_path = os.path.join(checkpoint_dir, - 'checkpoint_' + str(latest_ckpt['global_step'])) + checkpoint_dir, target=checkpoint_state + ) + save_path = os.path.join( + checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step']) + ) else: latest_ckpt = checkpoint_state save_path = latest_checkpoint(checkpoint_dir) @@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str, found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step if not found_checkpoint: - return (optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count) + return ( + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + ) # If there's the latest checkpoint in the checkpoint_dir, restore from that. if framework == 'jax': checkpoint_state = replicate_checkpoint( - latest_ckpt, - pytree_keys=[ - 'optimizer_state', - 'model_params', - 'model_state', - ]) - checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'], - opt_update_fn) + latest_ckpt, + pytree_keys=[ + 'optimizer_state', + 'model_params', + 'model_state', + ], + ) + checkpoint_state['optimizer_state'] = ( + checkpoint_state['optimizer_state'], + opt_update_fn, + ) checkpoint_state['eval_results'] = [ - (value, key) for key, value in latest_ckpt['eval_results'].items() + (value, key) for key, value in latest_ckpt['eval_results'].items() ] else: checkpoint_state = latest_ckpt if isinstance( - model_params, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): model_params = model_params.module model_params.load_state_dict(checkpoint_state['model_params']) checkpoint_state['model_params'] = model_params for key in optimizer_state.keys(): optimizer_state[key].load_state_dict( - checkpoint_state['optimizer_state'][key]) + checkpoint_state['optimizer_state'][key] + ) checkpoint_state['optimizer_state'][key] = optimizer_state[key] logging.info(f'Loaded checkpoint from {save_path}.') - return (checkpoint_state['optimizer_state'], - checkpoint_state['model_params'], - checkpoint_state['model_state'], - checkpoint_state['train_state'], - list(checkpoint_state['eval_results']), - checkpoint_state['global_step'], - checkpoint_state['preemption_count'] + 1) - - -def replicate_checkpoint(latest: dict, - pytree_keys: Sequence[str], - replicate: bool = True) -> dict: + return ( + checkpoint_state['optimizer_state'], + checkpoint_state['model_params'], + checkpoint_state['model_state'], + checkpoint_state['train_state'], + list(checkpoint_state['eval_results']), + checkpoint_state['global_step'], + checkpoint_state['preemption_count'] + 1, + ) + + +def replicate_checkpoint( + latest: dict, pytree_keys: Sequence[str], replicate: bool = True +) -> dict: """Restores from the provided checkpoint. Args: @@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict, return pytree -def save_checkpoint(framework: str, - optimizer_state: spec.OptimizerState, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - train_state: dict, - eval_results: list, - global_step: int, - preemption_count: int, - checkpoint_dir: str, - save_intermediate_checkpoints: bool) -> None: +def save_checkpoint( + framework: str, + optimizer_state: spec.OptimizerState, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + train_state: dict, + eval_results: list, + global_step: int, + preemption_count: int, + checkpoint_dir: str, + save_intermediate_checkpoints: bool, +) -> None: """Save the checkpoint in `checkpoint_dir`. Args: @@ -199,8 +216,9 @@ def save_checkpoint(framework: str, model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( - model_params, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): model_params = model_params.module model_params = model_params.state_dict() optimizer_state_dict = {} @@ -209,33 +227,36 @@ def save_checkpoint(framework: str, optimizer_state_dict[key] = optimizer_state[key].state_dict() else: logging.warning( - f'The optimizer state for key {key} is not saved, because ' - f'{type(optimizer_state[key])} has not implemented a state_dict() ' - 'method.') + f'The optimizer state for key {key} is not saved, because ' + f'{type(optimizer_state[key])} has not implemented a state_dict() ' + 'method.' + ) opt_state = optimizer_state_dict checkpoint_state = { - 'model_params': model_params, - 'optimizer_state': opt_state, - 'model_state': model_state, - 'train_state': train_state, - 'eval_results': tuple(eval_results), - 'global_step': global_step, - 'preemption_count': preemption_count, + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': tuple(eval_results), + 'global_step': global_step, + 'preemption_count': preemption_count, } save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}') if framework == 'jax': flax_checkpoints.save_checkpoint( - checkpoint_dir, - target=checkpoint_state, - step=global_step, - overwrite=True, - keep=np.inf if save_intermediate_checkpoints else 1) + checkpoint_dir, + target=checkpoint_state, + step=global_step, + overwrite=True, + keep=np.inf if save_intermediate_checkpoints else 1, + ) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( - os.path.join(checkpoint_dir, 'checkpoint_*')) + os.path.join(checkpoint_dir, 'checkpoint_*') + ) for path in checkpoint_files: logging.info('Removing checkpoint at %s', path) gfile.rmtree(path) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..919ccd125 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -15,9 +15,10 @@ def shard_and_maybe_pad_np( - batch: Dict[str, spec.Tensor], - padding_value: int = 0, - global_batch_size: Optional[int] = None) -> Dict[str, spec.Tensor]: + batch: Dict[str, spec.Tensor], + padding_value: int = 0, + global_batch_size: Optional[int] = None, +) -> Dict[str, spec.Tensor]: """Prepare tf data for JAX or PyTorch DDP. Convert an input batch from tf Tensors to numpy arrays, pad it with @@ -26,11 +27,13 @@ def shard_and_maybe_pad_np( """ local_device_count = max(torch.cuda.device_count(), jax.local_device_count()) inputs = batch['inputs'] - current_batch_size = inputs[0].shape[0] if isinstance( - inputs, tuple) else inputs.shape[0] + current_batch_size = ( + inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] + ) if global_batch_size is not None: - assert global_batch_size >= current_batch_size, \ - 'global_batch_size must be larger than or equal to current_batch_size.' + assert global_batch_size >= current_batch_size, ( + 'global_batch_size must be larger than or equal to current_batch_size.' + ) # Always pad to global_batch_size if it is provided. pad_to_global_batch_size = global_batch_size > current_batch_size else: @@ -43,7 +46,8 @@ def shard_and_maybe_pad_np( pad_size = local_device_count - remainder_size targets = batch['targets'] targets_shape = tuple( - targets[0].shape if isinstance(targets, tuple) else targets.shape) + targets[0].shape if isinstance(targets, tuple) else targets.shape + ) # We need a 2d mask for WMT. mask_shape = targets_shape if len(targets_shape) < 3 else targets_shape[0] # Get weights from batch if there are any. @@ -68,9 +72,9 @@ def _prepare(x): return jax.tree.map(_prepare, batch) -def pad(tensor: np.ndarray, - pad_size: int, - padding_value: int = 0) -> np.ndarray: +def pad( + tensor: np.ndarray, pad_size: int, padding_value: int = 0 +) -> np.ndarray: if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) padding = np.full(pad_size, padding_value, dtype=tensor.dtype) @@ -78,8 +82,9 @@ def pad(tensor: np.ndarray, return padded_tensor -def mixup_pytorch(batch: Tuple[spec.Tensor, spec.Tensor], - alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]: +def mixup_pytorch( + batch: Tuple[spec.Tensor, spec.Tensor], alpha: float = 0.2 +) -> Tuple[spec.Tensor, spec.Tensor]: inputs, targets = batch # Transform to one-hot targets. targets = F.one_hot(targets, num_classes=1000) @@ -144,12 +149,14 @@ class DistributedEvalSampler(Sampler): ... train(loader) """ - def __init__(self, - dataset: torch.utils.data.Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = False, - seed: int = 0) -> None: + def __init__( + self, + dataset: torch.utils.data.Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + seed: int = 0, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError('Requires distributed package to be available.') @@ -165,7 +172,7 @@ def __init__(self, # true value without extra samples self.total_size = len(self.dataset) indices = list(range(self.total_size)) - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] # true value without extra samples self.num_samples = len(indices) @@ -182,7 +189,7 @@ def __iter__(self) -> Iterable[int]: indices = list(range(len(self.dataset))) # Subsample. - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -203,11 +210,13 @@ def set_epoch(self, epoch: int) -> None: # Modified from github.com/pytorch/pytorch/issues/23900#issuecomment-518858050. -def cycle(iterable: Iterable, - keys: Tuple[str, ...] = ('inputs', 'targets'), - custom_sampler: bool = False, - use_mixup: bool = False, - mixup_alpha: float = 0.2) -> Iterable: +def cycle( + iterable: Iterable, + keys: Tuple[str, ...] = ('inputs', 'targets'), + custom_sampler: bool = False, + use_mixup: bool = False, + mixup_alpha: float = 0.2, +) -> Iterable: iterator = iter(iterable) epoch = 0 while True: @@ -229,11 +238,9 @@ def cycle(iterable: Iterable, # github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ # ConvNets/image_classification/dataloaders.py class PrefetchedWrapper: - - def __init__(self, - dataloader: DataLoader, - device: torch.device, - start_epoch: int = 0) -> None: + def __init__( + self, dataloader: DataLoader, device: torch.device, start_epoch: int = 0 + ) -> None: self.dataloader = dataloader self.epoch = start_epoch self.device = device @@ -254,7 +261,8 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]: for next_inputs, next_targets in self.dataloader: with torch.cuda.stream(stream): next_inputs = next_inputs.to( - self.device, dtype=torch.float, non_blocking=True) + self.device, dtype=torch.float, non_blocking=True + ) next_targets = next_targets.to(self.device, non_blocking=True) if not first: diff --git a/algoperf/halton.py b/algoperf/halton.py index 1f36b07bf..08c5466f1 100644 --- a/algoperf/halton.py +++ b/algoperf/halton.py @@ -36,10 +36,12 @@ def _is_prime(n: int) -> bool: return all(n % i != 0 for i in range(2, int(n**0.5) + 1)) and n != 2 -def _generate_dim(num_samples: int, - base: int, - per_dim_shift: bool, - shuffled_seed_sequence: List[int]) -> List[float]: +def _generate_dim( + num_samples: int, + base: int, + per_dim_shift: bool, + shuffled_seed_sequence: List[int], +) -> List[float]: """Generate `num_samples` from a Van der Corput sequence with base `base`. Args: @@ -59,8 +61,9 @@ def _generate_dim(num_samples: int, ValueError: if `base` is negative or not prime. """ if base < 0 or not _is_prime(base): - raise ValueError('Each Van der Corput sequence requires a prime `base`, ' - f'received {base}.') + raise ValueError( + f'Each Van der Corput sequence requires a prime `base`, received {base}.' + ) rng = random.RandomState(base) if shuffled_seed_sequence is None: @@ -76,7 +79,7 @@ def _generate_dim(num_samples: int, dim_sequence = [] for i in range(1, num_samples + 1): - num = 0. + num = 0.0 denominator = base while i: num += shuffled_seed_sequence[i % base] / denominator @@ -91,13 +94,15 @@ def _generate_dim(num_samples: int, Matrix = List[List[int]] -def generate_sequence(num_samples: int, - num_dims: int, - skip: int = 100, - per_dim_shift: bool = True, - shuffle_sequence: bool = True, - primes: Sequence[int] = None, - shuffled_seed_sequence: Matrix = None) -> Matrix: +def generate_sequence( + num_samples: int, + num_dims: int, + skip: int = 100, + per_dim_shift: bool = True, + shuffle_sequence: bool = True, + primes: Sequence[int] = None, + shuffled_seed_sequence: Matrix = None, +) -> Matrix: """Generate `num_samples` from a Halton sequence of dimension `num_dims`. Each dimension is generated independently from a shuffled Van der Corput @@ -140,25 +145,29 @@ def generate_sequence(num_samples: int, if primes is not None and len(primes) != num_dims: raise ValueError( - 'If passing in a sequence of primes it must be the same length as ' - f'num_dims={num_dims}, received {primes} (len {len(primes)}).') + 'If passing in a sequence of primes it must be the same length as ' + f'num_dims={num_dims}, received {primes} (len {len(primes)}).' + ) if shuffled_seed_sequence is not None: if len(shuffled_seed_sequence) != num_dims: raise ValueError( - 'If passing in `shuffled_seed_sequence` it must be the same length ' - f'as num_dims={num_dims}, received {shuffled_seed_sequence} ' - f'(len {len(shuffled_seed_sequence)}).') + 'If passing in `shuffled_seed_sequence` it must be the same length ' + f'as num_dims={num_dims}, received {shuffled_seed_sequence} ' + f'(len {len(shuffled_seed_sequence)}).' + ) for d in range(num_dims): if len(shuffled_seed_sequence[d]) != primes[d]: raise ValueError( - 'If passing in `shuffled_seed_sequence` it must have element `{d}` ' - 'be a sequence of length `primes[{d}]`={expected}, received ' - '{actual} (len {length})'.format( - d=d, - expected=primes[d], - actual=shuffled_seed_sequence[d], - length=shuffled_seed_sequence[d])) + 'If passing in `shuffled_seed_sequence` it must have element `{d}` ' + 'be a sequence of length `primes[{d}]`={expected}, received ' + '{actual} (len {length})'.format( + d=d, + expected=primes[d], + actual=shuffled_seed_sequence[d], + length=shuffled_seed_sequence[d], + ) + ) if primes is None: primes = [] @@ -166,7 +175,7 @@ def generate_sequence(num_samples: int, while len(primes) < num_dims + 1: primes = generate_primes(1000 * prime_attempts) prime_attempts += 1 - primes = primes[-num_dims - 1:-1] + primes = primes[-num_dims - 1 : -1] # Skip the first `skip` points in the sequence because they can have unwanted # correlations. @@ -179,10 +188,11 @@ def generate_sequence(num_samples: int, else: dim_shuffled_seed_sequence = shuffled_seed_sequence[d] dim_sequence = _generate_dim( - num_samples=num_samples, - base=primes[d], - shuffled_seed_sequence=dim_shuffled_seed_sequence, - per_dim_shift=per_dim_shift) + num_samples=num_samples, + base=primes[d], + shuffled_seed_sequence=dim_shuffled_seed_sequence, + per_dim_shift=per_dim_shift, + ) dim_sequence = dim_sequence[skip:] halton_sequence.append(dim_sequence) @@ -195,29 +205,29 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: str, - min_val: float, - max_val: float, - scaling: str, - halton_point: float) -> Tuple[str, float]: +def _generate_double_point( + name: str, min_val: float, max_val: float, scaling: str, halton_point: float +) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: raise ValueError( - 'Only log or linear scaling is supported for floating point ' - f'parameters. Received {scaling}.') + 'Only log or linear scaling is supported for floating point ' + f'parameters. Received {scaling}.' + ) if scaling == 'log': # To transform from [0, 1] to [min_val, max_val] on a log scale we do: # min_val * exp(x * log(max_val / min_val)). - rescaled_value = ( - min_val * math.exp(halton_point * math.log(max_val / min_val))) + rescaled_value = min_val * math.exp( + halton_point * math.log(max_val / min_val) + ) else: rescaled_value = halton_point * (max_val - min_val) + min_val return name, rescaled_value -def _generate_discrete_point(name: str, - feasible_points: Sequence[Any], - halton_point: float) -> Any: +def _generate_discrete_point( + name: str, feasible_points: Sequence[Any], halton_point: float +) -> Any: """Generate a discrete hyperparameter value from a Halton sequence point.""" index = int(math.floor(halton_point * len(feasible_points))) return name, feasible_points[index] @@ -236,27 +246,23 @@ def interval(start: int, end: int) -> Tuple[int, int]: def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints - return functools.partial(_generate_double_point, - name, - min_val, - max_val, - 'log') + return functools.partial( + _generate_double_point, name, min_val, max_val, 'log' + ) def uniform( - name: str, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]] +) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): - return functools.partial(_generate_discrete_point, - name, - search_points.feasible_points) + return functools.partial( + _generate_discrete_point, name, search_points.feasible_points + ) min_val, max_val = search_points - return functools.partial(_generate_double_point, - name, - min_val, - max_val, - 'linear') + return functools.partial( + _generate_double_point, name, min_val, max_val, 'linear' + ) def product(sweeps: Sequence[_SweepSequence]) -> _SweepSequence: @@ -277,9 +283,10 @@ def sweep(name, feasible_points: Sequence[Any]) -> _SweepSequence: return [{name: x} for x in feasible_points.feasible_points] -def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, - _SweepSequence]], - length: int) -> _SweepSequence: +def zipit( + generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _SweepSequence]], + length: int, +) -> _SweepSequence: """Zip together a list of hyperparameter generators. Args: @@ -302,7 +309,8 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, hyperparameter name from generator_fns_or_sweeps. """ halton_sequence = generate_sequence( - num_samples=length, num_dims=len(generator_fns_or_sweeps)) + num_samples=length, num_dims=len(generator_fns_or_sweeps) + ) # A List[Dict] of hyperparameter names to sweep values. hyperparameter_sweep = [] for trial_index in range(length): @@ -326,8 +334,9 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _ListSearchSpace = List[Dict[str, Union[str, float, Sequence]]] -def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], - num_trials: int) -> List[collections.namedtuple]: +def generate_search( + search_space: Union[_DictSearchSpace, _ListSearchSpace], num_trials: int +) -> List[collections.namedtuple]: """Generate a random search with the given bounds and scaling. Args:linear @@ -352,8 +361,9 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], else: raise AttributeError('tuning_search_space should either be a dict or list.') - named_tuple_class = collections.namedtuple('Hyperparameters', - all_hyperparameter_names) + named_tuple_class = collections.namedtuple( + 'Hyperparameters', all_hyperparameter_names + ) if isinstance(search_space, dict): hyperparameter_generators = [] @@ -367,16 +377,18 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], generator_fn = uniform(name, interval(space['min'], space['max'])) hyperparameter_generators.append(generator_fn) return [ - named_tuple_class(**p) - for p in zipit(hyperparameter_generators, num_trials) + named_tuple_class(**p) + for p in zipit(hyperparameter_generators, num_trials) ] else: hyperparameters = [] updated_num_trials = min(num_trials, len(search_space)) if num_trials != len(search_space): - logging.info(f'--num_tuning_trials was set to {num_trials}, but ' - f'{len(search_space)} trial(s) found in the JSON file. ' - f'Updating --num_tuning_trials to {updated_num_trials}.') + logging.info( + f'--num_tuning_trials was set to {num_trials}, but ' + f'{len(search_space)} trial(s) found in the JSON file. ' + f'Updating --num_tuning_trials to {updated_num_trials}.' + ) for trial in search_space: hyperparameters.append(named_tuple_class(**trial)) return hyperparameters[:updated_num_trials] diff --git a/algoperf/init_utils.py b/algoperf/init_utils.py index 185480cc7..c66a0be20 100644 --- a/algoperf/init_utils.py +++ b/algoperf/init_utils.py @@ -12,7 +12,7 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) - std = math.sqrt(1. / fan_in) / .87962566103423978 + std = math.sqrt(1.0 / fan_in) / 0.87962566103423978 nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: - nn.init.constant_(module.bias, 0.) + nn.init.constant_(module.bias, 0.0) diff --git a/algoperf/interop_utils.py b/algoperf/interop_utils.py index 0c6535d7a..c30d0cf3b 100644 --- a/algoperf/interop_utils.py +++ b/algoperf/interop_utils.py @@ -6,7 +6,8 @@ def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor: return torch.utils.dlpack.from_dlpack( - jax.dlpack.to_dlpack(x, take_ownership=take_ownership)) + jax.dlpack.to_dlpack(x, take_ownership=take_ownership) + ) def pytorch_to_jax(x: torch.Tensor) -> spec.Tensor: diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py index c988956dc..3c4898142 100644 --- a/algoperf/logger_utils.py +++ b/algoperf/logger_utils.py @@ -37,12 +37,12 @@ def makedir(dir_name: str, exist_ok: bool = True) -> None: def get_log_dir( - experiment_dir: str, - workload: spec.Workload, - framework: str, - experiment_name: str, - resume_last_run: bool, - overwrite: bool, + experiment_dir: str, + workload: spec.Workload, + framework: str, + experiment_name: str, + resume_last_run: bool, + overwrite: bool, ) -> Optional[str]: # Construct path to experiment workload directory. experiment_dir = os.path.expanduser(experiment_dir) @@ -50,26 +50,29 @@ def get_log_dir( if experiment_name is None: experiment_path = os.path.join(experiment_dir, workload_dir_name) else: - experiment_path = os.path.join(experiment_dir, - experiment_name, - workload_dir_name) + experiment_path = os.path.join( + experiment_dir, experiment_name, workload_dir_name + ) if os.path.exists(experiment_path): if overwrite: logging.info( - f'Removing existing experiment directory {experiment_path} because ' - '--overwrite was set.') + f'Removing existing experiment directory {experiment_path} because ' + '--overwrite was set.' + ) if RANK == 0: shutil.rmtree(experiment_path) elif resume_last_run: logging.info( - f'Resuming from experiment directory {experiment_path} because ' - '--resume_last_run was set.') + f'Resuming from experiment directory {experiment_path} because ' + '--resume_last_run was set.' + ) else: if RANK == 0: resume = input( - 'Found existing experiment dir with the same name: {}. Do you wish ' - 'to resume training from this dir? [y/N]:'.format(experiment_path)) + 'Found existing experiment dir with the same name: {}. Do you wish ' + 'to resume training from this dir? [y/N]:'.format(experiment_path) + ) if resume.lower() != 'y': sys.exit() @@ -83,16 +86,18 @@ def get_log_dir( return experiment_path -def write_hparams(hparams: spec.Hyperparameters, - tuning_dir: str) -> spec.Hyperparameters: +def write_hparams( + hparams: spec.Hyperparameters, tuning_dir: str +) -> spec.Hyperparameters: hparams_file_name = os.path.join(tuning_dir, 'hparams.json') if os.path.exists(hparams_file_name): # If hparams.json already exist, use the previously saved hyperparameters. logging.info('Loading hparams from %s.', hparams_file_name) with open(hparams_file_name, 'r') as f: hparams_dict = json.load(f) - hparams = collections.namedtuple('Hyperparameters', - hparams_dict)(**hparams_dict) + hparams = collections.namedtuple('Hyperparameters', hparams_dict)( + **hparams_dict + ) else: logging.info('Saving hparams to %s.', hparams_file_name) if RANK == 0: @@ -108,8 +113,8 @@ def write_json(name: str, log_dict: Dict, indent: int = 2) -> None: def write_to_csv( - metrics: Dict, - csv_path: str, + metrics: Dict, + csv_path: str, ) -> None: try: with open(csv_path, 'r') as csv_file: @@ -118,8 +123,10 @@ def write_to_csv( except (pd.errors.EmptyDataError, FileNotFoundError) as e: measurements = pd.DataFrame([metrics], columns=sorted(metrics.keys())) if isinstance(e, pd.errors.EmptyDataError): - logging.info('Measurements file is empty. Create a new one, starting ' - 'with metrics from this step.') + logging.info( + 'Measurements file is empty. Create a new one, starting ' + 'with metrics from this step.' + ) with open(csv_path, 'w') as csv_file: measurements.to_csv(csv_file, index=False) return @@ -130,7 +137,8 @@ def _get_utilization() -> Dict: # CPU util_data['cpu.util.avg_percent_since_last'] = psutil.cpu_percent( - interval=None) # non-blocking (cpu util percentage since last call) + interval=None + ) # non-blocking (cpu util percentage since last call) util_data['cpu.freq.current'] = psutil.cpu_freq().current # Memory @@ -208,11 +216,14 @@ def _get_system_hardware_info() -> Dict: def _get_system_software_info() -> Dict: system_software_info = {} - system_software_info['os_platform'] = \ - platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' - system_software_info['python_version'] = platform.python_version( + system_software_info['os_platform'] = ( + platform.platform() + ) # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' + system_software_info['python_version'] = ( + platform.python_version() ) # Ex. '3.11.10' - system_software_info['python_compiler'] = platform.python_compiler( + system_software_info['python_compiler'] = ( + platform.python_compiler() ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive @@ -228,19 +239,28 @@ def _get_system_software_info() -> Dict: def _get_git_commit_hash() -> str: - return subprocess.check_output(['git', 'rev-parse', - 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(['git', 'rev-parse', 'HEAD']) + .decode('ascii') + .strip() + ) def _get_git_branch() -> str: - return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', - 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + .decode('ascii') + .strip() + ) def _get_cpu_model_name() -> str: output = subprocess.check_output(['lscpu']).decode('ascii').strip() - return re.findall(r'(?=Model name:\s{1,}).*', - output)[0].split('Model name:')[1].strip() + return ( + re.findall(r'(?=Model name:\s{1,}).*', output)[0] + .split('Model name:')[1] + .strip() + ) def _is_primitive_type(item: Any) -> bool: @@ -252,23 +272,25 @@ def _get_workload_properties(workload: spec.Workload) -> Dict: workload_properties = {} skip_list = ['param_shapes', 'model_params_types'] keys = [ - key for key in dir(workload) - if not key.startswith('_') and key not in skip_list + key + for key in dir(workload) + if not key.startswith('_') and key not in skip_list ] for key in keys: try: attr = getattr(workload, key) except: # pylint: disable=bare-except logging.info( - f'Unable to record workload.{key} information. Continuing without it.' + f'Unable to record workload.{key} information. Continuing without it.' ) if _is_primitive_type(attr): workload_properties[f'workload.{key}'] = attr return workload_properties -def get_meta_data(workload: spec.Workload, - rng_seed: Optional[int] = None) -> Dict: +def get_meta_data( + workload: spec.Workload, rng_seed: Optional[int] = None +) -> Dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -290,12 +312,14 @@ class MetricLogger(object): the wrong time. """ - def __init__(self, - csv_path: str, - eval_csv_path: str, - events_dir: Optional[str] = None, - configs: Optional[flags.FLAGS] = None, - hyperparameters: Optional[spec.Hyperparameters] = None) -> None: + def __init__( + self, + csv_path: str, + eval_csv_path: str, + events_dir: Optional[str] = None, + configs: Optional[flags.FLAGS] = None, + hyperparameters: Optional[spec.Hyperparameters] = None, + ) -> None: self._measurements = {} self._csv_path = csv_path self._eval_csv_path = eval_csv_path @@ -305,15 +329,18 @@ def __init__(self, self._tb_metric_writer = metric_writers.create_default_writer(events_dir) if wandb is not None and self.use_wandb: wandb.init( - dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework]) + dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework] + ) wandb.config.update(configs) wandb.config.update(hyperparameters._asdict()) - def append_scalar_metrics(self, - metrics: Dict, - global_step: int, - preemption_count: Optional[int] = None, - is_eval: bool = False) -> None: + def append_scalar_metrics( + self, + metrics: Dict, + global_step: int, + preemption_count: Optional[int] = None, + is_eval: bool = False, + ) -> None: metrics['global_step'] = global_step if preemption_count is not None: metrics['preemption_count'] = preemption_count @@ -324,7 +351,8 @@ def append_scalar_metrics(self, if self._tb_metric_writer: self._tb_metric_writer.write_scalars( - step=int(metrics['global_step']), scalars=metrics) + step=int(metrics['global_step']), scalars=metrics + ) self._tb_metric_writer.flush() if wandb is not None and self.use_wandb: @@ -335,15 +363,16 @@ def finish(self) -> None: wandb.finish() -def set_up_loggers(train_dir: str, - configs: flags.FLAGS, - hyperparameters: spec.Hyperparameters) -> MetricLogger: +def set_up_loggers( + train_dir: str, configs: flags.FLAGS, hyperparameters: spec.Hyperparameters +) -> MetricLogger: csv_path = os.path.join(train_dir, 'measurements.csv') eval_csv_path = os.path.join(train_dir, 'eval_measurements.csv') metrics_logger = MetricLogger( - csv_path=csv_path, - eval_csv_path=eval_csv_path, - events_dir=train_dir, - configs=configs, - hyperparameters=hyperparameters) + csv_path=csv_path, + eval_csv_path=eval_csv_path, + events_dir=train_dir, + configs=configs, + hyperparameters=hyperparameters, + ) return metrics_logger diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..908ef0f27 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -14,7 +14,8 @@ def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]: def pytorch_param_types( - param_shapes: Dict[str, spec.ShapeTuple]) -> Dict[str, spec.ParameterType]: + param_shapes: Dict[str, spec.ShapeTuple], +) -> Dict[str, spec.ParameterType]: param_types = {} for name in param_shapes.keys(): if 'bn' in name: @@ -65,18 +66,21 @@ def pytorch_param_types( def jax_param_shapes( - params: spec.ParameterContainer) -> spec.ParameterShapeTree: + params: spec.ParameterContainer, +) -> spec.ParameterShapeTree: return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) -def jax_param_types(param_shapes: spec.ParameterShapeTree, - parent_name: str = '') -> Dict[str, spec.ParameterType]: +def jax_param_types( + param_shapes: spec.ParameterShapeTree, parent_name: str = '' +) -> Dict[str, spec.ParameterType]: param_types = {} for name, value in param_shapes.items(): name = name.lower() if isinstance(value, dict) or isinstance(value, flax.core.FrozenDict): param_types[name] = jax_param_types( - value, parent_name=parent_name + '/' + name) + value, parent_name=parent_name + '/' + name + ) else: if 'batchnorm' in parent_name or 'bn' in parent_name: if name == 'scale': @@ -85,7 +89,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.BATCH_NORM_BIAS else: raise ValueError( - f'Unrecognized batch norm parameter: {parent_name}/{name}.') + f'Unrecognized batch norm parameter: {parent_name}/{name}.' + ) elif 'layernorm' in parent_name or 'ln' in parent_name: if name == 'scale': param_types[name] = spec.ParameterType.LAYER_NORM_SCALE @@ -93,7 +98,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.LAYER_NORM_BIAS else: raise ValueError( - f'Unrecognized layer norm parameter: {parent_name}/{name}.') + f'Unrecognized layer norm parameter: {parent_name}/{name}.' + ) elif 'conv' in parent_name: if 'bias' in name: param_types[name] = spec.ParameterType.BIAS @@ -102,8 +108,9 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, # Note that this is exact equality, not contained in, because # flax.linen.Embed names the embedding parameter "embedding" # https://github.com/google/flax/blob/main/flax/linen/linear.py#L604. - elif ('embedding' in name or - ('embedding' in parent_name and name == 'kernel')): + elif 'embedding' in name or ( + 'embedding' in parent_name and name == 'kernel' + ): param_types[name] = spec.ParameterType.EMBEDDING elif 'attention' in parent_name: if name == 'bias': @@ -122,7 +129,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.ATTENTION_QKV else: raise ValueError( - f'Unrecognized attention parameter: {parent_name}/{name}.') + f'Unrecognized attention parameter: {parent_name}/{name}.' + ) elif 'bias' in name: param_types[name] = spec.ParameterType.BIAS else: diff --git a/algoperf/profiler.py b/algoperf/profiler.py index fa2a1bee2..0e791d3a8 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -21,7 +21,6 @@ def _get_monotonic_time() -> float: class Profiler: - def __init__(self, local_rank: Optional[int] = None) -> None: self._local_rank = local_rank @@ -41,7 +40,8 @@ def start(self, action_name: str) -> None: pass if action_name in self.current_actions: raise ValueError( - f'Attempted to start {action_name} which has already started.') + f'Attempted to start {action_name} which has already started.' + ) self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: @@ -49,8 +49,10 @@ def stop(self, action_name: str) -> None: pass end_time = _get_monotonic_time() if action_name not in self.current_actions: - raise ValueError(f'Attempting to stop recording an action ' - f'({action_name}) which was never started.') + raise ValueError( + f'Attempting to stop recording an action ' + f'({action_name}) which was never started.' + ) start_time = self.current_actions.pop(action_name) duration = end_time - start_time self.recorded_durations[action_name].append(duration) @@ -64,16 +66,20 @@ def profile(self, action_name: str) -> Generator: self.stop(action_name) def _make_report( - self + self, ) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]: total_duration = _get_monotonic_time() - self.start_time - report = [(str(a), - float(np.mean(d)), - float(np.std(d)), - len(d), - float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + report = [ + ( + str(a), + float(np.mean(d)), + float(np.std(d)), + len(d), + float(np.sum(d)), + 100.0 * float(np.sum(d)) / total_duration, + ) + for a, d in self.recorded_durations.items() + ] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration @@ -92,32 +98,42 @@ def log_row(action, mean, std, num_calls, total, per): row += f' {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|' return row - header_string = log_row('Action', - 'Mean Duration (s)', - 'Std Duration (s)', - 'Num Calls', - 'Total Time (s)', - 'Percentage %') + header_string = log_row( + 'Action', + 'Mean Duration (s)', + 'Std Duration (s)', + 'Num Calls', + 'Total Time (s)', + 'Percentage %', + ) output_string_len = len(header_string.expandtabs()) sep_lines = f'{sep}{"-" * output_string_len}' output_string += sep_lines + header_string + sep_lines report, total_calls, total_duration = self._make_report() - output_string += log_row('Total', - '-----', - '-----', - f'{total_calls:}', - f'{total_duration:.5}', - '100 %') + output_string += log_row( + 'Total', + '-----', + '-----', + f'{total_calls:}', + f'{total_duration:.5}', + '100 %', + ) output_string += sep_lines - for action, mean_duration, std_duration, num_calls, \ - total_duration, duration_per in report: + for ( + action, + mean_duration, + std_duration, + num_calls, + total_duration, + duration_per, + ) in report: output_string += log_row( - action, - f'{mean_duration:.5}', - f'{std_duration:.5}', - f'{num_calls}', - f'{total_duration:.5}', - f'{duration_per:.5}', + action, + f'{mean_duration:.5}', + f'{std_duration:.5}', + f'{num_calls}', + f'{total_duration:.5}', + f'{duration_per:.5}', ) output_string += sep_lines output_string += sep @@ -125,7 +141,6 @@ def log_row(action, mean, std, num_calls, total, per): class PassThroughProfiler(Profiler): - def start(self, action_name: str) -> None: pass diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index b81d2969a..429e4d1e2 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -12,10 +12,12 @@ from algoperf import spec from algoperf.profiler import Profiler -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - BatchNorm as ConformerBatchNorm -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - BatchNorm as DeepspeechBatchNorm +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + BatchNorm as ConformerBatchNorm, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + BatchNorm as DeepspeechBatchNorm, +) def pytorch_setup() -> Tuple[bool, int, torch.device, int]: @@ -61,12 +63,13 @@ def sync_ddp_time(time: float, device: torch.device) -> float: return time_tensor.item() -def update_batch_norm_fn(module: spec.ParameterContainer, - update_batch_norm: bool) -> None: +def update_batch_norm_fn( + module: spec.ParameterContainer, update_batch_norm: bool +) -> None: bn_layers = ( - torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class. - ConformerBatchNorm, # Custom BN class for conformer model. - DeepspeechBatchNorm, # Custom BN class for deepspeech model. + torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class. + ConformerBatchNorm, # Custom BN class for conformer model. + DeepspeechBatchNorm, # Custom BN class for deepspeech model. ) if isinstance(module, bn_layers): if not update_batch_norm: diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index a579976ad..41f4b6b41 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -10,8 +10,9 @@ import jax.random as jax_rng except (ImportError, ModuleNotFoundError): logging.warning( - 'Could not import jax.random for the submission runner, falling back to ' - 'numpy random_utils.') + 'Could not import jax.random for the submission runner, falling back to ' + 'numpy random_utils.' + ) jax_rng = None FLAGS = flags.FLAGS @@ -54,8 +55,9 @@ def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name def _check_jax_install() -> None: if jax_rng is None: raise ValueError( - 'Must install jax to use the jax RNG library, or use PyTorch and pass ' - '--framework=pytorch to use the Numpy version instead.') + 'Must install jax to use the jax RNG library, or use PyTorch and pass ' + '--framework=pytorch to use the Numpy version instead.' + ) def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: diff --git a/algoperf/spec.py b/algoperf/spec.py index 9670dcb76..5f7b930af 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -5,10 +5,10 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union -from absl import logging import jax -from torch import nn import torch.nn.functional as F +from absl import logging +from torch import nn class LossType(enum.Enum): @@ -53,7 +53,6 @@ class ParameterType(enum.Enum): # Define this so that if using pytree iteration utilities, can iterate over the # model shapes pytree without iterating over the shape tuples. class ShapeTuple: - def __init__(self, shape_tuple): self.shape_tuple = shape_tuple @@ -64,19 +63,22 @@ def __eq__(self, other): return self.shape_tuple == other.shape_tuple -Shape = Union[Tuple[int], - Tuple[int, int], - Tuple[int, int, int], - Tuple[int, int, int, int], - ShapeTuple] +Shape = Union[ + Tuple[int], + Tuple[int, int], + Tuple[int, int, int], + Tuple[int, int, int, int], + ShapeTuple, +] ParameterShapeTree = Dict[str, Dict[str, Shape]] # If necessary, these can be zipped together easily given they have the same # structure, to get an iterator over pairs of leaves. ParameterKey = str # Dicts can be arbitrarily nested. -ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]], - nn.Module] +ParameterContainer = Union[ + Dict[ParameterKey, Dict[ParameterKey, Tensor]], nn.Module +] ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]] RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...] @@ -92,7 +94,6 @@ def __eq__(self, other): class Workload(metaclass=abc.ABCMeta): - def __init__(self, *args, **kwargs) -> None: del args del kwargs @@ -107,8 +108,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" @abc.abstractmethod - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: """Return whether or not the workload validation goal has been reached.""" @abc.abstractmethod @@ -117,14 +119,15 @@ def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: @abc.abstractmethod def _build_input_queue( - self, - data_rng: RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]: + self, + data_rng: RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, Any]]: """Build the input queue for the workload data. This is the only function that is NOT allowed to be called by submitters. @@ -213,8 +216,9 @@ def param_shapes(self): """The shapes of the parameters in the workload model.""" if self._param_shapes is None: raise ValueError( - 'This should not happen, workload.init_model_fn() should be called ' - 'before workload.param_shapes!') + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_shapes!' + ) return self._param_shapes @property @@ -222,8 +226,9 @@ def model_params_types(self): """The types of the parameters in the workload model.""" if self._param_types is None: raise ValueError( - 'This should not happen, workload.init_model_fn() should be called ' - 'before workload.param_types!') + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_types!' + ) return self._param_types @abc.abstractmethod @@ -234,10 +239,12 @@ def is_output_params(self, param_key: ParameterKey) -> bool: # Tuple[RandomState, Optional[float], Optional[float]], # ParameterContainer] @abc.abstractmethod - def init_model_fn(self, - rng: RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> ModelInitState: + def init_model_fn( + self, + rng: RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> ModelInitState: """Return (initial_params, initial_model_state).""" # ModelFn = Callable[ @@ -251,30 +258,35 @@ def init_model_fn(self, # float], # Tensor] @abc.abstractmethod - def model_fn(self, - params: ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, Tensor], - model_state: ModelAuxiliaryState, - mode: ForwardPassMode, - rng: RandomState, - update_batch_norm: bool, - dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]: + def model_fn( + self, + params: ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, Tensor], + model_state: ModelAuxiliaryState, + mode: ForwardPassMode, + rng: RandomState, + update_batch_norm: bool, + dropout_rate: float, + ) -> Tuple[Tensor, ModelAuxiliaryState]: """Return logits_batch""" # Possible side effect of updating BN. - def output_activation_fn(self, logits_batch: Tensor, - framework: str) -> Tensor: + def output_activation_fn( + self, logits_batch: Tensor, framework: str + ) -> Tensor: """Turn logits into probabilities, according to the loss_type property.""" if framework not in ['pytorch', 'jax']: raise ValueError( - f'`framework` has to be either `pytorch` or `jax`, got {framework}.') + f'`framework` has to be either `pytorch` or `jax`, got {framework}.' + ) activation_fn = { - LossType.MEAN_SQUARED_ERROR: lambda z: z, - LossType.MEAN_ABSOLUTE_ERROR: lambda z: z, + LossType.MEAN_SQUARED_ERROR: lambda z: z, + LossType.MEAN_ABSOLUTE_ERROR: lambda z: z, } is_pytorch = framework == 'pytorch' # If False, framework == 'jax'. softmax_fn = ( - functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax) + functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax + ) sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn @@ -286,12 +298,13 @@ def output_activation_fn(self, logits_batch: Tensor, # `update_params`. @abc.abstractmethod def loss_fn( - self, - # Dense or one-hot labels, or a tuple of (tensor, padding) for speech. - label_batch: Union[Tuple[Tensor, Tensor], Tensor], - logits_batch: Union[Tuple[Tensor, Tensor], Tensor], - mask_batch: Optional[Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, Tensor]: # differentiable + self, + # Dense or one-hot labels, or a tuple of (tensor, padding) for speech. + label_batch: Union[Tuple[Tensor, Tensor], Tensor], + logits_batch: Union[Tuple[Tensor, Tensor], Tensor], + mask_batch: Optional[Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -300,48 +313,54 @@ def loss_fn( """ @abc.abstractmethod - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: ParameterContainer, - model_state: ModelAuxiliaryState, - rng: RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: ParameterContainer, + model_state: ModelAuxiliaryState, + rng: RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Evaluate the model on a given dataset split, return final scalars.""" - def eval_model(self, - global_batch_size: int, - params: ParameterContainer, - model_state: ModelAuxiliaryState, - rng: RandomState, - data_dir: str, - imagenet_v2_data_dir: Optional[str], - global_step: int) -> Dict[str, float]: + def eval_model( + self, + global_batch_size: int, + params: ParameterContainer, + model_state: ModelAuxiliaryState, + rng: RandomState, + data_dir: str, + imagenet_v2_data_dir: Optional[str], + global_step: int, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" logging.info('Evaluating on the training split.') train_metrics = self._eval_model_on_split( - split='eval_train', - num_examples=self.num_eval_train_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) + split='eval_train', + num_examples=self.num_eval_train_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step, + ) eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} # We always require a validation set. logging.info('Evaluating on the validation split.') validation_metrics = self._eval_model_on_split( - 'validation', - num_examples=self.num_validation_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) + 'validation', + num_examples=self.num_validation_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step, + ) for k, v in validation_metrics.items(): eval_metrics['validation/' + k] = v eval_metrics['validation/num_examples'] = self.num_validation_examples @@ -350,14 +369,15 @@ def eval_model(self, if self.num_test_examples is not None: logging.info('Evaluating on the test split.') test_metrics = self._eval_model_on_split( - 'test', - num_examples=self.num_test_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir, - global_step=global_step) + 'test', + num_examples=self.num_test_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir, + global_step=global_step, + ) for k, v in test_metrics.items(): eval_metrics['test/' + k] = v eval_metrics['test/num_examples'] = self.num_test_examples @@ -374,27 +394,32 @@ class TrainingCompleteError(Exception): # Training algorithm track submission functions, to be filled in by the # submitter. -InitOptimizerFn = Callable[[ +InitOptimizerFn = Callable[ + [ Workload, ParameterContainer, ModelAuxiliaryState, Hyperparameters, - RandomState -], - OptimizerState] - - -def init_optimizer_state(workload: Workload, - model_params: ParameterContainer, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - rng: RandomState) -> OptimizerState: + RandomState, + ], + OptimizerState, +] + + +def init_optimizer_state( + workload: Workload, + model_params: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + rng: RandomState, +) -> OptimizerState: # return initial_optimizer_state pass UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState] -UpdateParamsFn = Callable[[ +UpdateParamsFn = Callable[ + [ Workload, ParameterContainer, ParameterTypeTree, @@ -406,9 +431,10 @@ def init_optimizer_state(workload: Workload, List[Tuple[int, float]], int, RandomState, - Optional[Dict[str, Any]] -], - UpdateReturn] + Optional[Dict[str, Any]], + ], + UpdateReturn, +] # Each call to this function is considered a "step". @@ -417,23 +443,26 @@ def init_optimizer_state(workload: Workload, # and if has not actually achieved the goal then it will be considered as not # achieved the goal and get an infinite time score. Most submissions will likely # wait until the next free eval and not use this functionality. -def update_params(workload: Workload, - current_param_container: ParameterContainer, - current_params_types: ParameterTypeTree, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - batch: Dict[str, Tensor], - loss_type: LossType, - optimizer_state: OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: RandomState, - train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: +def update_params( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + batch: Dict[str, Tensor], + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass -PrepareForEvalFn = Callable[[ +PrepareForEvalFn = Callable[ + [ Workload, ParameterContainer, ParameterTypeTree, @@ -443,27 +472,31 @@ def update_params(workload: Workload, OptimizerState, List[Tuple[int, float]], int, - RandomState -], - UpdateReturn] + RandomState, + ], + UpdateReturn, +] # Prepare model and optimizer for evaluation. -def prepare_for_eval(workload: Workload, - current_param_container: ParameterContainer, - current_params_types: ParameterTypeTree, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - loss_type: LossType, - optimizer_state: OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: RandomState) -> UpdateReturn: +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState, +) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass -DataSelectionFn = Callable[[ +DataSelectionFn = Callable[ + [ Workload, Iterator[Dict[str, Any]], OptimizerState, @@ -471,21 +504,24 @@ def prepare_for_eval(workload: Workload, LossType, Hyperparameters, int, - RandomState -], - Tuple[Tensor, Tensor]] + RandomState, + ], + Tuple[Tensor, Tensor], +] # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: Workload, - input_queue: Iterator[Dict[str, Any]], - optimizer_state: OptimizerState, - current_param_container: ParameterContainer, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - global_step: int, - rng: RandomState) -> Dict[str, Tensor]: +def data_selection( + workload: Workload, + input_queue: Iterator[Dict[str, Any]], + optimizer_state: OptimizerState, + current_param_container: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + global_step: int, + rng: RandomState, +) -> Dict[str, Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..3d831c4af 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -17,13 +17,15 @@ from algoperf.data_utils import shard_and_maybe_pad_np -def preprocess_for_train(image: spec.Tensor, - rng: spec.RandomState, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - crop_size: int, - padding_size: int, - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_train( + image: spec.Tensor, + rng: spec.RandomState, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + crop_size: int, + padding_size: int, + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for training. Args: @@ -44,20 +46,23 @@ def preprocess_for_train(image: spec.Tensor, flip_rng = rng[1, :] image_shape = tf.shape(image) - image = tf.image.resize_with_crop_or_pad(image, - image_shape[0] + padding_size, - image_shape[1] + padding_size) + image = tf.image.resize_with_crop_or_pad( + image, image_shape[0] + padding_size, image_shape[1] + padding_size + ) image = tf.image.stateless_random_crop( - image, (crop_size, crop_size, 3), seed=crop_rng) + image, (crop_size, crop_size, 3), seed=crop_rng + ) image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng) image = normalize_image(image, mean_rgb, stddev_rgb, dtype=dtype) return image -def preprocess_for_eval(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_eval( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for evaluation. Args: @@ -74,10 +79,12 @@ def preprocess_for_eval(image: spec.Tensor, return image -def normalize_image(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - dtype=tf.float32) -> spec.Tensor: +def normalize_image( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + dtype=tf.float32, +) -> spec.Tensor: image = tf.image.convert_image_dtype(image, dtype) image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) @@ -85,17 +92,17 @@ def normalize_image(image: spec.Tensor, def create_split( - split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - train: bool, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - cache: bool = False, - repeat_final_dataset: bool = False, - crop_size: int = 32, - padding_size: int = 4, + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + train: bool, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + cache: bool = False, + repeat_final_dataset: bool = False, + crop_size: int = 32, + padding_size: int = 4, ) -> Iterator[Dict[str, spec.Tensor]]: """Creates a split from the CIFAR-10 dataset using TensorFlow Datasets.""" shuffle_rng, preprocess_rng = jax.random.split(rng, 2) @@ -104,14 +111,17 @@ def preprocess_example(example_index, example): dtype = tf.float32 if train: per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( - tf.cast(preprocess_rng, tf.int64), example_index) - image = preprocess_for_train(example['image'], - per_step_preprocess_rng, - mean_rgb, - stddev_rgb, - crop_size, - padding_size, - dtype) + tf.cast(preprocess_rng, tf.int64), example_index + ) + image = preprocess_for_train( + example['image'], + per_step_preprocess_rng, + mean_rgb, + stddev_rgb, + crop_size, + padding_size, + dtype, + ) else: image = preprocess_for_eval(example['image'], mean_rgb, stddev_rgb, dtype) return {'inputs': image, 'targets': example['label']} @@ -132,7 +142,8 @@ def preprocess_example(example_index, example): # index that we can fold into the RNG seed. ds = ds.enumerate() ds = ds.map( - preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) ds = ds.batch(global_batch_size, drop_remainder=train) if repeat_final_dataset: @@ -144,32 +155,36 @@ def preprocess_example(example_index, example): def create_input_iter( - split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - crop_size: int, - padding_size: int, - train: bool, - cache: bool, - repeat_final_dataset: bool) -> Iterator[Dict[str, spec.Tensor]]: + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + crop_size: int, + padding_size: int, + train: bool, + cache: bool, + repeat_final_dataset: bool, +) -> Iterator[Dict[str, spec.Tensor]]: ds = create_split( - split, - dataset_builder, - rng, - global_batch_size, - train=train, - mean_rgb=mean_rgb, - stddev_rgb=stddev_rgb, - cache=cache, - repeat_final_dataset=repeat_final_dataset, - crop_size=crop_size, - padding_size=padding_size) + split, + dataset_builder, + rng, + global_batch_size, + train=train, + mean_rgb=mean_rgb, + stddev_rgb=stddev_rgb, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + crop_size=crop_size, + padding_size=padding_size, + ) it = map( - functools.partial( - shard_and_maybe_pad_np, global_batch_size=global_batch_size), - ds) + functools.partial( + shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 957079272..8d034796f 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -25,48 +25,52 @@ class ResNet(nn.Module): act: Callable = nn.relu @nn.compact - def __call__(self, - x: spec.Tensor, - update_batch_norm: bool = True, - use_running_average_bn: bool = None) -> spec.Tensor: + def __call__( + self, + x: spec.Tensor, + update_batch_norm: bool = True, + use_running_average_bn: bool = None, + ) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm norm = functools.partial( - nn.BatchNorm, - use_running_average=use_running_average_bn, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype) + nn.BatchNorm, + use_running_average=use_running_average_bn, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = conv( - self.num_filters, (3, 3), (1, 1), - padding=[(1, 1), (1, 1)], - name='Conv_init')( - x) + self.num_filters, + (3, 3), + (1, 1), + padding=[(1, 1), (1, 1)], + name='Conv_init', + )(x) x = norm(name='BatchNorm_init')(x) x = nn.relu(x) for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( - self.num_filters * 2**i, - strides=strides, - conv=conv, - norm=norm, - act=self.act)( - x) + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + )(x) x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - dtype=self.dtype)( - x) + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x ResNet18 = functools.partial( - ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock +) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index c6cc50fbf..bc26e3899 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,32 +3,30 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax import linen as nn -from flax.core import pop import jax -from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds +from flax import jax_utils +from flax import linen as nn +from flax.core import pop +from jax import lax -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter from algoperf.workloads.cifar.workload import BaseCifarWorkload class CifarWorkload(BaseCifarWorkload): - def _build_cifar_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) train = split == 'train' @@ -38,38 +36,38 @@ def _build_cifar_dataset( elif split == 'validation': split = f'train[{self.num_train_examples}:]' ds = create_input_iter( - split, - ds_builder, - data_rng, - batch_size, - self.train_mean, - self.train_stddev, - self.crop_size, - self.padding_size, - train=train, - cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + split, + ds_builder, + data_rng, + batch_size, + self.train_mean, + self.train_stddev, + self.crop_size, + self.padding_size, + train=train, + cache=not train if cache is None else cache, + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches - return self._build_cifar_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_cifar_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: """Sync the batch statistics across replicas.""" # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics @@ -85,8 +83,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, - jnp.ones(input_shape, model.dtype)) + variables = jax.jit(model.init)( + {'params': rng}, jnp.ones(input_shape, model.dtype) + ) model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -98,43 +97,46 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + ) return logits, new_model_state else: logits = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -144,7 +146,8 @@ def loss_fn( one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + smoothed_targets * nn.log_softmax(logits_batch), axis=-1 + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -153,51 +156,53 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: summed_loss = self.loss_fn(labels, logits, weights)['summed'] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, + 'loss': summed_loss, + 'accuracy': accuracy, } metrics = lax.psum(metrics, axis_name='batch') return metrics @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) return self._compute_metrics(logits, batch['targets'], weights) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index e6a7a8a81..6beef89e6 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -12,23 +12,26 @@ from algoperf import spec from algoperf.init_utils import pytorch_default_init -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - BasicBlock -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - Bottleneck +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + BasicBlock, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + Bottleneck, +) from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 class ResNet(nn.Module): - - def __init__(self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 10, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 10, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -42,21 +45,26 @@ def __init__(self, replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( - 'replace_stride_with_dilation should be None ' - f'or a 3-element tuple, got {replace_stride_with_dilation}') + 'replace_stride_with_dilation should be None ' + f'or a 3-element tuple, got {replace_stride_with_dilation}' + ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) self.layer3 = self._make_layer( - block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) self.layer4 = self._make_layer( - block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) self.fc = nn.Linear(512 * block.expansion, num_classes) self.reset_parameters() @@ -68,7 +76,7 @@ def reset_parameters(self) -> None: nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.normal_(self.fc.weight, std=1e-2) - nn.init.constant_(self.fc.bias, 0.) + nn.init.constant_(self.fc.bias, 0.0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, @@ -81,12 +89,14 @@ def reset_parameters(self) -> None: elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, - block: Type[Union[BasicBlock, Bottleneck]], - planes: int, - blocks: int, - stride: int = 1, - dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -95,32 +105,39 @@ def _make_layer(self, stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = torch.nn.Sequential( - collections.OrderedDict([ - ("conv", conv1x1(self.inplanes, planes * block.expansion, - stride)), - ("bn", norm_layer(planes * block.expansion)), - ])) + collections.OrderedDict( + [ + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ('bn', norm_layer(planes * block.expansion)), + ] + ) + ) layers = [] layers.append( - block(self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer)) + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer)) + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index d05131c27..d7e858226 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -23,7 +23,6 @@ class CifarWorkload(BaseCifarWorkload): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Is set in submission_runner.py for workloads with PyTorch evaluation @@ -34,7 +33,8 @@ def __init__(self, *args, **kwargs) -> None: def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -42,47 +42,54 @@ def eval_num_workers(self, eval_num_workers: int): self._eval_num_workers = eval_num_workers def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> torch.utils.data.DataLoader: del cache del repeat_final_dataset is_train = split == 'train' - normalize = transforms.Compose([ + normalize = transforms.Compose( + [ transforms.ToTensor(), transforms.Normalize(mean=self.train_mean, std=self.train_stddev), - ]) + ] + ) eval_transform_config = normalize - train_transform_config = transforms.Compose([ + train_transform_config = transforms.Compose( + [ transforms.RandomCrop( - size=self.crop_size, - padding=self.padding_size, + size=self.crop_size, + padding=self.padding_size, ), transforms.RandomHorizontalFlip(), normalize, - ]) + ] + ) transform = train_transform_config if is_train else eval_transform_config dataset = CIFAR10( - root=data_dir, - train=split in ['train', 'eval_train', 'validation'], - download=False, - transform=transform) + root=data_dir, + train=split in ['train', 'eval_train', 'validation'], + download=False, + transform=transform, + ) assert self.num_train_examples + self.num_validation_examples == 50000 indices = list(range(50000)) indices_split = { - 'train': indices[:self.num_train_examples], - 'validation': indices[self.num_train_examples:], + 'train': indices[: self.num_train_examples], + 'validation': indices[self.num_train_examples :], } if split == 'eval_train': train_indices = indices_split['train'] random.Random(int(data_rng[0])).shuffle(train_indices) - indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] + indices_split['eval_train'] = train_indices[ + : self.num_eval_train_examples + ] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) @@ -92,30 +99,34 @@ def _build_dataset( ds_iter_batch_size = per_device_batch_size if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) else: ds_iter_batch_size = global_batch_size dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, - pin_memory=True, - drop_last=is_train) + dataset, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4 if is_train else self.eval_num_workers, + pin_memory=True, + drop_last=is_train, + ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) return dataloader def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate @@ -143,30 +154,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['fc.weight', 'fc.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng model = params if mode == spec.ForwardPassMode.EVAL: if update_batch_norm: raise ValueError( - 'Batch norm statistics cannot be updated during evaluation.') + 'Batch norm statistics cannot be updated during evaluation.' + ) model.eval() if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) @@ -175,11 +190,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -187,10 +203,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -199,25 +216,27 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: @@ -229,8 +248,8 @@ def _eval_model( return {'accuracy': accuracy, 'loss': summed_loss} def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index c0d565108..61880fbfa 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -15,7 +15,6 @@ class BaseCifarWorkload(spec.Workload): - _num_classes: int = 10 @property @@ -23,8 +22,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -51,8 +51,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -93,37 +94,35 @@ def eval_period_time_sec(self) -> int: return 600 # 10 mins. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: raise NotImplementedError def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches if split == 'test': if not cache: raise ValueError('cache must be True for split=test.') if not repeat_final_dataset: raise ValueError('repeat_final_dataset must be True for split=test.') - return self._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) @property def step_hint(self) -> int: @@ -133,39 +132,43 @@ def step_hint(self) -> int: return 4883 def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: raise NotImplementedError @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=True, - repeat_final_dataset=True) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=True, + repeat_final_dataset=True, + ) num_batches = int(math.ceil(num_examples / global_batch_size)) num_devices = max(torch.cuda.device_count(), jax.local_device_count()) @@ -174,10 +177,9 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) per_device_model_rngs = prng.split(model_rng, num_devices) # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - per_device_model_rngs) + synced_metrics = self._eval_model( + params, batch, model_state, per_device_model_rngs + ) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 4a91a80b8..706c2b51a 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -1,9 +1,10 @@ """A JAX implementation of DLRM-Small.""" + from typing import Sequence import flax.linen as nn -from jax import nn as jnn import jax.numpy as jnp +from jax import nn as jnn from algoperf.jax_utils import Dropout @@ -32,7 +33,6 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=DROPOUT_RATE): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -40,20 +40,18 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE): mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( - mlp_bottom_dims[0], - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), - )( - bot_mlp_input) + mlp_bottom_dims[0], + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0] ** 0.5), + )(bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), - )( - bot_mlp_input) + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), + )(bot_mlp_input) bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) @@ -63,34 +61,38 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE): def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) + embedding_table = self.param( + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape(embed_features, - (batch_size, 26 * self.embed_dim)) + embed_features = jnp.reshape( + embed_features, (batch_size, 26 * self.embed_dim) + ) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) top_mlp_input = nn.Dense( - mlp_top_dims[0], - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( - top_mlp_input) + mlp_top_dims[0], + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0])) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / mlp_top_dims[0])), + )(top_mlp_input) top_mlp_input = nn.relu(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( - top_mlp_input) + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out)) + ), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx]) + ), + )(top_mlp_input) x = nn.relu(x) if dropout_rate and layer_idx == num_layers_top - 2: x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate) @@ -98,11 +100,12 @@ def scaled_init(key, shape, dtype=jnp.float_): # In the DLRM model the last layer width is always 1. We can hardcode that # below. logits = nn.Dense( - 1, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( - top_mlp_input) + 1, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1)) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)), + )(top_mlp_input) return logits @@ -118,16 +121,18 @@ def dot_interact(concat_features): batch_size = concat_features.shape[0] # Interact features, select upper or lower-triangular portion, and reshape. - xactions = jnp.matmul(concat_features, - jnp.transpose(concat_features, [0, 2, 1])) + xactions = jnp.matmul( + concat_features, jnp.transpose(concat_features, [0, 2, 1]) + ) feature_dim = xactions.shape[-1] indices = jnp.array(jnp.triu_indices(feature_dim)) num_elems = indices.shape[1] indices = jnp.tile(indices, [1, batch_size]) indices0 = jnp.reshape( - jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), - [1, -1]) + jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), + [1, -1], + ) indices = tuple(jnp.concatenate((indices0, indices), 0)) activations = xactions[indices] activations = jnp.reshape(activations, [batch_size, -1]) @@ -156,25 +161,24 @@ class DlrmSmall(nn.Module): @nn.compact def __call__(self, x, train, dropout_rate=DROPOUT_RATE): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) # Bottom MLP. for dense_dim in self.mlp_bottom_dims: bot_mlp_input = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )(bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) if self.use_layer_norm: bot_mlp_input = nn.LayerNorm()(bot_mlp_input) bot_mlp_output = bot_mlp_input batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) + feature_stack = jnp.reshape( + bot_mlp_output, [batch_size, -1, self.embed_dim] + ) # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size @@ -187,38 +191,45 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE): def scaled_init(key, shape, dtype=jnp.float_): return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) + embedding_table = self.param( + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) idx_lookup = jnp.reshape(idx_lookup, [-1]) embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) + embed_features = jnp.reshape( + embed_features, [batch_size, -1, self.embed_dim] + ) if self.use_layer_norm: embed_features = nn.LayerNorm()(embed_features) feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) + top_mlp_input = jnp.concatenate( + [bot_mlp_output, dot_interact_output], axis=-1 + ) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) for layer_idx, fan_out in enumerate(mlp_top_dims): fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] top_mlp_input = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( - top_mlp_input) + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out)) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)), + )(top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_input = Dropout( - dropout_rate, deterministic=not train)( - top_mlp_input, rate=dropout_rate) + if ( + dropout_rate is not None + and dropout_rate > 0.0 + and layer_idx == num_layers_top - 2 + ): + top_mlp_input = Dropout(dropout_rate, deterministic=not train)( + top_mlp_input, rate=dropout_rate + ) logits = top_mlp_input return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index d84d18d5c..283b3be8e 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -3,26 +3,24 @@ import functools from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models -from algoperf.workloads.criteo1tb.workload import \ - BaseCriteo1TbDlrmSmallWorkload +from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): - @property def eval_batch_size(self) -> int: return 131_072 def _per_example_sigmoid_binary_cross_entropy( - self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: + self, logits: spec.Tensor, targets: spec.Tensor + ) -> spec.Tensor: """Computes the sigmoid binary cross entropy per example. Args: @@ -39,11 +37,12 @@ def _per_example_sigmoid_binary_cross_entropy( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense (not one-hot) labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense (not one-hot) labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -55,7 +54,8 @@ def loss_fn( label_batch = jnp.reshape(label_batch, (batch_size,)) logits_batch = jnp.reshape(logits_batch, (batch_size,)) per_example_losses = self._per_example_sigmoid_binary_cross_entropy( - logits=logits_batch, targets=label_batch) + logits=logits_batch, targets=label_batch + ) if mask_batch is not None: mask_batch = jnp.reshape(mask_batch, (batch_size,)) per_example_losses *= mask_batch @@ -64,15 +64,15 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } def init_model_fn( - self, - rng: spec.RandomState, - tabulate: Optional[bool] = False, + self, + rng: spec.RandomState, + tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" if self.use_resnet: @@ -81,13 +81,14 @@ def init_model_fn( model_class = models.DlrmSmall self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier, + ) params_rng, _ = jax.random.split(rng) init_fake_batch_size = 2 @@ -96,10 +97,12 @@ def init_model_fn( input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) - initial_variables = jax.jit(init_fn)({ + initial_variables = jax.jit(init_fn)( + { 'params': params_rng, - }, - jnp.ones(input_shape, jnp.float32)) + }, + jnp.ones(input_shape, jnp.float32), + ) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -109,14 +112,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm @@ -130,35 +133,38 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_batch_pmapped( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) summed_loss = self.loss_fn( - label_batch=batch['targets'], logits_batch=logits, - mask_batch=weights)['summed'] + label_batch=batch['targets'], logits_batch=logits, mask_batch=weights + )['summed'] return summed_loss - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + def _eval_batch( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64 + ) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): @@ -166,7 +172,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" @@ -200,7 +205,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property def validation_target_value(self) -> float: return 0.129657 diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7574de3a7..1906bf7ae 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,14 +5,13 @@ import torch from torch import nn -from algoperf.pytorch_utils import CustomDropout -from algoperf.pytorch_utils import SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout DROPOUT_RATE = 0.0 class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" + """Dense block with optional residual connection.""" '' def __init__(self, module, resnet=False): super().__init__() @@ -41,17 +40,20 @@ class DotInteract(nn.Module): def __init__(self, num_sparse_features): super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) + self.triu_indices = torch.triu_indices( + num_sparse_features + 1, num_sparse_features + 1 + ) def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] + combined_values = torch.cat( + (dense_features.unsqueeze(1), sparse_features), dim=1 + ) + interactions = torch.bmm( + combined_values, torch.transpose(combined_values, 1, 2) + ) + interactions_flat = interactions[ + :, self.triu_indices[0], self.triu_indices[1] + ] return torch.cat((dense_features, interactions_flat), dim=1) @@ -66,15 +68,17 @@ class DLRMResNet(nn.Module): embed_dim: embedding dimension. """ - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): + def __init__( + self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + use_layer_norm=False, + embedding_init_multiplier=None, + ): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -92,7 +96,8 @@ def __init__(self, scale = 1.0 / torch.sqrt(self.vocab_size) for i in range(num_chunks): chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim) + ) chunk.data.uniform_(0, 1) chunk.data = scale * chunk.data self.register_parameter(f'embedding_chunk_{i}', chunk) @@ -115,11 +120,11 @@ def __init__(self, for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) + limit = math.sqrt(6.0 / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) # Number of sparse features = 26 fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] @@ -144,19 +149,20 @@ def __init__(self, for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + module.weight.data, + 0.0, + math.sqrt(2.0 / (module.in_features + module.out_features)), + ) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) def forward(self, x, dropout_rate=DROPOUT_RATE): - batch_size = x.shape[0] dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) + x, [self.num_dense_features, self.num_sparse_features], 1 + ) # Bottom MLP. embedded_dense = self.bot_mlp(dense_features) @@ -166,8 +172,9 @@ def forward(self, x, dropout_rate=DROPOUT_RATE): idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size embedding_table = torch.cat(self.embedding_table_chucks, dim=0) embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) + embedded_sparse = torch.reshape( + embedded_sparse, [batch_size, 26 * self.embed_dim] + ) top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. @@ -186,15 +193,17 @@ class DlrmSmall(nn.Module): embed_dim: embedding dimension. """ - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - use_layer_norm=False, - embedding_init_multiplier=None): + def __init__( + self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + use_layer_norm=False, + embedding_init_multiplier=None, + ): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -218,7 +227,8 @@ def __init__(self, for i in range(num_chunks): chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim) + ) chunk.data.uniform_(0, 1) chunk.data = scale * chunk.data self.register_parameter(f'embedding_chunk_{i}', chunk) @@ -235,21 +245,24 @@ def __init__(self, self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) + limit = math.sqrt(6.0 / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + self.dot_interact = DotInteract( + num_sparse_features=num_sparse_features, + ) # TODO: Write down the formula here instead of the constant. input_dims = 506 num_layers_top = len(self.mlp_top_dims) top_mlp_layers = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] + fan_in = ( + input_dims if layer_idx == 0 else self.mlp_top_dims[layer_idx - 1] + ) top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) @@ -265,19 +278,20 @@ def __init__(self, for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + module.weight.data, + 0.0, + math.sqrt(2.0 / (module.in_features + module.out_features)), + ) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) def forward(self, x, dropout_rate=DROPOUT_RATE): - batch_size = x.shape[0] dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) + x, [self.num_dense_features, self.num_sparse_features], 1 + ) # Bottom MLP. embedded_dense = self.bot_mlp(dense_features) @@ -287,13 +301,15 @@ def forward(self, x, dropout_rate=DROPOUT_RATE): idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size embedding_table = torch.cat(self.embedding_table_chucks, dim=0) embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) + embedded_sparse = torch.reshape( + embedded_sparse, [batch_size, -1, self.embed_dim] + ) if self.embed_ln: embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) + dense_features=embedded_dense, sparse_features=embedded_sparse + ) # Final MLP. logits = self.top_mlp(concatenated_dense, dropout_rate) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 69f24c69d..74f91de43 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -7,24 +7,22 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models -from algoperf.workloads.criteo1tb.workload import \ - BaseCriteo1TbDlrmSmallWorkload +from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): - @property def eval_batch_size(self) -> int: return 8_192 def _per_example_sigmoid_binary_cross_entropy( - self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: + self, logits: spec.Tensor, targets: spec.Tensor + ) -> spec.Tensor: ls = torch.nn.LogSigmoid() log_p = ls(logits) log_not_p = ls(-logits) @@ -35,11 +33,12 @@ def _per_example_sigmoid_binary_cross_entropy( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense (not one-hot) labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense (not one-hot) labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -51,7 +50,8 @@ def loss_fn( label_batch = torch.reshape(label_batch, (batch_size,)) logits_batch = torch.reshape(logits_batch, (batch_size,)) per_example_losses = self._per_example_sigmoid_binary_cross_entropy( - logits=logits_batch, targets=label_batch) + logits=logits_batch, targets=label_batch + ) if mask_batch is not None: mask_batch = torch.reshape(mask_batch, (batch_size,)) per_example_losses *= mask_batch @@ -60,9 +60,9 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: @@ -74,13 +74,14 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: model_class = models.DlrmSmall model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -95,14 +96,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['top_mlp.4.weight', 'top_mlp.4.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -118,8 +119,8 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): @@ -128,14 +129,15 @@ def model_fn( return logits_batch, None def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) @@ -143,35 +145,42 @@ def _build_input_queue( # avoid creating too many threads. if RANK == 0: np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) weights = None while True: if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) + batch['inputs'], dtype=torch.float32, device=DEVICE + ) targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch['targets'], dtype=torch.float32, device=DEVICE + ) if not_train: weights = batch.get('weights') if weights is None: - weights = torch.ones((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + weights = torch.ones( + (N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE, + ) else: weights = torch.as_tensor( - weights, dtype=torch.float32, device=DEVICE) + weights, dtype=torch.float32, device=DEVICE + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=torch.int32, device=DEVICE) + len(targets[0]), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) dist.broadcast(weights, src=0) weights = weights[0] @@ -187,52 +196,57 @@ def _build_input_queue( else: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) - weights = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(weights, src=0) weights = weights[RANK] - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + inputs = torch.empty( + (N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(inputs, src=0) inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + targets = torch.empty( + (N_GPUS, per_device_batch_size, 1), dtype=torch.float32, device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights, + 'inputs': inputs, + 'targets': targets, + 'weights': weights, } yield batch - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + def _eval_batch( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = torch.ones(len(logits), device=DEVICE) summed_loss = self.loss_fn( - label_batch=batch['targets'], logits_batch=logits, - mask_batch=weights)['summed'] + label_batch=batch['targets'], logits_batch=logits, mask_batch=weights + )['summed'] return summed_loss.to(dtype=torch.float64) @@ -241,7 +255,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" @@ -275,7 +288,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property def validation_target_value(self) -> float: return 0.129657 diff --git a/algoperf/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py index 7e254336a..bce8b11c4 100644 --- a/algoperf/workloads/criteo1tb/input_pipeline.py +++ b/algoperf/workloads/criteo1tb/input_pipeline.py @@ -19,32 +19,32 @@ # Raw vocab sizes from # https://cloud.google.com/tpu/docs/tutorials/dlrm-dcn-2.x#run-model. _VOCAB_SIZES = [ - 39884406, - 39043, - 17289, - 7420, - 20263, - 3, - 7120, - 1543, - 63, - 38532951, - 2953546, - 403346, - 10, - 2208, - 11938, - 155, - 4, - 976, - 14, - 39979771, - 25641295, - 39664984, - 585935, - 12972, - 108, - 36, + 39884406, + 39043, + 17289, + 7420, + 20263, + 3, + 7120, + 1543, + 63, + 38532951, + 2953546, + 403346, + 10, + 2208, + 11938, + 155, + 4, + 976, + 14, + 39979771, + 25641295, + 39664984, + 585935, + 12972, + 108, + 36, ] @@ -60,7 +60,8 @@ def _parse_example_fn(num_dense_features, example): categorical_defaults = [['00000000'] for _ in range(len(_VOCAB_SIZES))] record_defaults = label_defaults + int_defaults + categorical_defaults fields = tf.io.decode_csv( - example, record_defaults, field_delim='\t', na_value='-1') + example, record_defaults, field_delim='\t', na_value='-1' + ) num_labels = 1 features = {} @@ -78,20 +79,24 @@ def _parse_example_fn(num_dense_features, example): # We append the column index to the string to make the same id in different # columns unique. cat_features.append( - tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx])) + tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx]) + ) cat_features = tf.cast( - tf.stack(cat_features, axis=1), dtype=int_features.dtype) + tf.stack(cat_features, axis=1), dtype=int_features.dtype + ) features['inputs'] = tf.concat([int_features, cat_features], axis=1) return features -def get_criteo1tb_dataset(split: str, - shuffle_rng, - data_dir: str, - num_dense_features: int, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): +def get_criteo1tb_dataset( + split: str, + shuffle_rng, + data_dir: str, + num_dense_features: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, +): """Get the Criteo 1TB dataset for a given split.""" num_test_files = _NUM_DAY_23_FILES // 2 + 1 if split in ['train', 'eval_train']: @@ -99,19 +104,20 @@ def get_criteo1tb_dataset(split: str, elif split == 'validation': # Assumes files are of the format day_23_04. file_paths = [ - os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') - for s in range(num_test_files, _NUM_DAY_23_FILES) + os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') + for s in range(num_test_files, _NUM_DAY_23_FILES) ] else: file_paths = [ - os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') - for s in range(0, num_test_files) + os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') + for s in range(0, num_test_files) ] is_training = split == 'train' shuffle = is_training or split == 'eval_train' ds = tf.data.Dataset.list_files( - file_paths, shuffle=shuffle, seed=shuffle_rng[0]) + file_paths, shuffle=shuffle, seed=shuffle_rng[0] + ) if shuffle: ds = ds.shuffle(buffer_size=1024) @@ -132,9 +138,10 @@ def get_criteo1tb_dataset(split: str, ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return ds diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 617b2e987..9fb819203 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -29,8 +29,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'loss' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/loss'] < self.validation_target_value @property @@ -71,8 +72,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -100,23 +102,25 @@ def eval_period_time_sec(self) -> int: return 2 * 60 # 2 mins. def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache ds = input_pipeline.get_criteo1tb_dataset( - split=split, - shuffle_rng=data_rng, - data_dir=data_dir, - num_dense_features=self.num_dense_features, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + split=split, + shuffle_rng=data_rng, + data_dir=data_dir, + num_dense_features=self.num_dense_features, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) for batch in iter(ds): yield batch @@ -126,15 +130,17 @@ def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" return 10_666 - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -142,12 +148,13 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng=rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=True) + data_rng=rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=True, + ) loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index a5fe060b9..b80c370ea 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -12,6 +12,7 @@ Data: github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ + import functools import flax.linen as nn @@ -31,7 +32,7 @@ def _instance_norm2d(x, axes, epsilon=1e-5): mean2 = jnp.mean(jnp.square(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. - var = jnp.maximum(0., mean2 - jnp.square(mean)) + var = jnp.maximum(0.0, mean2 - jnp.square(mean)) stats_shape = list(x.shape) for axis in axes: stats_shape[axis] = 1 @@ -46,16 +47,17 @@ def _instance_norm2d(x, axes, epsilon=1e-5): class UNet(nn.Module): """Jax / Flax implementation of a U-Net model. - O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks - for biomedical image segmentation. In International Conference on Medical - image computing and computer-assisted intervention, pages 234–241. - Springer, 2015. + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. - out_channels: Number of channels in the output to the U-Net model. - channels: Number of output channels of the first convolution layer. - num_pool_layers: Number of down-sampling and up-sampling layers. - dropout_rate: Dropout probability. + out_channels: Number of channels in the output to the U-Net model. + channels: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + dropout_rate: Dropout probability. """ + num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 @@ -67,14 +69,16 @@ class UNet(nn.Module): def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): # pylint: disable=invalid-name _ConvBlock = functools.partial( - ConvBlock, - dropout_rate=dropout_rate, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) + ConvBlock, + dropout_rate=dropout_rate, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) _TransposeConvBlock = functools.partial( - TransposeConvBlock, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) + TransposeConvBlock, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) down_sample_layers = [_ConvBlock(self.num_channels)] @@ -125,9 +129,9 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): output = jnp.concatenate((output, downsample_layer), axis=-1) output = conv(output, train) - output = nn.Conv( - self.out_channels, kernel_size=(1, 1), strides=(1, 1))( - output) + output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))( + output + ) return output.squeeze(-1) @@ -136,6 +140,7 @@ class ConvBlock(nn.Module): out_channels: Number of channels in the output. dropout_rate: Dropout probability. """ + out_channels: int use_tanh: bool use_layer_norm: bool @@ -152,11 +157,11 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False, + )(x) if self.use_layer_norm: x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: @@ -171,23 +176,23 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = Dropout( - dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x, rate=dropout_rate) + x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate + ) x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False, + )(x) if self.use_layer_norm: x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) - x = Dropout( - dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x, rate=dropout_rate) + x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate + ) return x @@ -195,6 +200,7 @@ class TransposeConvBlock(nn.Module): """A Transpose Convolutional Block. out_channels: Number of channels in the output. """ + out_channels: int use_tanh: bool use_layer_norm: bool @@ -208,8 +214,8 @@ def __call__(self, x): jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`. """ x = nn.ConvTranspose( - self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( - x) + self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False + )(x) x = _instance_norm2d(x, (1, 2)) if self.use_tanh: activation_fn = nn.tanh diff --git a/algoperf/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py index e15b93616..ca2ee1b60 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_jax/ssim.py @@ -49,12 +49,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): return ssims -def structural_similarity(im1, - im2, - data_range=1.0, - win_size=7, - k1=0.01, - k2=0.03): +def structural_similarity( + im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03 +): """Compute the mean structural similarity index between two images. NOTE(dsuo): modified from skimage.metrics.structural_similarity. @@ -85,7 +82,7 @@ def structural_similarity(im1, """ filter_func = functools.partial(_uniform_filter, size=win_size) - num_points = win_size**len(im1.shape) + num_points = win_size ** len(im1.shape) # filter has already normalized by num_points cov_norm = num_points / (num_points - 1) # sample covariance @@ -102,8 +99,8 @@ def structural_similarity(im1, vy = cov_norm * (uyy - uy * uy) vxy = cov_norm * (uxy - ux * uy) - c1 = (k1 * data_range)**2 - c2 = (k2 * data_range)**2 + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 a1 = 2 * ux * uy + c1 a2 = 2 * vxy + c2 @@ -121,12 +118,15 @@ def structural_similarity(im1, def _uniform_filter(im, size=7): - def conv(im): - return jnp.convolve( + return ( + jnp.convolve( jnp.pad(im, pad_width=size // 2, mode='symmetric'), jnp.ones(size), - mode='valid') / size + mode='valid', + ) + / size + ) im = jax.vmap(conv, (0,))(im) im = jax.vmap(conv, (1,))(im) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index bd0aa1d0b..08bb25014 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -4,32 +4,29 @@ import math from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE -from algoperf.workloads.fastmri.fastmri_jax.models import UNet +from algoperf import param_utils, spec +from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE, UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload class FastMRIWorkload(BaseFastMRIWorkload): - def init_model_fn( - self, - rng: spec.RandomState, + self, + rng: spec.RandomState, ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" fake_batch = jnp.zeros((13, 320, 320)) self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, ) params_rng, _ = jax.random.split(rng) @@ -45,34 +42,37 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Conv_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -81,8 +81,9 @@ def loss_fn( """ del label_smoothing per_example_losses = jnp.mean( - jnp.abs(logits_batch - label_batch), - axis=tuple(range(1, logits_batch.ndim))) + jnp.abs(logits_batch - label_batch), + axis=tuple(range(1, logits_batch.ndim)), + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -91,56 +92,63 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_model(self, - params: spec.Tensor, - batch: Dict[str, spec.Tensor], - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_model( + self, + params: spec.Tensor, + batch: Dict[str, spec.Tensor], + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: """Return the SSIM and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=rng, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) ssim_vals = ssim( - logits, - targets, - mean=batch['mean'], - std=batch['std'], - volume_max=batch['volume_max']) + logits, + targets, + mean=batch['mean'], + std=batch['std'], + volume_max=batch['volume_max'], + ) ssim_sum = jnp.sum(ssim_vals * weights) summed_loss = self.loss_fn(targets, logits, weights)['summed'] metrics = { - 'ssim': ssim_sum, - 'loss': summed_loss, + 'ssim': ssim_sum, + 'loss': summed_loss, } metrics = jax.lax.psum(metrics, axis_name='batch') return metrics - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -149,27 +157,27 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split, - data_dir, - global_batch_size=global_batch_size, - repeat_final_dataset=True, - num_batches=num_batches) - - total_metrics = {'ssim': 0., 'loss': 0.} + data_rng, + split, + data_dir, + global_batch_size=global_batch_size, + repeat_final_dataset=True, + num_batches=num_batches, + ) + + total_metrics = {'ssim': 0.0, 'loss': 0.0} eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. synced_metrics = self._eval_model(params, batch, eval_rngs) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k][0] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} class FastMRIModelSizeWorkload(FastMRIWorkload): - @property def num_pool_layers(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -190,7 +198,6 @@ def test_target_value(self) -> float: class FastMRITanhWorkload(FastMRIWorkload): - @property def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -206,7 +213,6 @@ def test_target_value(self) -> float: class FastMRILayerNormWorkload(FastMRIWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 8441f06c2..16cf8bd54 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -7,31 +7,31 @@ from functools import partial import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torch.nn import functional as F from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout2d -from algoperf.pytorch_utils import SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout DROPOUT_RATE = 0.0 class UNet(nn.Module): r"""U-Net model from - `"U-net: Convolutional networks - for biomedical image segmentation" - `_. - """ - - def __init__(self, - in_chans: int = 1, - out_chans: int = 1, - num_channels: int = 32, - num_pool_layers: int = 4, - use_tanh: bool = False, - use_layer_norm: bool = False) -> None: + `"U-net: Convolutional networks + for biomedical image segmentation" + `_. + """ + + def __init__( + self, + in_chans: int = 1, + out_chans: int = 1, + num_channels: int = 32, + num_pool_layers: int = 4, + use_tanh: bool = False, + use_layer_norm: bool = False, + ) -> None: super().__init__() self.in_chans = in_chans @@ -40,11 +40,13 @@ def __init__(self, self.num_pool_layers = num_pool_layers self.down_sample_layers = nn.ModuleList( - [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)]) + [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)] + ) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( - ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)) + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) + ) ch *= 2 self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) @@ -53,24 +55,26 @@ def __init__(self, for _ in range(num_pool_layers - 1): self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm) + ) self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm) + ) self.up_conv.append( - SequentialWithDropout( - ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), - nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), - )) + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor: - stack = [] output = x @@ -95,7 +99,7 @@ def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor: if output.shape[-2] != downsample_layer.shape[-2]: padding[3] = 1 # padding bottom if torch.sum(torch.tensor(padding)) != 0: - output = F.pad(output, padding, "reflect") + output = F.pad(output, padding, 'reflect') output = torch.cat([output, downsample_layer], dim=1) output = conv(output, dropout_rate) @@ -107,11 +111,9 @@ class ConvBlock(nn.Module): # A Convolutional Block that consists of two convolution layers each # followed by instance normalization, LeakyReLU activation and dropout_rate. - def __init__(self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool) -> None: + def __init__( + self, in_chans: int, out_chans: int, use_tanh: bool, use_layer_norm: bool + ) -> None: super().__init__() self._supports_custom_dropout = True @@ -124,14 +126,14 @@ def __init__(self, else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv_layers = SequentialWithDropout( - nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), - nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - CustomDropout2d(), + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), ) def forward(self, x: Tensor, dropout_rate: float) -> Tensor: @@ -143,11 +145,11 @@ class TransposeConvBlock(nn.Module): # layers followed by instance normalization and LeakyReLU activation. def __init__( - self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool, + self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool, ): super().__init__() if use_tanh: @@ -155,10 +157,11 @@ def __init__( else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layers = nn.Sequential( - nn.ConvTranspose2d( - in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_chans), - activation_fn, + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + activation_fn, ) def forward(self, x: Tensor) -> Tensor: diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py index 45b61bea4..7d594b959 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py @@ -32,9 +32,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): # NOTE(dsuo): `volume_max` can be 0 if we have a padded batch, but this will # lead to NaN values in `ssim`. - volume_max = torch.where(volume_max == 0, - torch.ones_like(volume_max), - volume_max) + volume_max = torch.where( + volume_max == 0, torch.ones_like(volume_max), volume_max + ) if mean is None: mean = torch.zeros(logits.shape[0], device=DEVICE) @@ -56,12 +56,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): return ssims -def structural_similarity(im1, - im2, - data_range=1.0, - win_size=7, - k1=0.01, - k2=0.03): +def structural_similarity( + im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03 +): """Compute the mean structural similarity index between two images. NOTE(dsuo): modified from skimage.metrics.structural_similarity. @@ -92,7 +89,7 @@ def structural_similarity(im1, """ filter_func = functools.partial(_uniform_filter, size=win_size) - num_points = win_size**len(im1.shape) + num_points = win_size ** len(im1.shape) # filter has already normalized by num_points cov_norm = num_points / (num_points - 1) # sample covariance @@ -109,8 +106,8 @@ def structural_similarity(im1, vy = cov_norm * (uyy - uy * uy) vxy = cov_norm * (uxy - ux * uy) - c1 = (k1 * data_range)**2 - c2 = (k2 * data_range)**2 + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 a1 = 2 * ux * uy + c1 a2 = 2 * vxy + c2 diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 1adbb57ca..bddf6b1f3 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -9,10 +9,8 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.fastmri.fastmri_pytorch import models from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim @@ -22,28 +20,31 @@ class FastMRIWorkload(BaseFastMRIWorkload): - - def _build_input_queue(self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): per_device_batch_size = int(global_batch_size / N_GPUS) # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: data_rng = data_rng.astype('uint32') - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset, - num_batches) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + num_batches, + ) while True: if RANK == 0: @@ -59,20 +60,23 @@ def _build_input_queue(self, else: aux_tensor_list.append(tensor) batch[key] = ( - tensor[0] if USE_PYTORCH_DDP else tensor.view( - -1, *value.shape[2:])) + tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, *value.shape[2:]) + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if split != 'train': # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(batch['inputs']), dtype=torch.int32, device=DEVICE) + len(batch['inputs']), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) weights = weights if 'weights' in batch else None if weights is None: - weights = torch.ones((N_GPUS, per_device_batch_size), - dtype=torch.float64, - device=DEVICE) + weights = torch.ones( + (N_GPUS, per_device_batch_size), + dtype=torch.float64, + device=DEVICE, + ) # Has no effect, but without it `batch` has no `weights` key # for RANK == 0, but has one for all others. batch['weights'] = weights[0] @@ -83,20 +87,22 @@ def _build_input_queue(self, batch = {} if split != 'train': # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) - weights = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.float64, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.float64, device=DEVICE + ) dist.broadcast(weights, src=0) batch['weights'] = weights[RANK] - tensors = torch.empty((2, N_GPUS, per_device_batch_size, 320, 320), - device=DEVICE) + tensors = torch.empty( + (2, N_GPUS, per_device_batch_size, 320, 320), device=DEVICE + ) dist.broadcast(tensors, src=0) - aux_tensors = torch.empty((3, N_GPUS, per_device_batch_size), - device=DEVICE) + aux_tensors = torch.empty( + (3, N_GPUS, per_device_batch_size), device=DEVICE + ) dist.broadcast(aux_tensors, src=0) # Note that the batch dict in the RANK == 0 process is ordered. batch['inputs'] = tensors[0][RANK] @@ -109,10 +115,11 @@ def _build_input_queue(self, def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -127,14 +134,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['up_conv.3.1.weight', 'up_conv.3.1.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -149,25 +156,27 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logit_batch = model( - augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), - dropout_rate=dropout_rate).squeeze(1) + augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), + dropout_rate=dropout_rate, + ).squeeze(1) return logit_batch, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -176,7 +185,8 @@ def loss_fn( """ del label_smoothing per_example_losses = F.l1_loss( - logits_batch, label_batch, reduction='none').mean(dim=(1, 2)) + logits_batch, label_batch, reduction='none' + ).mean(dim=(1, 2)) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -185,46 +195,52 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } - def _eval_model(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + def _eval_model( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: """Return the SSIM and loss as a dict.""" outputs, _ = self.model_fn( - params, - batch, - None, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + None, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: weights = torch.ones(len(outputs), device=DEVICE) weights_sum = weights.sum().to(torch.int) ssim_sum = ssim( - outputs[:weights_sum], - targets[:weights_sum], - mean=batch['mean'][:weights_sum], - std=batch['std'][:weights_sum], - volume_max=batch['volume_max'][:weights_sum]).sum() + outputs[:weights_sum], + targets[:weights_sum], + mean=batch['mean'][:weights_sum], + std=batch['std'][:weights_sum], + volume_max=batch['volume_max'][:weights_sum], + ).sum() summed_loss = self.loss_fn(targets, outputs, weights)['summed'] return {'ssim': ssim_sum, 'loss': summed_loss} - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -233,22 +249,23 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split, - data_dir, - global_batch_size=global_batch_size, - repeat_final_dataset=True, - num_batches=num_batches) + data_rng, + split, + data_dir, + global_batch_size=global_batch_size, + repeat_final_dataset=True, + num_batches=num_batches, + ) total_metrics = { - 'ssim': torch.tensor(0., device=DEVICE), - 'loss': torch.tensor(0., device=DEVICE), + 'ssim': torch.tensor(0.0, device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), } for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): @@ -257,7 +274,6 @@ def _eval_model_on_split(self, class FastMRIModelSizeWorkload(FastMRIWorkload): - @property def num_pool_layers(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -278,7 +294,6 @@ def test_target_value(self) -> float: class FastMRITanhWorkload(FastMRIWorkload): - @property def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -294,7 +309,6 @@ def test_target_value(self) -> float: class FastMRILayerNormWorkload(FastMRIWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" diff --git a/algoperf/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py index f20611f43..62b3219c5 100644 --- a/algoperf/workloads/fastmri/input_pipeline.py +++ b/algoperf/workloads/fastmri/input_pipeline.py @@ -16,12 +16,9 @@ _EVAL_SEED = 0 -def _process_example(kspace, - kspace_shape, - target, - target_shape, - volume_max, - seed): +def _process_example( + kspace, kspace_shape, target, target_shape, volume_max, seed +): """Generate a single example (slice from mri image). Args: @@ -45,15 +42,17 @@ def _process_example(kspace, acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32) num_low_frequencies = tf.cast( - num_cols_float * center_fraction, dtype=tf.int32) + num_cols_float * center_fraction, dtype=tf.int32 + ) # calculate_center_mask mask = tf.zeros(num_cols, dtype=tf.float32) pad = (num_cols - num_low_frequencies + 1) // 2 mask = tf.tensor_scatter_nd_update( - mask, - tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), - tf.ones(num_low_frequencies)) + mask, + tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), + tf.ones(num_low_frequencies), + ) # reshape_mask center_mask = tf.reshape(mask, (1, num_cols)) @@ -61,10 +60,12 @@ def _process_example(kspace, # calculate_acceleration_mask num_low_frequencies_float = tf.cast(num_low_frequencies, dtype=tf.float32) prob = (num_cols_float / acceleration - num_low_frequencies_float) / ( - num_cols_float - num_low_frequencies_float) + num_cols_float - num_low_frequencies_float + ) mask = tf.cast( - tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32) + tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32 + ) acceleration_mask = tf.reshape(mask, (1, num_cols)) mask = tf.math.maximum(center_mask, acceleration_mask) @@ -78,9 +79,11 @@ def _process_example(kspace, shifted_image = tf.signal.ifft2d(shifted_kspace) image = tf.signal.fftshift(shifted_image, axes=(0, 1)) scaling_norm = tf.cast( - tf.math.sqrt( - tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')), - kspace.dtype) + tf.math.sqrt( + tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32') + ), + kspace.dtype, + ) image = image * scaling_norm image = tf.stack((tf.math.real(image), tf.math.imag(image)), axis=-1) @@ -108,48 +111,58 @@ def _process_example(kspace, target = tf.clip_by_value(norm_target, -6, 6) return { - 'inputs': image, - 'targets': target, - 'mean': mean, - 'std': std, - 'volume_max': volume_max, + 'inputs': image, + 'targets': target, + 'mean': mean, + 'std': std, + 'volume_max': volume_max, } def _h5_to_examples(path, log=False): """Yield MRI slices from an hdf5 file containing a single MRI volume.""" if log: - tf.print('fastmri_dataset._h5_to_examples call:', - path, - datetime.datetime.now().strftime('%H:%M:%S:%f')) + tf.print( + 'fastmri_dataset._h5_to_examples call:', + path, + datetime.datetime.now().strftime('%H:%M:%S:%f'), + ) with open(path, 'rb') as gf: with h5py.File(gf, 'r') as hf: # NOTE(dsuo): logic taken from reference code volume_max = hf.attrs.get('max', 0.0) for i in range(hf['kspace'].shape[0]): - yield hf['kspace'][i], hf['kspace'][i].shape, hf['reconstruction_esc'][ - i], hf['reconstruction_esc'][i].shape, volume_max + yield ( + hf['kspace'][i], + hf['kspace'][i].shape, + hf['reconstruction_esc'][i], + hf['reconstruction_esc'][i].shape, + volume_max, + ) def _create_generator(filename): signature = ( - tf.TensorSpec(shape=(640, None), dtype=tf.complex64), - tf.TensorSpec(shape=(2,), dtype=tf.int32), - tf.TensorSpec(shape=(320, 320), dtype=tf.float32), - tf.TensorSpec(shape=(2,), dtype=tf.int32), - tf.TensorSpec(shape=(), dtype=tf.float32), + tf.TensorSpec(shape=(640, None), dtype=tf.complex64), + tf.TensorSpec(shape=(2,), dtype=tf.int32), + tf.TensorSpec(shape=(320, 320), dtype=tf.float32), + tf.TensorSpec(shape=(2,), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.float32), ) return tf.data.Dataset.from_generator( - _h5_to_examples, args=(filename,), output_signature=signature) + _h5_to_examples, args=(filename,), output_signature=signature + ) -def load_fastmri_split(global_batch_size, - split, - data_dir, - shuffle_rng, - num_batches, - repeat_final_eval_dataset): +def load_fastmri_split( + global_batch_size, + split, + data_dir, + shuffle_rng, + num_batches, + repeat_final_eval_dataset, +): """Creates a split from the FastMRI dataset using tf.data. NOTE: only creates knee singlecoil datasets. @@ -169,11 +182,13 @@ def load_fastmri_split(global_batch_size, # Check if data directories exist because glob will not raise an error if not os.path.exists(os.path.join(data_dir, _TRAIN_DIR)): - raise NotADirectoryError('Directory not found: {}'.format( - os.path.join(data_dir, _TRAIN_DIR))) + raise NotADirectoryError( + 'Directory not found: {}'.format(os.path.join(data_dir, _TRAIN_DIR)) + ) if not os.path.exists(os.path.join(data_dir, _VAL_DIR)): - raise NotADirectoryError('Directory not found: {}'.format( - os.path.join(data_dir, _VAL_DIR))) + raise NotADirectoryError( + 'Directory not found: {}'.format(os.path.join(data_dir, _VAL_DIR)) + ) if split in ['train', 'eval_train']: file_pattern = os.path.join(data_dir, _TRAIN_DIR, '*.h5') @@ -190,10 +205,8 @@ def load_fastmri_split(global_batch_size, shuffle = is_train or split == 'eval_train' ds = tf.data.Dataset.from_tensor_slices(h5_paths) ds = ds.interleave( - _create_generator, - cycle_length=32, - block_length=64, - num_parallel_calls=16) + _create_generator, cycle_length=32, block_length=64, num_parallel_calls=16 + ) if is_train: ds = ds.cache() @@ -201,7 +214,8 @@ def process_example(example_index, example): if shuffle: process_rng = tf.cast(jax.random.fold_in(shuffle_rng, 0), tf.int64) process_rng = tf.random.experimental.stateless_fold_in( - process_rng, example_index) + process_rng, example_index + ) else: # NOTE(dsuo): we use fixed randomness for eval. process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64) @@ -211,9 +225,8 @@ def process_example(example_index, example): if shuffle: ds = ds.shuffle( - 16 * global_batch_size, - seed=shuffle_rng[0], - reshuffle_each_iteration=True) + 16 * global_batch_size, seed=shuffle_rng[0], reshuffle_each_iteration=True + ) if is_train: ds = ds.repeat() @@ -231,7 +244,8 @@ def process_example(example_index, example): ds = ds.repeat() ds = ds.prefetch(10) return map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 051749cc3..0b1ecfaa1 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -8,7 +8,6 @@ class BaseFastMRIWorkload(spec.Workload): - @property def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" @@ -61,8 +60,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -106,18 +106,22 @@ def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" return 18_094 - def _build_input_queue(self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): del cache - return input_pipeline.load_fastmri_split(global_batch_size, - split, - data_dir, - data_rng, - num_batches, - repeat_final_dataset) + return input_pipeline.load_fastmri_split( + global_batch_size, + split, + data_dir, + data_rng, + num_batches, + repeat_final_dataset, + ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 3d6939218..53368b384 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -1,5 +1,5 @@ """ -Note: +Note: The following code is adapted from: https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image @@ -12,35 +12,39 @@ import tensorflow as tf _IMAGE_DTYPES = { - tf.dtypes.uint8, - tf.dtypes.int32, - tf.dtypes.int64, - tf.dtypes.float16, - tf.dtypes.float32, - tf.dtypes.float64, + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, } -Number = Union[float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64,] - -TensorLike = Union[List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable,] +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] def get_ndims(image): @@ -50,16 +54,19 @@ def get_ndims(image): def to_4d_image(image): """Convert 2/3/4D image to 4D image. - Args: - image: 2/3/4D `Tensor`. + Args: + image: 2/3/4D `Tensor`. - Returns: - 4D `Tensor` with the same type. - """ - with tf.control_dependencies([ + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies( + [ tf.debugging.assert_rank_in( - image, [2, 3, 4], message="`image` must be 2/3/4D tensor") - ]): + image, [2, 3, 4], message='`image` must be 2/3/4D tensor' + ) + ] + ): ndims = image.get_shape().ndims if ndims is None: return _dynamic_to_4d_image(image) @@ -80,12 +87,12 @@ def _dynamic_to_4d_image(image): left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) new_shape = tf.concat( - [ - tf.ones(shape=left_pad, dtype=tf.int32), - shape, - tf.ones(shape=right_pad, dtype=tf.int32), - ], - axis=0, + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, ) return tf.reshape(image, new_shape) @@ -93,16 +100,16 @@ def _dynamic_to_4d_image(image): def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. - Args: - image: 4D `Tensor`. - ndims: The original rank of the image. + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. - Returns: - `ndims`-D `Tensor` with the same type. - """ + Returns: + `ndims`-D `Tensor` with the same type. + """ with tf.control_dependencies( - [tf.debugging.assert_rank(image, 4, - message="`image` must be 4D tensor")]): + [tf.debugging.assert_rank(image, 4, message='`image` must be 4D tensor')] + ): if isinstance(ndims, tf.Tensor): return _dynamic_from_4d_image(image, ndims) elif ndims == 2: @@ -125,63 +132,64 @@ def _dynamic_from_4d_image(image, original_rank): def transform( - images: TensorLike, - transforms: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - output_shape: Optional[list] = None, - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + transforms: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Applies the given transform(s) to the image(s). - Args: - images: A tensor of shape (num_images, num_rows, num_columns, - num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). - transforms: Projective transform matrix/matrices. A vector of length 8 or - tensor of size N x 8. If one row of transforms is - [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point - `(x, y)` to a transformed *input* point - `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, - where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. Note that - gradients are not backpropagated into transformation parameters. - interpolation: Interpolation mode. - Supported values: "nearest", "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - output_shape: Output dimesion after the transform, [height, width]. - If None, output is the same size as input image. - - name: The name of the op. - - Returns: - Image(s) with the same type and shape as `images`, with the given - transform(s) applied. Transformed coordinates outside of the input image - will be filled with zeros. - - Raises: - TypeError: If `image` is an invalid type. - ValueError: If output shape is not 1-D int32 Tensor. - """ - with tf.name_scope(name or "transform"): - image_or_images = tf.convert_to_tensor(images, name="images") + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or 'transform'): + image_or_images = tf.convert_to_tensor(images, name='images') transform_or_transforms = tf.convert_to_tensor( - transforms, name="transforms", dtype=tf.dtypes.float32) + transforms, name='transforms', dtype=tf.dtypes.float32 + ) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) @@ -189,61 +197,67 @@ def transform( output_shape = tf.shape(images)[1:3] output_shape = tf.convert_to_tensor( - output_shape, tf.dtypes.int32, name="output_shape") + output_shape, tf.dtypes.int32, name='output_shape' + ) if not output_shape.get_shape().is_compatible_with([2]): - raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " - "new_height, new_width") + raise ValueError( + 'output_shape must be a 1-D Tensor of 2 elements: new_height, new_width' + ) if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: - raise ValueError("transforms rank must be statically known") + raise ValueError('transforms rank must be statically known') elif len(transform_or_transforms.get_shape()) == 2: transforms = transform_or_transforms else: transforms = transform_or_transforms - raise ValueError("transforms should have rank 1 or 2, but got rank %d" % - len(transforms.get_shape())) + raise ValueError( + 'transforms should have rank 1 or 2, but got rank %d' + % len(transforms.get_shape()) + ) fill_value = tf.convert_to_tensor( - fill_value, dtype=tf.float32, name="fill_value") + fill_value, dtype=tf.float32, name='fill_value' + ) output = tf.raw_ops.ImageProjectiveTransformV3( - images=images, - transforms=transforms, - output_shape=output_shape, - interpolation=interpolation.upper(), - fill_mode=fill_mode.upper(), - fill_value=fill_value, + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, ) return from_4d_image(output, original_ndims) def angles_to_projective_transforms( - angles: TensorLike, - image_height: TensorLike, - image_width: TensorLike, - name: Optional[str] = None, + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, ) -> tf.Tensor: """Returns projective transform(s) for the given angle(s). - Args: - angles: A scalar angle to rotate all images by, or (for batches of - images) a vector with an angle to rotate each image in the batch. The - rank must be statically known (the shape is not `TensorShape(None)`. - image_height: Height of the image(s) to be transformed. - image_width: Width of the image(s) to be transformed. - - Returns: - A tensor of shape (num_images, 8). Projective transforms which can be - given to `transform` op. - """ - with tf.name_scope(name or "angles_to_projective_transforms"): + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or 'angles_to_projective_transforms'): angle_or_angles = tf.convert_to_tensor( - angles, name="angles", dtype=tf.dtypes.float32) + angles, name='angles', dtype=tf.dtypes.float32 + ) if len(angle_or_angles.get_shape()) not in (0, 1): - raise ValueError("angles should have rank 0 or 1.") + raise ValueError('angles should have rank 0 or 1.') if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] @@ -252,112 +266,116 @@ def angles_to_projective_transforms( cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) - x_offset = ((image_width - 1) - - (cos_angles * (image_width - 1) - sin_angles * - (image_height - 1))) / 2.0 - y_offset = ((image_height - 1) - - (sin_angles * (image_width - 1) + cos_angles * - (image_height - 1))) / 2.0 + x_offset = ( + (image_width - 1) + - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1)) + ) / 2.0 + y_offset = ( + (image_height - 1) + - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1)) + ) / 2.0 num_angles = tf.shape(angles)[0] return tf.concat( - values=[ - cos_angles[:, None], - -sin_angles[:, None], - x_offset[:, None], - sin_angles[:, None], - cos_angles[:, None], - y_offset[:, None], - tf.zeros((num_angles, 2), tf.dtypes.float32), - ], - axis=1, + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, ) def rotate_img( - images: TensorLike, - angles: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + angles: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Rotate image(s) counterclockwise by the passed angle(s) in radians. - Args: - images: A tensor of shape - `(num_images, num_rows, num_columns, num_channels)` - (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or - `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by (if `images` has rank 4) - a vector of length num_images, with an angle for each image in the - batch. - interpolation: Interpolation mode. Supported values: "nearest", - "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - name: The name of the op. - - Returns: - Image(s) with the same type and shape as `images`, rotated by the given - angle(s). Empty space due to the rotation will be filled with zeros. - - Raises: - TypeError: If `images` is an invalid type. - """ - with tf.name_scope(name or "rotate"): + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'rotate'): image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] output = transform( - images, - angles_to_projective_transforms(angles, image_height, image_width), - interpolation=interpolation, - fill_mode=fill_mode, - fill_value=fill_value, + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, ) return from_4d_image(output, original_ndims) -def translations_to_projective_transforms(translations: TensorLike, - name: Optional[str] = None - ) -> tf.Tensor: +def translations_to_projective_transforms( + translations: TensorLike, name: Optional[str] = None +) -> tf.Tensor: """Returns projective transform(s) for the given translation(s). - Args: - translations: A 2-element list representing `[dx, dy]` or a matrix of - 2-element lists representing `[dx, dy]` to translate for each image - (for a batch of images). The rank must be statically known - (the shape is not `TensorShape(None)`). - name: The name of the op. - Returns: - A tensor of shape `(num_images, 8)` projective transforms which can be - given to `tfa.image.transform`. - """ - with tf.name_scope(name or "translations_to_projective_transforms"): + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or 'translations_to_projective_transforms'): translation_or_translations = tf.convert_to_tensor( - translations, name="translations", dtype=tf.dtypes.float32) + translations, name='translations', dtype=tf.dtypes.float32 + ) if translation_or_translations.get_shape().ndims is None: raise TypeError( - "translation_or_translations rank must be statically known") + 'translation_or_translations rank must be statically known' + ) if len(translation_or_translations.get_shape()) not in (1, 2): - raise TypeError("Translations should have rank 1 or 2.") + raise TypeError('Translations should have rank 1 or 2.') if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] @@ -372,67 +390,67 @@ def translations_to_projective_transforms(translations: TensorLike, # where the last entry is implicit. # Translation matrices are always float32. return tf.concat( - values=[ - tf.ones((num_translations, 1), tf.dtypes.float32), - tf.zeros((num_translations, 1), tf.dtypes.float32), - -translations[:, 0, None], - tf.zeros((num_translations, 1), tf.dtypes.float32), - tf.ones((num_translations, 1), tf.dtypes.float32), - -translations[:, 1, None], - tf.zeros((num_translations, 2), tf.dtypes.float32), - ], - axis=1, + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, ) @tf.function def translate( - images: TensorLike, - translations: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + translations: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Translate image(s) by the passed vectors(s). - Args: - images: A tensor of shape - `(num_images, num_rows, num_columns, num_channels)` (NHWC), - `(num_rows, num_columns, num_channels)` (HWC), or - `(num_rows, num_columns)` (HW). The rank must be statically known (the - shape is not `TensorShape(None)`). - translations: A vector representing `[dx, dy]` or (if `images` has rank 4) - a matrix of length num_images, with a `[dx, dy]` vector for each image - in the batch. - interpolation: Interpolation mode. Supported values: "nearest", - "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - name: The name of the op. - Returns: - Image(s) with the same type and shape as `images`, translated by the - given vector(s). Empty space due to the translation will be filled with - zeros. - Raises: - TypeError: If `images` is an invalid type. - """ - with tf.name_scope(name or "translate"): + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'translate'): return transform( - images, - translations_to_projective_transforms(translations), - interpolation=interpolation, - fill_mode=fill_mode, - fill_value=fill_value, + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..fc42f4a5b 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -17,19 +17,21 @@ from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { - 'train': 'train', 'eval_train': 'train', 'validation': 'validation' + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', } -def _distorted_bounding_box_crop(image_bytes: spec.Tensor, - rng: spec.RandomState, - bbox: spec.Tensor, - min_object_covered: float = 0.1, - aspect_ratio_range: Tuple[float, - float] = (0.75, - 1.33), - area_range: Tuple[float, float] = (0.05, 1.0), - max_attempts: int = 100) -> spec.Tensor: +def _distorted_bounding_box_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + bbox: spec.Tensor, + min_object_covered: float = 0.1, + aspect_ratio_range: Tuple[float, float] = (0.75, 1.33), + area_range: Tuple[float, float] = (0.05, 1.0), + max_attempts: int = 100, +) -> spec.Tensor: """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. @@ -57,14 +59,15 @@ def _distorted_bounding_box_crop(image_bytes: spec.Tensor, """ shape = tf.io.extract_jpeg_shape(image_bytes) bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box( - shape, - seed=rng, - bounding_boxes=bbox, - min_object_covered=min_object_covered, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - max_attempts=max_attempts, - use_image_if_no_bounding_boxes=True) + shape, + seed=rng, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True, + ) # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) @@ -84,8 +87,9 @@ def resize(image: spec.Tensor, image_size: int) -> spec.Tensor: Returns: Resized image 'Tensor'. """ - return tf.image.resize([image], [image_size, image_size], - method=tf.image.ResizeMethod.BICUBIC)[0] + return tf.image.resize( + [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC + )[0] def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool: @@ -95,80 +99,93 @@ def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool: return tf.greater_equal(tf.reduce_sum(match), x) -def _decode_and_random_crop(image_bytes: spec.Tensor, - rng: spec.RandomState, - image_size: int, - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - resize_size: int) -> spec.Tensor: +def _decode_and_random_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + image_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + resize_size: int, +) -> spec.Tensor: """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = _distorted_bounding_box_crop( - image_bytes, - rng, - bbox, - min_object_covered=0.1, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - max_attempts=10) + image_bytes, + rng, + bbox, + min_object_covered=0.1, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=10, + ) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( - bad, - lambda: _decode_and_center_crop(image_bytes, image_size, resize_size), - lambda: resize(image, image_size)) + bad, + lambda: _decode_and_center_crop(image_bytes, image_size, resize_size), + lambda: resize(image, image_size), + ) return image -def _decode_and_center_crop(image_bytes: spec.Tensor, - image_size: int, - resize_size: int) -> spec.Tensor: +def _decode_and_center_crop( + image_bytes: spec.Tensor, image_size: int, resize_size: int +) -> spec.Tensor: """Crops to center of image with padding then scales image_size.""" shape = tf.io.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( - ((image_size / resize_size) * - tf.cast(tf.minimum(image_height, image_width), tf.float32)), - tf.int32) + ( + (image_size / resize_size) + * tf.cast(tf.minimum(image_height, image_width), tf.float32) + ), + tf.int32, + ) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 - crop_window = tf.stack([ + crop_window = tf.stack( + [ offset_height, offset_width, padded_center_crop_size, padded_center_crop_size, - ]) + ] + ) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = resize(image, image_size) return image -def normalize_image(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float]) -> spec.Tensor: +def normalize_image( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], +) -> spec.Tensor: image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) return image -def preprocess_for_train(image_bytes: spec.Tensor, - rng: spec.RandomState, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - image_size: int, - resize_size: int, - dtype: tf.DType = tf.float32, - use_randaug: bool = False, - randaug_num_layers: int = 2, - randaug_magnitude: int = 10) -> spec.Tensor: +def preprocess_for_train( + image_bytes: spec.Tensor, + rng: spec.RandomState, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, + use_randaug: bool = False, + randaug_num_layers: int = 2, + randaug_magnitude: int = 10, +) -> spec.Tensor: """Preprocesses the given image for training. Args: @@ -182,33 +199,36 @@ def preprocess_for_train(image_bytes: spec.Tensor, """ rngs = tf.random.experimental.stateless_split(rng, 3) - image = _decode_and_random_crop(image_bytes, - rngs[0], - image_size, - aspect_ratio_range, - area_range, - resize_size) + image = _decode_and_random_crop( + image_bytes, + rngs[0], + image_size, + aspect_ratio_range, + area_range, + resize_size, + ) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.stateless_random_flip_left_right(image, seed=rngs[1]) if use_randaug: image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) - image = randaugment.distort_image_with_randaugment(image, - randaug_num_layers, - randaug_magnitude, - rngs[2]) + image = randaugment.distort_image_with_randaugment( + image, randaug_num_layers, randaug_magnitude, rngs[2] + ) image = tf.cast(image, tf.float32) image = normalize_image(image, mean_rgb, stddev_rgb) image = tf.image.convert_image_dtype(image, dtype=dtype) return image -def preprocess_for_eval(image_bytes: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int, - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_eval( + image_bytes: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for evaluation. Args: @@ -229,10 +249,12 @@ def preprocess_for_eval(image_bytes: spec.Tensor, # Modified from # github.com/google/init2winit/blob/master/init2winit/dataset_lib/ (cont. below) # image_preprocessing.py. -def mixup_tf(key: spec.RandomState, - inputs: spec.Tensor, - targets: spec.Tensor, - alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]: +def mixup_tf( + key: spec.RandomState, + inputs: spec.Tensor, + targets: spec.Tensor, + alpha: float = 0.2, +) -> Tuple[spec.Tensor, spec.Tensor]: """Perform mixup https://arxiv.org/abs/1710.09412. NOTE: Code taken from https://github.com/google/big_vision with variables @@ -261,24 +283,26 @@ def mixup_tf(key: spec.RandomState, return inputs, targets -def create_split(split, - dataset_builder, - rng, - global_batch_size, - train, - image_size, - resize_size, - mean_rgb, - stddev_rgb, - cache=False, - repeat_final_dataset=False, - aspect_ratio_range=(0.75, 4.0 / 3.0), - area_range=(0.08, 1.0), - use_mixup=False, - mixup_alpha=0.1, - use_randaug=False, - randaug_num_layers=2, - randaug_magnitude=10) -> Iterator[Dict[str, spec.Tensor]]: +def create_split( + split, + dataset_builder, + rng, + global_batch_size, + train, + image_size, + resize_size, + mean_rgb, + stddev_rgb, + cache=False, + repeat_final_dataset=False, + aspect_ratio_range=(0.75, 4.0 / 3.0), + area_range=(0.08, 1.0), + use_mixup=False, + mixup_alpha=0.1, + use_randaug=False, + randaug_num_layers=2, + randaug_magnitude=10, +) -> Iterator[Dict[str, spec.Tensor]]: """Creates a split from the ImageNet dataset using TensorFlow Datasets.""" shuffle_rng, preprocess_rng, mixup_rng = jax.random.split(rng, 3) @@ -286,34 +310,35 @@ def decode_example(example_index, example): dtype = tf.float32 if train: per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( - tf.cast(preprocess_rng, tf.int64), example_index) - - image = preprocess_for_train(example['image'], - per_step_preprocess_rng, - mean_rgb, - stddev_rgb, - aspect_ratio_range, - area_range, - image_size, - resize_size, - dtype, - use_randaug, - randaug_num_layers, - randaug_magnitude) + tf.cast(preprocess_rng, tf.int64), example_index + ) + + image = preprocess_for_train( + example['image'], + per_step_preprocess_rng, + mean_rgb, + stddev_rgb, + aspect_ratio_range, + area_range, + image_size, + resize_size, + dtype, + use_randaug, + randaug_num_layers, + randaug_magnitude, + ) else: - image = preprocess_for_eval(example['image'], - mean_rgb, - stddev_rgb, - image_size, - resize_size, - dtype) + image = preprocess_for_eval( + example['image'], mean_rgb, stddev_rgb, image_size, resize_size, dtype + ) return {'inputs': image, 'targets': example['label']} ds = dataset_builder.as_dataset( - split=TFDS_SPLIT_NAME[split], - decoders={ - 'image': tfds.decode.SkipDecoding(), - }) + split=TFDS_SPLIT_NAME[split], + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) options = tf.data.Options() options.threading.private_threadpool_size = 48 ds = ds.with_options(options) @@ -336,18 +361,21 @@ def decode_example(example_index, example): def mixup_batch(batch_index, batch): per_batch_mixup_rng = tf.random.experimental.stateless_fold_in( - mixup_rng, batch_index) + mixup_rng, batch_index + ) (inputs, targets) = mixup_tf( - per_batch_mixup_rng, - batch['inputs'], - batch['targets'], - alpha=mixup_alpha) + per_batch_mixup_rng, + batch['inputs'], + batch['targets'], + alpha=mixup_alpha, + ) batch['inputs'] = inputs batch['targets'] = targets return batch ds = ds.enumerate().map( - mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) + mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) else: raise ValueError('Mixup can only be used for the training split.') @@ -359,44 +387,48 @@ def mixup_batch(batch_index, batch): return ds -def create_input_iter(split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int, - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - train: bool, - cache: bool, - repeat_final_dataset: bool, - use_mixup: bool, - mixup_alpha: float, - use_randaug: bool) -> Iterator[Dict[str, spec.Tensor]]: +def create_input_iter( + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + train: bool, + cache: bool, + repeat_final_dataset: bool, + use_mixup: bool, + mixup_alpha: float, + use_randaug: bool, +) -> Iterator[Dict[str, spec.Tensor]]: ds = create_split( - split, - dataset_builder, - rng, - global_batch_size, - train=train, - image_size=image_size, - resize_size=resize_size, - mean_rgb=mean_rgb, - stddev_rgb=stddev_rgb, - cache=cache, - repeat_final_dataset=repeat_final_dataset, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - use_mixup=use_mixup, - mixup_alpha=mixup_alpha, - use_randaug=use_randaug) + split, + dataset_builder, + rng, + global_batch_size, + train=train, + image_size=image_size, + resize_size=resize_size, + mean_rgb=mean_rgb, + stddev_rgb=stddev_rgb, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + use_mixup=use_mixup, + mixup_alpha=mixup_alpha, + use_randaug=use_randaug, + ) it = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. it = jax_utils.prefetch_to_device(it, 2) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index ffa60b260..1f3911708 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -17,12 +17,13 @@ class ResNetBlock(nn.Module): """ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: Tuple[int, int] = (1, 1) - bn_init_scale: float = 0. + bn_init_scale: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor) -> spec.Tensor: @@ -35,8 +36,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: if residual.shape != y.shape or self.strides != (1, 1): residual = self.conv( - self.filters, (1, 1), self.strides, name='Conv_proj')( - residual) + self.filters, (1, 1), self.strides, name='Conv_proj' + )(residual) residual = self.norm(name='BatchNorm_proj')(residual) return self.act(residual + y) @@ -44,6 +45,7 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef @@ -65,8 +67,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: if residual.shape != y.shape or self.strides != (1, 1): residual = self.conv( - self.filters * 4, (1, 1), self.strides, name='Conv_proj')( - residual) + self.filters * 4, (1, 1), self.strides, name='Conv_proj' + )(residual) residual = self.norm(name='BatchNorm_proj')(residual) return self.act(residual + y) @@ -79,30 +81,35 @@ class ResNet(nn.Module): num_filters: int = 64 dtype: Any = jnp.float32 act: Callable = nn.relu - bn_init_scale: float = 0. + bn_init_scale: float = 0.0 @nn.compact - def __call__(self, - x: spec.Tensor, - update_batch_norm: bool = True, - use_running_average_bn: Optional[bool] = None) -> spec.Tensor: + def __call__( + self, + x: spec.Tensor, + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None, + ) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm norm = functools.partial( - nn.BatchNorm, - use_running_average=use_running_average_bn, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype) + nn.BatchNorm, + use_running_average=use_running_average_bn, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = conv( - self.num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='Conv_init')( - x) + self.num_filters, + (7, 7), + (2, 2), + padding=[(3, 3), (3, 3)], + name='Conv_init', + )(x) x = norm(name='BatchNorm_init')(x) x = self.act(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) @@ -110,23 +117,23 @@ def __call__(self, for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( - self.num_filters * 2**i, - strides=strides, - conv=conv, - norm=norm, - act=self.act, - bn_init_scale=self.bn_init_scale)( - x) + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + bn_init_scale=self.bn_init_scale, + )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - dtype=self.dtype)( - x) + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x ResNet18 = functools.partial( - ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock +) ResNet50 = functools.partial( - ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) + ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock +) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index c68e2de33..03b36e03d 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,16 +9,19 @@ import tensorflow as tf -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - rotate_img -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - transform -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - translate +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( + rotate_img, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( + transform, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( + translate, +) # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. -_MAX_LEVEL = 10. +_MAX_LEVEL = 10.0 def blend(image1, image2, factor): @@ -86,10 +89,12 @@ def cutout(image, pad_size, replace=0): # Sample the center location in the image where the zero mask will be applied. cutout_center_height = tf.random.uniform( - shape=[], minval=0, maxval=image_height, dtype=tf.int32) + shape=[], minval=0, maxval=image_height, dtype=tf.int32 + ) cutout_center_width = tf.random.uniform( - shape=[], minval=0, maxval=image_width, dtype=tf.int32) + shape=[], minval=0, maxval=image_width, dtype=tf.int32 + ) lower_pad = tf.maximum(0, cutout_center_height - pad_size) upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) @@ -97,20 +102,18 @@ def cutout(image, pad_size, replace=0): right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) cutout_shape = [ - image_height - (lower_pad + upper_pad), - image_width - (left_pad + right_pad), + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad), ] padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] mask = tf.pad( - tf.zeros(cutout_shape, dtype=image.dtype), - padding_dims, - constant_values=1) + tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1 + ) mask = tf.expand_dims(mask, -1) mask = tf.tile(mask, [1, 1, 3]) image = tf.where( - tf.equal(mask, 0), - tf.ones_like(image, dtype=image.dtype) * replace, - image) + tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image + ) return image @@ -204,7 +207,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) return unwrap(image, replace) @@ -214,7 +217,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0]) return unwrap(image, replace) @@ -264,9 +267,12 @@ def sharpness(image, factor): # Make image 4D for conv operation. image = tf.expand_dims(image, 0) # SMOOTH PIL Kernel. - kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]], - dtype=tf.float32, - shape=[3, 3, 1, 1]) / 13. + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) # Tile across channel dimension. kernel = tf.tile(kernel, [1, 1, 3, 1]) strides = [1, 1, 1, 1] @@ -274,7 +280,8 @@ def sharpness(image, factor): # Some augmentation that uses depth-wise conv will cause crashing when # training on GPU. degenerate = tf.nn.depthwise_conv2d( - image, kernel, strides, padding='VALID', dilations=[1, 1]) + image, kernel, strides, padding='VALID', dilations=[1, 1] + ) degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) @@ -316,9 +323,10 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, - lambda: tf.gather(build_lut(histo, step), im)) + tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im), + ) return tf.cast(result, tf.uint8) @@ -373,9 +381,10 @@ def unwrap(image, replace): # Where they are zero, fill them in with 'replace'. flattened_image = tf.where( - tf.equal(alpha_channel, 0), - tf.ones_like(flattened_image, dtype=image.dtype) * replace, - flattened_image) + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image, + ) image = tf.reshape(flattened_image, image_shape) image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) @@ -383,22 +392,22 @@ def unwrap(image, replace): NAME_TO_FUNC = { - 'AutoContrast': autocontrast, - 'Equalize': equalize, - 'Invert': invert, - 'Rotate': rotate, - 'Posterize': posterize, - 'Solarize': solarize, - 'SolarizeAdd': solarize_add, - 'Color': color, - 'Contrast': contrast, - 'Brightness': brightness, - 'Sharpness': sharpness, - 'ShearX': shear_x, - 'ShearY': shear_y, - 'TranslateX': translate_x, - 'TranslateY': translate_y, - 'Cutout': cutout, + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, } @@ -410,7 +419,7 @@ def _randomly_negate_tensor(tensor): def _rotate_level_to_arg(level): - level = (level / _MAX_LEVEL) * 30. + level = (level / _MAX_LEVEL) * 30.0 level = _randomly_negate_tensor(level) return (level,) @@ -435,47 +444,28 @@ def _translate_level_to_arg(level, translate_const): def level_to_arg(cutout_const, translate_const): return { - 'AutoContrast': - lambda level: (), - 'Equalize': - lambda level: (), - 'Invert': - lambda level: (), - 'Rotate': - _rotate_level_to_arg, - 'Posterize': - lambda level: (int((level / _MAX_LEVEL) * 4),), - 'Solarize': - lambda level: (int((level / _MAX_LEVEL) * 256),), - 'SolarizeAdd': - lambda level: (int((level / _MAX_LEVEL) * 110),), - 'Color': - _enhance_level_to_arg, - 'Contrast': - _enhance_level_to_arg, - 'Brightness': - _enhance_level_to_arg, - 'Sharpness': - _enhance_level_to_arg, - 'ShearX': - _shear_level_to_arg, - 'ShearY': - _shear_level_to_arg, - 'Cutout': - lambda level: (int((level / _MAX_LEVEL) * cutout_const),), - 'TranslateX': - lambda level: _translate_level_to_arg(level, translate_const), - 'TranslateY': - lambda level: _translate_level_to_arg(level, translate_const), + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level / _MAX_LEVEL) * cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg(level, translate_const), + 'TranslateY': lambda level: _translate_level_to_arg(level, translate_const), } -def _parse_policy_info(name, - prob, - level, - replace_value, - cutout_const, - translate_const): +def _parse_policy_info( + name, prob, level, replace_value, cutout_const, translate_const +): """Return the function that corresponds to `name` and update `level` param.""" func = NAME_TO_FUNC[name] args = level_to_arg(cutout_const, translate_const)[name](level) @@ -514,45 +504,49 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): """ replace_value = [128] * 3 available_ops = [ - 'AutoContrast', - 'Equalize', - 'Invert', - 'Rotate', - 'Posterize', - 'Solarize', - 'Color', - 'Contrast', - 'Brightness', - 'Sharpness', - 'ShearX', - 'ShearY', - 'TranslateX', - 'TranslateY', - 'Cutout', - 'SolarizeAdd', + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateX', + 'TranslateY', + 'Cutout', + 'SolarizeAdd', ] for layer_num in range(num_layers): key = tf.random.experimental.stateless_fold_in(key, layer_num) - op_to_select = tf.random.stateless_uniform([], - seed=key, - maxval=len(available_ops), - dtype=tf.int32) + op_to_select = tf.random.stateless_uniform( + [], seed=key, maxval=len(available_ops), dtype=tf.int32 + ) random_magnitude = float(magnitude) with tf.name_scope('randaug_layer_{}'.format(layer_num)): - for (i, op_name) in enumerate(available_ops): + for i, op_name in enumerate(available_ops): key = tf.random.experimental.stateless_fold_in(key, i) - prob = tf.random.stateless_uniform([], - seed=key, - minval=0.2, - maxval=0.8, - dtype=tf.float32) - func, _, args = _parse_policy_info(op_name, prob, random_magnitude, - replace_value, cutout_const=40, - translate_const=100) + prob = tf.random.stateless_uniform( + [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32 + ) + func, _, args = _parse_policy_info( + op_name, + prob, + random_magnitude, + replace_value, + cutout_const=40, + translate_const=100, + ) image = tf.cond( - tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args + ), + lambda: image, + ) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 7896dcd05..c3035c212 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -9,70 +9,75 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax import linen as nn -from flax.core import pop import jax -from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds +from flax import jax_utils +from flax import linen as nn +from flax.core import pop +from jax import lax -from algoperf import param_utils +from algoperf import param_utils, spec from algoperf import random_utils as prng -from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline -from algoperf.workloads.imagenet_resnet.imagenet_jax import models -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax import ( + input_pipeline, + models, +) +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: if split == 'test': np_iter = imagenet_v2.get_imagenet_v2_iter( - data_dir, - global_batch_size, - mean_rgb=self.train_mean, - stddev_rgb=self.train_stddev, - image_size=self.center_crop_size, - resize_size=self.resize_size) + data_dir, + global_batch_size, + mean_rgb=self.train_mean, + stddev_rgb=self.train_stddev, + image_size=self.center_crop_size, + resize_size=self.resize_size, + ) return itertools.cycle(np_iter) ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) train = split == 'train' ds = input_pipeline.create_input_iter( - split, - ds_builder, - data_rng, - global_batch_size, - self.train_mean, - self.train_stddev, - self.center_crop_size, - self.resize_size, - self.aspect_ratio_range, - self.scale_ratio_range, - train=train, - cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset, - use_mixup=use_mixup, - mixup_alpha=0.2, - use_randaug=use_randaug) + split, + ds_builder, + data_rng, + global_batch_size, + self.train_mean, + self.train_stddev, + self.center_crop_size, + self.resize_size, + self.aspect_ratio_range, + self.scale_ratio_range, + train=train, + cache=not train if cache is None else cache, + repeat_final_dataset=repeat_final_dataset, + use_mixup=use_mixup, + mixup_alpha=0.2, + use_randaug=use_randaug, + ) return ds def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: """Sync the batch statistics across replicas.""" # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics and @@ -83,8 +88,8 @@ def sync_batch_stats( return new_model_state def init_model_fn( - self, - rng: spec.RandomState, + self, + rng: spec.RandomState, ) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') @@ -98,15 +103,17 @@ def init_model_fn( act_fnc = nn.relu model = model_cls( - num_classes=self._num_classes, - act=act_fnc, - bn_init_scale=self.bn_init_scale, - dtype=jnp.float32) + num_classes=self._num_classes, + act=act_fnc, + bn_init_scale=self.bn_init_scale, + dtype=jnp.float32, + ) self._model = model input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': rng}, - jnp.ones(input_shape, model.dtype)) - model_state, params = pop(variables, "params") + variables = jax.jit(model.init)( + {'params': rng}, jnp.ones(input_shape, model.dtype) + ) + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) @@ -117,63 +124,70 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_model(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_model( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng=rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng=rng, + update_batch_norm=False, + ) weights = batch.get('weights') return self._compute_metrics(logits, batch['targets'], weights) def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + ) return logits, new_model_state else: logits = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -182,12 +196,14 @@ def loss_fn( """ if label_batch.shape[-1] != self._num_classes: one_hot_labels = jax.nn.one_hot( - label_batch, num_classes=self._num_classes) + label_batch, num_classes=self._num_classes + ) else: one_hot_labels = label_batch smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing) per_example_losses = -jnp.sum( - smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1) + smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1 + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -196,36 +212,37 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: if weights is None: weights = jnp.ones(len(logits)) summed_loss = self.loss_fn(labels, logits, weights)['summed'] # not accuracy, but nr. of correct predictions accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, + 'loss': summed_loss, + 'accuracy': accuracy, } metrics = lax.psum(metrics, axis_name='batch') return metrics - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: del global_step if model_state is not None: # Sync batch statistics across replicas before evaluating. @@ -235,13 +252,14 @@ def _eval_model_on_split(self, # We already repeat the dataset indefinitely in tf.data. if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng, - split=split, - global_batch_size=global_batch_size, - data_dir=data_dir, - cache=True, - repeat_final_dataset=True, - num_batches=num_batches) + data_rng, + split=split, + global_batch_size=global_batch_size, + data_dir=data_dir, + cache=True, + repeat_final_dataset=True, + num_batches=num_batches, + ) eval_metrics = {} for bi in range(num_batches): @@ -249,22 +267,21 @@ def _eval_model_on_split(self, step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - step_eval_rngs) + synced_metrics = self._eval_model( + params, batch, model_state, step_eval_rngs + ) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), - eval_metrics) + eval_metrics = jax.tree.map( + lambda x: float(x[0] / num_examples), eval_metrics + ) return eval_metrics class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): - @property def use_silu(self) -> bool: return True @@ -279,7 +296,6 @@ def test_target_value(self) -> float: class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): - @property def use_gelu(self) -> bool: return True @@ -294,7 +310,6 @@ def test_target_value(self) -> float: class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): - @property def bn_init_scale(self) -> float: return 8.0 diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index aba9e671f..ab3fc4a37 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -15,44 +15,49 @@ from algoperf.init_utils import pytorch_default_init -def conv3x3(in_planes: int, - out_planes: int, - stride: int = 1, - groups: int = 1, - dilation: int = 1) -> nn.Conv2d: +def conv3x3( + in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1, +) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) class BasicBlock(nn.Module): """ResNet block.""" + expansion: int = 1 def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True) + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() if norm_layer is None: @@ -60,7 +65,7 @@ def __init__( if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. self.conv1 = conv3x3(inplanes, planes, stride) @@ -92,24 +97,25 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class Bottleneck(nn.Module): """Bottleneck ResNet block.""" + expansion: int = 4 def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True) + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. self.conv1 = conv1x1(inplanes, width) @@ -146,18 +152,19 @@ def forward(self, x: Tensor) -> Tensor: class ResNet(nn.Module): - - def __init__(self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = True, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True), - bn_init_scale: float = 0.) -> None: + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = True, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), + bn_init_scale: float = 0.0, + ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -171,37 +178,42 @@ def __init__(self, replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( - 'replace_stride_with_dilation should be None ' - f'or a 3-element tuple, got {replace_stride_with_dilation}') + 'replace_stride_with_dilation should be None ' + f'or a 3-element tuple, got {replace_stride_with_dilation}' + ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, self.act_fnc, 64, layers[0]) self.layer2 = self._make_layer( - block, - self.act_fnc, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) + block, + self.act_fnc, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + ) self.layer3 = self._make_layer( - block, - self.act_fnc, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) + block, + self.act_fnc, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + ) self.layer4 = self._make_layer( - block, - self.act_fnc, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) + block, + self.act_fnc, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) @@ -212,7 +224,7 @@ def __init__(self, nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.normal_(self.fc.weight, std=1e-2) - nn.init.constant_(self.fc.bias, 0.) + nn.init.constant_(self.fc.bias, 0.0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, @@ -226,13 +238,15 @@ def __init__(self, elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, bn_init_scale) - def _make_layer(self, - block: Type[Union[BasicBlock, Bottleneck]], - act_fnc: nn.Module, - planes: int, - blocks: int, - stride: int = 1, - dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + act_fnc: nn.Module, + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -241,34 +255,41 @@ def _make_layer(self, stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = torch.nn.Sequential( - collections.OrderedDict([ - ("conv", conv1x1(self.inplanes, planes * block.expansion, - stride)), - ("bn", norm_layer(planes * block.expansion)), - ])) + collections.OrderedDict( + [ + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ('bn', norm_layer(planes * block.expansion)), + ] + ) + ) layers = [] layers.append( - block(self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer, - act_fnc)) + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + act_fnc, + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - act_fnc=act_fnc)) + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + act_fnc=act_fnc, + ) + ) return nn.Sequential(*layers) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py index c7a98e77a..28ce00650 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py @@ -24,8 +24,8 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: # Double the pad size to match Jax implementation. pad_size = pad_size * 2 - x0 = int(max(0, x0 - pad_size / 2.)) - y0 = int(max(0, y0 - pad_size / 2.)) + x0 = int(max(0, x0 - pad_size / 2.0)) + y0 = int(max(0, y0 - pad_size / 2.0)) x1 = int(min(image_width, x0 + pad_size)) y1 = int(min(image_height, y0 + pad_size)) xy = (x0, y0, x1, y1) @@ -36,7 +36,7 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: def solarize(img: spec.Tensor, threshold: float) -> spec.Tensor: img = np.array(img) - new_img = np.where(img < threshold, img, 255. - img) + new_img = np.where(img < threshold, img, 255.0 - img) return PIL.Image.fromarray(new_img.astype(np.uint8)) @@ -49,54 +49,56 @@ def solarize_add(img: spec.Tensor, addition: int = 0) -> spec.Tensor: return PIL.Image.fromarray(new_img) -def _apply_op(img: spec.Tensor, - op_name: str, - magnitude: float, - interpolation: InterpolationMode, - fill: Optional[List[float]]) -> spec.Tensor: +def _apply_op( + img: spec.Tensor, + op_name: str, + magnitude: float, + interpolation: InterpolationMode, + fill: Optional[List[float]], +) -> spec.Tensor: if op_name == 'ShearX': # Magnitude should be arctan(magnitude). img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(math.atan(magnitude)), 0.0], - interpolation=interpolation, - fill=fill, - center=[0, 0], + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(math.atan(magnitude)), 0.0], + interpolation=interpolation, + fill=fill, + center=[0, 0], ) elif op_name == 'ShearY': # Magnitude should be arctan(magnitude). img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(math.atan(magnitude))], - interpolation=interpolation, - fill=fill, - center=[0, 0], + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(math.atan(magnitude))], + interpolation=interpolation, + fill=fill, + center=[0, 0], ) elif op_name == 'TranslateX': img = F.affine( - img, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, + img, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, ) elif op_name == 'TranslateY': img = F.affine( - img, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, + img, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, ) elif op_name == 'Rotate': img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) @@ -131,33 +133,32 @@ def _apply_op(img: spec.Tensor, def ops_space() -> Dict[str, Tuple[spec.Tensor, bool]]: return { - # op_name: (magnitudes, signed) - 'ShearX': (torch.tensor(0.3), True), - 'ShearY': (torch.tensor(0.3), True), - 'TranslateX': (torch.tensor(100), True), - 'TranslateY': (torch.tensor(100), True), - 'Rotate': (torch.tensor(30), True), - 'Brightness': (torch.tensor(1.9), False), - 'Color': (torch.tensor(1.9), False), - 'Contrast': (torch.tensor(1.9), False), - 'Sharpness': (torch.tensor(1.9), False), - 'Posterize': (torch.tensor(4), False), - 'Solarize': (torch.tensor(256), False), - 'SolarizeAdd': (torch.tensor(110), False), - 'AutoContrast': (torch.tensor(0.0), False), - 'Equalize': (torch.tensor(0.0), False), - 'Invert': (torch.tensor(0.0), False), - 'Cutout': (torch.tensor(40.0), False), + # op_name: (magnitudes, signed) + 'ShearX': (torch.tensor(0.3), True), + 'ShearY': (torch.tensor(0.3), True), + 'TranslateX': (torch.tensor(100), True), + 'TranslateY': (torch.tensor(100), True), + 'Rotate': (torch.tensor(30), True), + 'Brightness': (torch.tensor(1.9), False), + 'Color': (torch.tensor(1.9), False), + 'Contrast': (torch.tensor(1.9), False), + 'Sharpness': (torch.tensor(1.9), False), + 'Posterize': (torch.tensor(4), False), + 'Solarize': (torch.tensor(256), False), + 'SolarizeAdd': (torch.tensor(110), False), + 'AutoContrast': (torch.tensor(0.0), False), + 'Equalize': (torch.tensor(0.0), False), + 'Invert': (torch.tensor(0.0), False), + 'Cutout': (torch.tensor(40.0), False), } class RandAugment(torch.nn.Module): - def __init__( - self, - num_ops: int = 2, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, + self, + num_ops: int = 2, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, ) -> None: super().__init__() self.num_ops = num_ops @@ -183,5 +184,6 @@ def forward(self, img: spec.Tensor) -> spec.Tensor: # With 50% prob turn the magnitude negative. magnitude *= -1.0 img = _apply_op( - img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + img, op_name, magnitude, interpolation=self.interpolation, fill=fill + ) return img diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 285ba3b4b..85a35dc45 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -16,22 +16,21 @@ from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng +from algoperf import data_utils, param_utils, pytorch_utils, spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() def imagenet_v2_to_torch( - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + batch: Dict[str, spec.Tensor], +) -> Dict[str, spec.Tensor]: # Slice off the part of the batch for this device and then transpose from # [N, H, W, C] to [N, C, H, W]. Only transfer the inputs to GPU. new_batch = {} @@ -48,7 +47,6 @@ def imagenet_v2_to_torch( class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Is set in submission_runner.py for workloads with PyTorch evaluation @@ -59,7 +57,8 @@ def __init__(self, *args, **kwargs) -> None: def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -67,60 +66,68 @@ def eval_num_workers(self, eval_num_workers: int): self._eval_num_workers = eval_num_workers def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache del repeat_final_dataset if split == 'test': np_iter = imagenet_v2.get_imagenet_v2_iter( - data_dir, - global_batch_size, - mean_rgb=self.train_mean, - stddev_rgb=self.train_stddev, - image_size=self.center_crop_size, - resize_size=self.resize_size) + data_dir, + global_batch_size, + mean_rgb=self.train_mean, + stddev_rgb=self.train_stddev, + image_size=self.center_crop_size, + resize_size=self.resize_size, + ) return map(imagenet_v2_to_torch, itertools.cycle(np_iter)) is_train = split == 'train' normalize = transforms.Normalize( - mean=[i / 255. for i in self.train_mean], - std=[i / 255. for i in self.train_stddev]) + mean=[i / 255.0 for i in self.train_mean], + std=[i / 255.0 for i in self.train_stddev], + ) if is_train: transform_config = [ - transforms.RandomResizedCrop( - self.center_crop_size, - scale=self.scale_ratio_range, - ratio=self.aspect_ratio_range), - transforms.RandomHorizontalFlip(), + transforms.RandomResizedCrop( + self.center_crop_size, + scale=self.scale_ratio_range, + ratio=self.aspect_ratio_range, + ), + transforms.RandomHorizontalFlip(), ] if use_randaug: transform_config.append(randaugment.RandAugment()) transform_config.extend([transforms.ToTensor(), normalize]) transform_config = transforms.Compose(transform_config) else: - transform_config = transforms.Compose([ + transform_config = transforms.Compose( + [ transforms.Resize(self.resize_size), transforms.CenterCrop(self.center_crop_size), transforms.ToTensor(), normalize, - ]) + ] + ) folder = 'train' if 'train' in split else 'val' dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config) + os.path.join(data_dir, folder), transform=transform_config + ) if split == 'eval_train': indices = list(range(self.num_train_examples)) random.Random(int(data_rng[0])).shuffle(indices) - dataset = torch.utils.data.Subset(dataset, - indices[:self.num_eval_train_examples]) + dataset = torch.utils.data.Subset( + dataset, indices[: self.num_eval_train_examples] + ) sampler = None if USE_PYTORCH_DDP: @@ -131,26 +138,30 @@ def _build_dataset( if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, - pin_memory=True, - drop_last=is_train, - persistent_workers=is_train) + dataset, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4 if is_train else self.eval_num_workers, + pin_memory=True, + drop_last=is_train, + persistent_workers=is_train, + ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( - dataloader, - custom_sampler=USE_PYTORCH_DDP, - use_mixup=use_mixup, - mixup_alpha=0.2) + dataloader, + custom_sampler=USE_PYTORCH_DDP, + use_mixup=use_mixup, + mixup_alpha=0.2, + ) return dataloader @@ -181,14 +192,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['fc.weight', 'fc.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = 0.0 + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -199,19 +210,22 @@ def model_fn( if mode == spec.ForwardPassMode.EVAL: if update_batch_norm: raise ValueError( - 'Batch norm statistics cannot be updated during evaluation.') + 'Batch norm statistics cannot be updated during evaluation.' + ) model.eval() if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): @@ -222,11 +236,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -234,10 +249,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -246,15 +262,14 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: """Return the mean accuracy and loss as a dict.""" if weights is None: weights = torch.ones(len(logits), device=DEVICE) @@ -264,15 +279,17 @@ def _compute_metrics(self, summed_loss = self.loss_fn(labels, logits, weights)['summed'] return {'accuracy': accuracy, 'loss': summed_loss} - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) @@ -280,31 +297,33 @@ def _eval_model_on_split(self, is_test = split == 'test' # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split=split, - global_batch_size=global_batch_size, - data_dir=data_dir, - cache=is_test, - repeat_final_dataset=is_test) + data_rng, + split=split, + global_batch_size=global_batch_size, + data_dir=data_dir, + cache=is_test, + repeat_final_dataset=is_test, + ) total_metrics = { - 'accuracy': torch.tensor(0., device=DEVICE), - 'loss': torch.tensor(0., device=DEVICE), + 'accuracy': torch.tensor(0.0, device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), } num_batches = int(math.ceil(num_examples / global_batch_size)) for _ in range(num_batches): batch = next(self._eval_iters[split]) logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False, + ) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): @@ -313,7 +332,6 @@ def _eval_model_on_split(self, class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): - @property def use_silu(self) -> bool: return True @@ -328,7 +346,6 @@ def test_target_value(self) -> float: class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): - @property def use_gelu(self) -> bool: return True @@ -343,7 +360,6 @@ def test_target_value(self) -> float: class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): - @property def bn_init_scale(self) -> float: return 8.0 diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 84d364586..6ffb73367 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -13,32 +13,34 @@ from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline -def get_imagenet_v2_iter(data_dir: str, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int) -> Iterator[Dict[str, spec.Tensor]]: +def get_imagenet_v2_iter( + data_dir: str, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, +) -> Iterator[Dict[str, spec.Tensor]]: """Always caches and repeats indefinitely.""" ds = tfds.load( - 'imagenet_v2/matched-frequency:3.0.0', - split='test', - data_dir=data_dir, - decoders={ - 'image': tfds.decode.SkipDecoding(), - }) + 'imagenet_v2/matched-frequency:3.0.0', + split='test', + data_dir=data_dir, + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) def _decode_example(example: Dict[str, float]) -> Dict[str, float]: - image = input_pipeline.preprocess_for_eval(example['image'], - mean_rgb, - stddev_rgb, - image_size, - resize_size) + image = input_pipeline.preprocess_for_eval( + example['image'], mean_rgb, stddev_rgb, image_size, resize_size + ) return {'inputs': image, 'targets': example['label']} ds = ds.map(_decode_example, num_parallel_calls=16) ds = ds.batch(global_batch_size) shard_pad_fn = functools.partial( - data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size) + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ) it = map(shard_pad_fn, iter(ds)) return it diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index 83fe97108..ef696e328 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -7,7 +7,6 @@ class BaseImagenetResNetWorkload(spec.Workload): - _num_classes: int = 1000 @property @@ -15,8 +14,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -58,8 +58,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -109,38 +110,37 @@ def eval_period_time_sec(self) -> int: return 510 # 8.5 minutes. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: raise NotImplementedError def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches if split == 'test': if not cache: raise ValueError('cache must be True for split=test.') if not repeat_final_dataset: raise ValueError('repeat_final_dataset must be True for split=test.') - return self._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index f33dea723..5e38acd8b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -7,8 +7,8 @@ from typing import Optional, Sequence, Union -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec from algoperf.jax_utils import Dropout @@ -16,18 +16,20 @@ DROPOUT_RATE = 0.0 -def posemb_sincos_2d(h: int, - w: int, - width: int, - temperature: int = 10_000., - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: +def posemb_sincos_2d( + h: int, + w: int, + width: int, + temperature: int = 10_000.0, + dtype: jnp.dtype = jnp.float32, +) -> spec.Tensor: """Follows the MoCo v3 logic.""" - y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence + y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence if width % 4 != 0: raise ValueError('Width must be mult of 4 for sincos posemb.') omega = jnp.arange(width // 4) / (width // 4 - 1) - omega = 1. / (temperature**omega) + omega = 1.0 / (temperature**omega) y = jnp.einsum('m,d->md', y.flatten(), omega) x = jnp.einsum('m,d->md', x.flatten(), omega) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) @@ -36,19 +38,19 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE @nn.compact - def __call__(self, - x: spec.Tensor, - train: bool = True, - dropout_rate=DROPOUT_RATE) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate=DROPOUT_RATE + ) -> spec.Tensor: """Applies Transformer MlpBlock module.""" inits = { - 'kernel_init': nn.initializers.xavier_uniform(), - 'bias_init': nn.initializers.normal(stddev=1e-6), + 'kernel_init': nn.initializers.xavier_uniform(), + 'bias_init': nn.initializers.normal(stddev=1e-6), } d = x.shape[2] @@ -66,6 +68,7 @@ def __call__(self, class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 use_glu: bool = False @@ -73,47 +76,45 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, - x: spec.Tensor, - train: bool = True, - dropout_rate=dropout_rate) -> spec.Tensor: - + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate + ) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1', + )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')( - y, train, dropout_rate=dropout_rate) + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' + )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y else: y = x y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1', + )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) y = x y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - name='MlpBlock_3', - dropout_rate=dropout_rate)( - y, train, dropout_rate=dropout_rate) + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + name='MlpBlock_3', + dropout_rate=dropout_rate, + )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -131,27 +132,27 @@ class Encoder(nn.Module): use_post_layer_norm: bool = False @nn.compact - def __call__(self, - x: spec.Tensor, - train: bool = True, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): x = Encoder1DBlock( - name=f"encoderblock_{lyr}", - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, + name=f'encoderblock_{lyr}', + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name="encoder_layernorm")(x) + return nn.LayerNorm(name='encoder_layernorm')(x) else: return x class MAPHead(nn.Module): """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 @@ -159,15 +160,16 @@ class MAPHead(nn.Module): @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape - probe = self.param('probe', - nn.initializers.xavier_uniform(), (1, 1, d), - x.dtype) + probe = self.param( + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + ) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(probe, x) + num_heads=self.num_heads, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(probe, x) y = nn.LayerNorm()(x) x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) @@ -191,26 +193,23 @@ class ViT(nn.Module): use_post_layer_norm: bool = False use_map: bool = False - def get_posemb(self, - seqshape: tuple, - width: int, - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: + def get_posemb( + self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 + ) -> spec.Tensor: return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, - x: spec.Tensor, - *, - train: bool = False, - dropout_rate=DROPOUT_RATE) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, *, train: bool = False, dropout_rate=DROPOUT_RATE + ) -> spec.Tensor: # Patch extraction x = nn.Conv( - self.width, - self.patch_size, - strides=self.patch_size, - padding='VALID', - name='conv_patch_extract')( - x) + self.width, + self.patch_size, + strides=self.patch_size, + padding='VALID', + name='conv_patch_extract', + )(x) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) @@ -221,20 +220,20 @@ def __call__(self, x = Dropout(dropout_rate)(x, not train, rate=dropout_rate) x = Encoder( - depth=self.depth, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - name='Transformer', + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) if self.use_map: x = MAPHead( - num_heads=self.num_heads, - mlp_dim=self.mlp_dim, - dropout_rate=dropout_rate)( - x, dropout_rate=dropout_rate) + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=dropout_rate, + )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 0e320b9b9..1637a2123 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -2,40 +2,44 @@ from typing import Dict, Optional, Tuple +import jax +import jax.numpy as jnp from flax import jax_utils from flax import linen as nn from flax.core import pop -import jax -import jax.numpy as jnp -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload +from algoperf import param_utils, spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload, +) from algoperf.workloads.imagenet_vit.imagenet_jax import models -from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import decode_variant +from algoperf.workloads.imagenet_vit.workload import ( + BaseImagenetVitWorkload, + decode_variant, +) # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - - def initialized(self, key: spec.RandomState, - model: nn.Module) -> spec.ModelInitState: + def initialized( + self, key: spec.RandomState, model: nn.Module + ) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) params_rng, _ = jax.random.split(key) - variables = jax.jit(model.init)({'params': params_rng}, - jnp.ones(input_shape)) - model_state, params = pop(variables, "params") + variables = jax.jit(model.init)( + {'params': params_rng}, jnp.ones(input_shape) + ) + model_state, params = pop(variables, 'params') return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.ViT( - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16'), + ) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -47,49 +51,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'head' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm del use_running_average_bn train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: model_state = None - return super()._eval_model_on_split(split, - num_examples, - global_batch_size, - params, - model_state, - rng, - data_dir, - global_step) + return super()._eval_model_on_split( + split, + num_examples, + global_batch_size, + params, + model_state, + rng, + data_dir, + global_step, + ) class ImagenetVitGluWorkload(ImagenetVitWorkload): - @property def use_glu(self) -> bool: return True @@ -104,7 +113,6 @@ def test_target_value(self) -> float: class ImagenetVitPostLNWorkload(ImagenetVitWorkload): - @property def use_post_layer_norm(self) -> bool: return True @@ -119,7 +127,6 @@ def test_target_value(self) -> float: class ImagenetVitMapWorkload(ImagenetVitWorkload): - @property def use_map(self) -> bool: return True diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index cb503cd9f..fc2a3cd46 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -9,27 +9,29 @@ from typing import Any, Optional, Tuple, Union import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from algoperf import init_utils -from algoperf import spec +from algoperf import init_utils, spec from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention DROPOUT_RATE = 0.0 -def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: +def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.0) -> spec.Tensor: """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device - y, x = torch.meshgrid(torch.arange(h, device=device), - torch.arange(w, device=device), indexing='ij') + y, x = torch.meshgrid( + torch.arange(h, device=device), + torch.arange(w, device=device), + indexing='ij', + ) if width % 4 != 0: raise ValueError('Width must be mult of 4 for sincos posemb.') omega = torch.arange(width // 4, device=device) / (width // 4 - 1) - omega = 1. / (temperature**omega) + omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) @@ -40,10 +42,11 @@ class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False) -> None: + self, + width: int, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, + ) -> None: super().__init__() self.width = width @@ -70,7 +73,6 @@ def reset_parameters(self) -> None: module.bias.data.normal_(std=1e-6) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - x = self.linear1(x) x = self.act_fnc(x) @@ -93,7 +95,8 @@ def __init__(self, width: int, num_heads: int = 8) -> None: self.num_heads = num_heads assert width % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') + 'Memory dimension must be divisible by number of heads.' + ) self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim @@ -109,7 +112,7 @@ def reset_parameters(self) -> None: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight.data) if module.bias is not None: - nn.init.constant_(module.bias.data, 0.) + nn.init.constant_(module.bias.data, 0.0) def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) @@ -117,7 +120,6 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: return x.permute(0, 2, 1, 3) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - mixed_query_layer = self.query(x) key_layer = self.transpose_for_scores(self.key(x)) @@ -141,12 +143,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: + def __init__( + self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + ) -> None: super().__init__() self.width = width @@ -159,10 +163,10 @@ def __init__(self, self.self_attention1 = SelfAttention(self.width, self.num_heads) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu) + width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu + ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: - if not self.use_post_layer_norm: y = self.layer_norm0(x) y = self.self_attention1(y, dropout_rate) @@ -191,13 +195,15 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" - def __init__(self, - depth: int, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False) -> None: + def __init__( + self, + depth: int, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + ) -> None: super().__init__() self.depth = depth @@ -207,13 +213,18 @@ def __init__(self, self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.net = nn.ModuleList([ - Encoder1DBlock(self.width, - self.mlp_dim, - self.num_heads, - self.use_glu, - self.use_post_layer_norm) for _ in range(depth) - ]) + self.net = nn.ModuleList( + [ + Encoder1DBlock( + self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + ) + for _ in range(depth) + ] + ) if not self.use_post_layer_norm: self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) @@ -233,10 +244,9 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12): + def __init__( + self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 + ): super().__init__() self.width = width self.mlp_dim = mlp_dim @@ -246,7 +256,8 @@ def __init__(self, nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True) + self.width, num_heads=self.num_heads, self_attn=False, bias=True + ) self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) @@ -268,19 +279,20 @@ class ViT(nn.Module): channels: int = 3 def __init__( - self, - num_classes: int = 1000, - patch_size: Tuple[int, int] = (16, 16), - width: int = 768, - depth: int = 12, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - num_heads: int = 12, - rep_size: Union[int, bool] = True, - head_zeroinit: bool = True, - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, - dtype: Any = torch.float32) -> None: + self, + num_classes: int = 1000, + patch_size: Tuple[int, int] = (16, 16), + width: int = 768, + depth: int = 12, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + num_heads: int = 12, + rep_size: Union[int, bool] = True, + head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, + dtype: Any = torch.float32, + ) -> None: super().__init__() self.num_classes = num_classes @@ -301,19 +313,21 @@ def __init__( self.pre_logits = nn.Linear(self.width, rep_size) self.conv_patch_extract = nn.Conv2d( - self.channels, - self.width, - self.patch_size, - stride=self.patch_size, - padding='valid') + self.channels, + self.width, + self.patch_size, + stride=self.patch_size, + padding='valid', + ) self.encoder = Encoder( - depth=self.depth, - width=self.width, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm) + depth=self.depth, + width=self.width, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + ) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -333,18 +347,17 @@ def reset_parameters(self) -> None: if self.num_classes: if self.head_zeroinit: - nn.init.constant_(self.head.weight.data, 0.) - nn.init.constant_(self.head.bias.data, 0.) + nn.init.constant_(self.head.weight.data, 0.0) + nn.init.constant_(self.head.bias.data, 0.0) else: init_utils.pytorch_default_init(self.head) def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, - x: spec.Tensor, - dropout_rate: float = DROPOUT_RATE) -> spec.Tensor: - + def forward( + self, x: spec.Tensor, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index d43e90e80..9c6faf70b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -6,29 +6,30 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload, +) from algoperf.workloads.imagenet_vit.imagenet_pytorch import models -from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import decode_variant +from algoperf.workloads.imagenet_vit.workload import ( + BaseImagenetVitWorkload, + decode_variant, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16'), + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -43,14 +44,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['head.weight', 'head.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -65,20 +66,20 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate) + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None class ImagenetVitGluWorkload(ImagenetVitWorkload): - @property def use_glu(self) -> bool: return True @@ -93,7 +94,6 @@ def test_target_value(self) -> float: class ImagenetVitPostLNWorkload(ImagenetVitWorkload): - @property def use_post_layer_norm(self) -> bool: return True @@ -108,7 +108,6 @@ def test_target_value(self) -> float: class ImagenetVitMapWorkload(ImagenetVitWorkload): - @property def use_map(self) -> bool: return True diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index f249ddee8..2a0070ba4 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -3,8 +3,9 @@ from typing import Dict, Iterator, Optional from algoperf import spec -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) def decode_variant(variant: str) -> Dict[str, int]: @@ -12,46 +13,52 @@ def decode_variant(variant: str) -> Dict[str, int]: v, patch = variant.split('/') return { - # Reference: Table 2 of https://arxiv.org/abs/2106.04560. - 'width': { - 'Ti': 192, - 'S': 384, - 'M': 512, - 'B': 768, - 'L': 1024, - 'H': 1280, - 'g': 1408, - 'G': 1664, - }[v], - 'depth': { - 'Ti': 12, - 'S': 12, - 'M': 12, - 'B': 12, - 'L': 24, - 'H': 32, - 'g': 40, - 'G': 48, - }[v], - 'mlp_dim': { - 'Ti': 768, - 'S': 1536, - 'M': 2048, - 'B': 3072, - 'L': 4096, - 'H': 5120, - 'g': 6144, - 'G': 8192, - }[v], - 'num_heads': { - 'Ti': 3, 'S': 6, 'M': 8, 'B': 12, 'L': 16, 'H': 16, 'g': 16, 'G': 16 - }[v], - 'patch_size': (int(patch), int(patch)), + # Reference: Table 2 of https://arxiv.org/abs/2106.04560. + 'width': { + 'Ti': 192, + 'S': 384, + 'M': 512, + 'B': 768, + 'L': 1024, + 'H': 1280, + 'g': 1408, + 'G': 1664, + }[v], + 'depth': { + 'Ti': 12, + 'S': 12, + 'M': 12, + 'B': 12, + 'L': 24, + 'H': 32, + 'g': 40, + 'G': 48, + }[v], + 'mlp_dim': { + 'Ti': 768, + 'S': 1536, + 'M': 2048, + 'B': 3072, + 'L': 4096, + 'H': 5120, + 'g': 6144, + 'G': 8192, + }[v], + 'num_heads': { + 'Ti': 3, + 'S': 6, + 'M': 8, + 'B': 12, + 'L': 16, + 'H': 16, + 'g': 16, + 'G': 16, + }[v], + 'patch_size': (int(patch), int(patch)), } class BaseImagenetVitWorkload(BaseImagenetResNetWorkload): - @property def validation_target_value(self) -> float: return 1 - 0.22691 # 0.77309 @@ -88,25 +95,28 @@ def eval_period_time_sec(self) -> int: return 7 * 60 # 7 mins. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: # We use mixup and Randaugment for ViT workloads. use_mixup = use_randaug = split == 'train' - return super()._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset, - use_mixup, - use_randaug) + return super()._build_dataset( + data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + use_mixup, + use_randaug, + ) @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py index 1310e7b59..570db07b3 100644 --- a/algoperf/workloads/librispeech_conformer/input_pipeline.py +++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py @@ -10,7 +10,6 @@ class LibriSpeechDataset(torch.utils.data.Dataset): - def __init__(self, split, data_dir): super().__init__() self.data_dir = data_dir @@ -38,13 +37,14 @@ def __getitem__(self, index): audio_paddings = np.zeros_like(audio, dtype=np.float32) audio_paddings = np.pad( - audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0) + audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0 + ) audio = np.pad(audio, (0, 320000 - audio.shape[0]), constant_values=0.0) target_paddings = np.zeros_like(targets, dtype=np.float32) target_paddings = np.pad( - target_paddings, (0, 256 - target_paddings.shape[0]), - constant_values=1.0) + target_paddings, (0, 256 - target_paddings.shape[0]), constant_values=1.0 + ) targets = np.pad(targets, (0, 256 - targets.shape[0]), constant_values=0) audio = audio.astype(np.float32) audio_paddings = audio_paddings.astype(np.float32) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py index 9f45434d9..bd36b1bb9 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py @@ -21,174 +21,175 @@ _MEL_HIGH_FREQUENCY_Q = 1127.0 LIBRISPEECH_MEAN_VECTOR = [ - -7.6047816276550293, - -7.1206226348876953, - -6.8864245414733887, - -6.8705768585205078, - -6.9667720794677734, - -7.1084094047546387, - -6.9528026580810547, - -6.783994197845459, - -6.6195521354675293, - -6.4876265525817871, - -6.4120659828186035, - -6.394047737121582, - -6.4244871139526367, - -6.3993711471557617, - -6.5158271789550781, - -6.7137999534606934, - -6.8476877212524414, - -6.9885001182556152, - -6.9221386909484863, - -7.146148681640625, - -7.2040400505065918, - -7.0537552833557129, - -7.3140382766723633, - -7.1223249435424805, - -7.30251407623291, - -7.1212143898010254, - -7.2425732612609863, - -7.1730537414550781, - -7.0979413986206055, - -7.088747501373291, - -6.9849910736083984, - -6.8787732124328613, - -6.7602753639221191, - -6.6300945281982422, - -6.5145769119262695, - -6.4245057106018066, - -6.356513500213623, - -6.31787633895874, - -6.2660770416259766, - -6.2468328475952148, - -6.2821526527404785, - -6.1908388137817383, - -6.2484354972839355, - -6.1472640037536621, - -6.0924725532531738, - -6.0171003341674805, - -5.9250402450561523, - -5.8535833358764648, - -5.8209109306335449, - -5.8118929862976074, - -5.80783748626709, - -5.7714629173278809, - -5.7453732490539551, - -5.7705655097961426, - -5.7765641212463379, - -5.7831673622131348, - -5.7954087257385254, - -5.7994823455810547, - -5.8023476600646973, - -5.8047118186950684, - -5.8168182373046875, - -5.8844799995422363, - -5.9727106094360352, - -6.0444660186767578, - -6.1284866333007812, - -6.2257585525512695, - -6.3157496452331543, - -6.39061164855957, - -6.4928598403930664, - -6.5498456954956055, - -6.6054320335388184, - -6.6508378982543945, - -6.66917610168457, - -6.6726889610290527, - -6.684234619140625, - -6.6974577903747559, - -6.75471830368042, - -6.7949142456054688, - -6.8634209632873535, - -6.94186544418335 + -7.6047816276550293, + -7.1206226348876953, + -6.8864245414733887, + -6.8705768585205078, + -6.9667720794677734, + -7.1084094047546387, + -6.9528026580810547, + -6.783994197845459, + -6.6195521354675293, + -6.4876265525817871, + -6.4120659828186035, + -6.394047737121582, + -6.4244871139526367, + -6.3993711471557617, + -6.5158271789550781, + -6.7137999534606934, + -6.8476877212524414, + -6.9885001182556152, + -6.9221386909484863, + -7.146148681640625, + -7.2040400505065918, + -7.0537552833557129, + -7.3140382766723633, + -7.1223249435424805, + -7.30251407623291, + -7.1212143898010254, + -7.2425732612609863, + -7.1730537414550781, + -7.0979413986206055, + -7.088747501373291, + -6.9849910736083984, + -6.8787732124328613, + -6.7602753639221191, + -6.6300945281982422, + -6.5145769119262695, + -6.4245057106018066, + -6.356513500213623, + -6.31787633895874, + -6.2660770416259766, + -6.2468328475952148, + -6.2821526527404785, + -6.1908388137817383, + -6.2484354972839355, + -6.1472640037536621, + -6.0924725532531738, + -6.0171003341674805, + -5.9250402450561523, + -5.8535833358764648, + -5.8209109306335449, + -5.8118929862976074, + -5.80783748626709, + -5.7714629173278809, + -5.7453732490539551, + -5.7705655097961426, + -5.7765641212463379, + -5.7831673622131348, + -5.7954087257385254, + -5.7994823455810547, + -5.8023476600646973, + -5.8047118186950684, + -5.8168182373046875, + -5.8844799995422363, + -5.9727106094360352, + -6.0444660186767578, + -6.1284866333007812, + -6.2257585525512695, + -6.3157496452331543, + -6.39061164855957, + -6.4928598403930664, + -6.5498456954956055, + -6.6054320335388184, + -6.6508378982543945, + -6.66917610168457, + -6.6726889610290527, + -6.684234619140625, + -6.6974577903747559, + -6.75471830368042, + -6.7949142456054688, + -6.8634209632873535, + -6.94186544418335, ] LIBRISPEECH_STD_VECTOR = [ - 3.4353282451629639, - 3.5962932109832764, - 3.7012472152709961, - 3.7369205951690674, - 3.7535104751586914, - 3.693629264831543, - 3.6922497749328613, - 3.7641522884368896, - 3.8419716358184814, - 3.8999848365783691, - 3.9294240474700928, - 3.9317409992218018, - 3.9139585494995117, - 3.9031598567962646, - 3.8691999912261963, - 3.8155081272125244, - 3.7644970417022705, - 3.7099106311798096, - 3.6965086460113525, - 3.6003766059875488, - 3.5493226051330566, - 3.5465121269226074, - 3.45003604888916, - 3.4712812900543213, - 3.4084610939025879, - 3.4408135414123535, - 3.4104881286621094, - 3.4217638969421387, - 3.4312851428985596, - 3.4199209213256836, - 3.4305806159973145, - 3.4382665157318115, - 3.4580366611480713, - 3.4817991256713867, - 3.4958710670471191, - 3.5036792755126953, - 3.5047574043273926, - 3.4988734722137451, - 3.493056058883667, - 3.4822943210601807, - 3.459430456161499, - 3.4612770080566406, - 3.4559063911437988, - 3.4755423069000244, - 3.4971549510955811, - 3.5326557159423828, - 3.5705199241638184, - 3.5920312404632568, - 3.596907377243042, - 3.5913500785827637, - 3.5865931510925293, - 3.5826809406280518, - 3.5837743282318115, - 3.5895791053771973, - 3.5819313526153564, - 3.5837869644165039, - 3.5861184597015381, - 3.5889589786529541, - 3.592214822769165, - 3.5939455032348633, - 3.5856630802154541, - 3.5884113311767578, - 3.5921022891998291, - 3.5870490074157715, - 3.5806570053100586, - 3.5731067657470703, - 3.5617532730102539, - 3.54980731010437, - 3.5527374744415283, - 3.5475366115570068, - 3.5387849807739258, - 3.5256178379058838, - 3.5031836032867432, - 3.4922726154327393, - 3.4879646301269531, - 3.4725594520568848, - 3.4558389186859131, - 3.4351828098297119, - 3.4284293651580811, - 3.4299170970916748 + 3.4353282451629639, + 3.5962932109832764, + 3.7012472152709961, + 3.7369205951690674, + 3.7535104751586914, + 3.693629264831543, + 3.6922497749328613, + 3.7641522884368896, + 3.8419716358184814, + 3.8999848365783691, + 3.9294240474700928, + 3.9317409992218018, + 3.9139585494995117, + 3.9031598567962646, + 3.8691999912261963, + 3.8155081272125244, + 3.7644970417022705, + 3.7099106311798096, + 3.6965086460113525, + 3.6003766059875488, + 3.5493226051330566, + 3.5465121269226074, + 3.45003604888916, + 3.4712812900543213, + 3.4084610939025879, + 3.4408135414123535, + 3.4104881286621094, + 3.4217638969421387, + 3.4312851428985596, + 3.4199209213256836, + 3.4305806159973145, + 3.4382665157318115, + 3.4580366611480713, + 3.4817991256713867, + 3.4958710670471191, + 3.5036792755126953, + 3.5047574043273926, + 3.4988734722137451, + 3.493056058883667, + 3.4822943210601807, + 3.459430456161499, + 3.4612770080566406, + 3.4559063911437988, + 3.4755423069000244, + 3.4971549510955811, + 3.5326557159423828, + 3.5705199241638184, + 3.5920312404632568, + 3.596907377243042, + 3.5913500785827637, + 3.5865931510925293, + 3.5826809406280518, + 3.5837743282318115, + 3.5895791053771973, + 3.5819313526153564, + 3.5837869644165039, + 3.5861184597015381, + 3.5889589786529541, + 3.592214822769165, + 3.5939455032348633, + 3.5856630802154541, + 3.5884113311767578, + 3.5921022891998291, + 3.5870490074157715, + 3.5806570053100586, + 3.5731067657470703, + 3.5617532730102539, + 3.54980731010437, + 3.5527374744415283, + 3.5475366115570068, + 3.5387849807739258, + 3.5256178379058838, + 3.5031836032867432, + 3.4922726154327393, + 3.4879646301269531, + 3.4725594520568848, + 3.4558389186859131, + 3.4351828098297119, + 3.4284293651580811, + 3.4299170970916748, ] @struct.dataclass class LibrispeechPreprocessingConfig: """Config to hold all preprocessing options for librispeech dataset.""" + sample_rate: float = 16000.0 frame_size_ms: float = 25.0 frame_step_ms: float = 10.0 @@ -208,8 +209,9 @@ class LibrispeechPreprocessingConfig: def _hertz_to_mel(frequencies_hertz): """Convert hertz to mel.""" - return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz / - _MEL_BREAK_FREQUENCY_HERTZ)) + return _MEL_HIGH_FREQUENCY_Q * jnp.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) def _pad_end_length(num_timesteps, frame_step, frame_size): @@ -221,11 +223,13 @@ def _pad_end_length(num_timesteps, frame_step, frame_size): return padded_length - num_timesteps -def frame(x, - frame_length: int, - frame_step: int, - pad_end: bool = False, - pad_value: Union[int, float] = 0.0): +def frame( + x, + frame_length: int, + frame_step: int, + pad_end: bool = False, + pad_value: Union[int, float] = 0.0, +): """Slides a window and extract values. This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with @@ -251,24 +255,31 @@ def frame(x, if pad_end: num_extends = _pad_end_length(num_timesteps, frame_step, frame_length) x = jnp.pad( - x, ((0, 0), (0, num_extends), (0, 0)), - 'constant', - constant_values=pad_value) + x, + ((0, 0), (0, num_extends), (0, 0)), + 'constant', + constant_values=pad_value, + ) flat_y = jax.lax.conv_general_dilated_patches( - x, (frame_length,), (frame_step,), - 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) + x, + (frame_length,), + (frame_step,), + 'VALID', + dimension_numbers=('NTC', 'OIT', 'NTC'), + ) ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length)) return ret.transpose((0, 1, 3, 2)) -def linear_to_mel_weight_matrix(num_mel_bins: int = 20, - num_spectrogram_bins: int = 129, - sample_rate: Union[int, float] = 8000, - lower_edge_hertz: Union[int, float] = 125.0, - upper_edge_hertz: Union[int, float] = 3800.0, - dtype: Any = jnp.float32): +def linear_to_mel_weight_matrix( + num_mel_bins: int = 20, + num_spectrogram_bins: int = 129, + sample_rate: Union[int, float] = 8000, + lower_edge_hertz: Union[int, float] = 125.0, + upper_edge_hertz: Union[int, float] = 3800.0, + dtype: Any = jnp.float32, +): r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`. Args: @@ -300,23 +311,29 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) if lower_edge_hertz < 0.0: - raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % - lower_edge_hertz) + raise ValueError( + 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz + ) if lower_edge_hertz >= upper_edge_hertz: - raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % - (lower_edge_hertz, upper_edge_hertz)) + raise ValueError( + 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f' + % (lower_edge_hertz, upper_edge_hertz) + ) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: - raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' - 'frequency (sample_rate / 2). Got %s for sample_rate: %s' % - (upper_edge_hertz, sample_rate)) + raise ValueError( + 'upper_edge_hertz must not be larger than the Nyquist ' + 'frequency (sample_rate / 2). Got %s for sample_rate: %s' + % (upper_edge_hertz, sample_rate) + ) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = jnp.linspace( - 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:] + 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype + )[bands_to_zero:] spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis] # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The @@ -324,10 +341,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into # num_mel_bins + 2 pieces. edges = jnp.linspace( - _hertz_to_mel(lower_edge_hertz), - _hertz_to_mel(upper_edge_hertz), - num_mel_bins + 2, - dtype=dtype) + _hertz_to_mel(lower_edge_hertz), + _hertz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=dtype, + ) # Split the triples up and reshape them into [1, num_mel_bins] tensors. lower_edge_mel = edges[:-2][jnp.newaxis, :] @@ -337,9 +355,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Calculate lower and upper slopes for every spectrogram bin. # Line segments are linear in the mel domain, not Hertz. lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( - center_mel - lower_edge_mel) + center_mel - lower_edge_mel + ) upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( - upper_edge_mel - center_mel) + upper_edge_mel - center_mel + ) # Intersect the line segments with each other and zero. mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes)) @@ -366,23 +386,26 @@ def _hanning_greco(win_support, frame_size, dtype): """ if frame_size < win_support: raise ValueError( - 'Provided frame_size = {} is lower than win_support = {}'.format( - frame_size, win_support)) + 'Provided frame_size = {} is lower than win_support = {}'.format( + frame_size, win_support + ) + ) arg = jnp.pi * 2.0 / (win_support) - hann = 0.5 - (0.5 * jnp.cos(arg * - (jnp.arange(win_support, dtype=dtype) + 0.5))) + hann = 0.5 - ( + 0.5 * jnp.cos(arg * (jnp.arange(win_support, dtype=dtype) + 0.5)) + ) zero_size = frame_size - win_support return jnp.pad(hann, [(0, zero_size)]) def _next_pow_of_two(x: Union[int, float]) -> int: - return int(2**np.ceil(np.log2(x))) + return int(2 ** np.ceil(np.log2(x))) class SpectrogramFrontend(nn.Module): - """Layer to convert input audio signals from time domain to frequency domain. - """ + """Layer to convert input audio signals from time domain to frequency domain.""" + config: LibrispeechPreprocessingConfig = None input_scale_factor: float = 1.0 output_log: bool = False @@ -390,8 +413,9 @@ class SpectrogramFrontend(nn.Module): def setup(self) -> None: p = self.config self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) - self._frame_size = int(round( - p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph + self._frame_size = ( + int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1 + ) # +1 for the preemph # TF-version has maximum of 512, but it's not always necessary self.fft_size = _next_pow_of_two(self._frame_size) @@ -421,32 +445,39 @@ def f(frame_size, dtype): def _apply_preemphasis(self, framed_signal): p = self.config if p.preemph_htk_flavor: - return jnp.concatenate([ - framed_signal[:, :, :1, :] * (1. - p.preemph), - (framed_signal[:, :, 1:-1, :] - - p.preemph * framed_signal[:, :, :-2, :]) - ], - axis=2) + return jnp.concatenate( + [ + framed_signal[:, :, :1, :] * (1.0 - p.preemph), + ( + framed_signal[:, :, 1:-1, :] + - p.preemph * framed_signal[:, :, :-2, :] + ), + ], + axis=2, + ) else: - return (framed_signal[:, :, 1:, :] - - p.preemph * framed_signal[:, :, :-1, :]) + return ( + framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :] + ) def fprop_paddings(self, input_paddings): p = self.config if p.pad_end: - num_extends = _pad_end_length(input_paddings.shape[1], - self._frame_step, - self._frame_size) + num_extends = _pad_end_length( + input_paddings.shape[1], self._frame_step, self._frame_size + ) input_paddings = jnp.pad( - input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0) + input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0 + ) return jax.lax.reduce_window( - input_paddings, - init_value=1.0, - computation=jax.lax.min, - window_dimensions=[1, self._frame_size], - window_strides=[1, self._frame_step], - padding='valid') + input_paddings, + init_value=1.0, + computation=jax.lax.min, + window_dimensions=[1, self._frame_size], + window_strides=[1, self._frame_step], + padding='valid', + ) def next_prng_key(self, name='dropout'): return self.make_rng(name) @@ -469,7 +500,8 @@ def __call__(self, inputs, input_paddings): pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor framed_signal = frame( - pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) + pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end + ) if p.preemph != 0.0: preemphasized = self._apply_preemphasis(framed_signal) @@ -477,8 +509,10 @@ def __call__(self, inputs, input_paddings): preemphasized = framed_signal[..., :-1, :] if p.noise_scale > 0.0: - noise_signal = jax.random.normal(self.next_prng_key(), - preemphasized.shape) * p.noise_scale + noise_signal = ( + jax.random.normal(self.next_prng_key(), preemphasized.shape) + * p.noise_scale + ) else: noise_signal = jnp.zeros(preemphasized.shape) @@ -501,8 +535,8 @@ def __call__(self, inputs, input_paddings): class MelFilterbankFrontend(nn.Module): - """Layer to compute log mel spectograms from input audio signals. - """ + """Layer to compute log mel spectograms from input audio signals.""" + config: LibrispeechPreprocessingConfig = None use_divide_stream: bool = True per_bin_mean: Optional[float] = None @@ -513,7 +547,8 @@ def setup(self): input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 self.stft = SpectrogramFrontend( - p, input_scale_factor=input_scale_factor, output_log=False) + p, input_scale_factor=input_scale_factor, output_log=False + ) if self.per_bin_mean is None: per_bin_mean = [0.0] * p.num_bins @@ -526,9 +561,11 @@ def setup(self): per_bin_stddev = self.per_bin_stddev self._normalizer_mean = jnp.array(per_bin_mean)[ - jnp.newaxis, jnp.newaxis, :, jnp.newaxis] + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] self._normalizer_stddev = jnp.array(per_bin_stddev)[ - jnp.newaxis, jnp.newaxis, :, jnp.newaxis] + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] @nn.compact def __call__(self, inputs, input_paddings): @@ -537,18 +574,21 @@ def __call__(self, inputs, input_paddings): spect, spect_paddings = self.stft(inputs, input_paddings) mel_weights = linear_to_mel_weight_matrix( - num_mel_bins=p.num_bins, - num_spectrogram_bins=spect.shape[2], - sample_rate=p.sample_rate, - lower_edge_hertz=p.lower_edge_hertz, - upper_edge_hertz=p.upper_edge_hertz) + num_mel_bins=p.num_bins, + num_spectrogram_bins=spect.shape[2], + sample_rate=p.sample_rate, + lower_edge_hertz=p.lower_edge_hertz, + upper_edge_hertz=p.upper_edge_hertz, + ) mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect) logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor)) normalized_logmel_spectrogram = ( - (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) + logmel_spectrogram - self._normalizer_mean + ) / self._normalizer_stddev - normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram, - -1) + normalized_logmel_spectrogram = jnp.squeeze( + normalized_logmel_spectrogram, -1 + ) return normalized_logmel_spectrogram, spect_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index b2eee1c37..9fc2e39ef 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -16,17 +16,19 @@ import math from typing import Any, List, Optional -from flax import linen as nn -from flax import struct import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn +from flax import struct from algoperf.jax_utils import Dropout -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + librispeech_preprocessor as preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + spectrum_augmenter, +) DROPOUT_RATE = 0.1 @@ -34,6 +36,7 @@ @struct.dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 dtype: Any = jnp.float32 encoder_dim: int = 512 @@ -66,6 +69,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -79,7 +83,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -92,6 +96,7 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + encoder_dim: int = 0 @nn.compact @@ -100,30 +105,32 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( - input_channels=1, output_channels=self.encoder_dim)( - outputs, output_paddings) + input_channels=1, output_channels=self.encoder_dim + )(outputs, output_paddings) outputs, output_paddings = Conv2dSubsampling( - input_channels=self.encoder_dim, - output_channels=self.encoder_dim)(outputs, output_paddings) + input_channels=self.encoder_dim, output_channels=self.encoder_dim + )(outputs, output_paddings) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( - self.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + self.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( - seq_length=outputs.shape[1]) + seq_length=outputs.shape[1] + ) - outputs = Dropout( - rate=dropout_rate, deterministic=not train)( - outputs, rate=dropout_rate) + outputs = Dropout(rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate + ) return outputs, output_paddings @@ -135,6 +142,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -142,24 +150,26 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings): # Computing strided convolution to subsample inputs. feature_group_count = inputs.shape[3] // self.filter_shape[2] outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) outputs = nn.relu(outputs) @@ -170,58 +180,61 @@ def __call__(self, inputs, paddings): pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * \ - (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): - """Feedforward block of conformer layer. - """ + """Feedforward block of conformer layer.""" + config: ConformerConfig @nn.compact - def __call__(self, - inputs, - padding_mask=None, - train=False, - dropout_rate=DROPOUT_RATE): + def __call__( + self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE + ): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( - config.encoder_dim * config.feed_forward_expansion_factor, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim * config.feed_forward_expansion_factor, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = Dropout(rate=dropout_rate)( - inputs, deterministic=not train, rate=dropout_rate) + inputs, deterministic=not train, rate=dropout_rate + ) inputs = inputs * padding_mask inputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) inputs = inputs * padding_mask inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train) @@ -236,6 +249,7 @@ class AddPositionalEmbedding(nn.Module): max_len: maximum possible length for the input posemb_init: positional embedding initializer """ + min_timescale: int = 1 max_timescale: int = 10_000 embedding_dim: int = 512 @@ -244,21 +258,23 @@ class AddPositionalEmbedding(nn.Module): def __call__(self, seq_length): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] num_timescales = self.embedding_dim // 2 - log_timescale_increment = ( - math.log(float(self.max_timescale) / float(self.min_timescale)) / - jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)) + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) inv_timescales = self.min_timescale * jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * - -log_timescale_increment) + jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment + ) scaled_time = ( - position[:, :, jnp.newaxis] * - inv_timescales[jnp.newaxis, jnp.newaxis, :]) - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], - axis=2).astype(jnp.float32) + position[:, :, jnp.newaxis] * inv_timescales[jnp.newaxis, jnp.newaxis, :] + ) + signal = jnp.concatenate( + [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2 + ).astype(jnp.float32) # Force usage of `np` rather than `jnp` to compute static values at trace # time. - signal = jnp.pad(signal, - [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) + signal = jnp.pad( + signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]] + ) return signal @@ -266,6 +282,7 @@ def __call__(self, seq_length): # https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201 class QueryScaler(nn.Module): """A layer to scale individual dims of the query attention matrix.""" + dim: int = 0 def setup(self): @@ -275,8 +292,10 @@ def setup(self): def __call__(self, inputs): inputs_shape = inputs.shape if inputs_shape[-1] != self.dim: - raise ValueError('QueryScaler expects inputs to have' - ' same last dimension as scaling param.') + raise ValueError( + 'QueryScaler expects inputs to have' + ' same last dimension as scaling param.' + ) # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we # can avoid unnecessary XLA op fusion mess on TPU. @@ -291,18 +310,20 @@ def __call__(self, inputs): # Modifying flax linen default dot product attention function to add # query scaling, reference to original function here : # https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121 -def dot_product_attention(query, - key, - value, - bias=None, - mask=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - dtype=jnp.float32, - precision=None, - temperature=1.0): +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + dtype=jnp.float32, + precision=None, + temperature=1.0, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -341,29 +362,35 @@ def dot_product_attention(query, """ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') + 'q, k, v batch dims must match.' + ) assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + 'q, k, v num_heads must match.' + ) assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights query = QueryScaler(dim=query.shape[-1])(query) attn_weights = nn.attention.dot_product_attention_weights( - query, - key, - bias, - mask, - broadcast_dropout, - dropout_rng, - dropout_rate, - deterministic, - dtype, - precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum( - '...hqk,...khd->...qhd', attn_weights, value, - precision=precision) * temperature + return ( + jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) + * temperature + ) class MultiHeadedSelfAttention(nn.Module): @@ -375,6 +402,7 @@ class MultiHeadedSelfAttention(nn.Module): Note: this attention implementation uses a learned scale parameter to scale query matrix before passing it to flax attention module. """ + config: ConformerConfig = None @nn.compact @@ -383,28 +411,30 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( - mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) + mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32 + ) inputs = LayerNorm(dim=config.encoder_dim)(inputs) attention_fn = functools.partial( - dot_product_attention, temperature=config.attention_temperature) + dot_product_attention, temperature=config.attention_temperature + ) result = nn.MultiHeadDotProductAttention( - num_heads=config.num_attention_heads, - qkv_features=config.encoder_dim, - decode=False, - dtype=config.dtype, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - use_bias=True, - broadcast_dropout=False, - attention_fn=attention_fn, - dropout_rate=dropout_rate, - deterministic=not train)( - inputs_q=inputs, mask=attention_mask) - - result = Dropout( - rate=dropout_rate, deterministic=not train)( - result, rate=dropout_rate) + num_heads=config.num_attention_heads, + qkv_features=config.encoder_dim, + decode=False, + dtype=config.dtype, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + use_bias=True, + broadcast_dropout=False, + attention_fn=attention_fn, + dropout_rate=dropout_rate, + deterministic=not train, + )(inputs_q=inputs, mask=attention_mask) + + result = Dropout(rate=dropout_rate, deterministic=not train)( + result, rate=dropout_rate + ) return result @@ -421,30 +451,27 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + config: ConformerConfig def setup(self): dim = self.config.encoder_dim dtype = self.config.dtype - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, - inputs, - input_paddings, - update_batch_norm, - use_running_average_bn): + def __call__( + self, inputs, input_paddings, update_batch_norm, use_running_average_bn + ): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -461,23 +488,25 @@ def __call__(self, mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True + ) count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = jnp.sum( - (inputs - mean) * (inputs - mean) * mask, - axis=reduce_over_dims, - keepdims=True) + (inputs - mean) * (inputs - mean) * mask, + axis=reduce_over_dims, + keepdims=True, + ) var = sum_vv / count_v if update_batch_norm: - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var + self.ra_mean.value = ( + momentum * self.ra_mean.value + (1 - momentum) * mean + ) + self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta @@ -506,64 +535,68 @@ class ConvolutionBlock(nn.Module): | output """ + config: ConformerConfig @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average_bn, - dropout_rate=DROPOUT_RATE): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn, + dropout_rate=DROPOUT_RATE, + ): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True, + )(inputs) input_gated2 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True, + )(inputs) inputs = input_gated1 * jax.nn.sigmoid(input_gated2) inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1)) inputs = nn.Conv( - features=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - strides=(1,), - padding='SAME', - feature_group_count=config.encoder_dim, - use_bias=False, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - - inputs = BatchNorm(config)(inputs, - input_paddings, - update_batch_norm, - use_running_average_bn) + features=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + strides=(1,), + padding='SAME', + feature_group_count=config.encoder_dim, + use_bias=False, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) + + inputs = BatchNorm(config)( + inputs, input_paddings, update_batch_norm, use_running_average_bn + ) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = nn.Dense( - config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim, kernel_init=nn.initializers.xavier_uniform() + )(inputs) - inputs = Dropout( - rate=dropout_rate, deterministic=not train)( - inputs, rate=dropout_rate) + inputs = Dropout(rate=dropout_rate, deterministic=not train)( + inputs, rate=dropout_rate + ) return inputs @@ -580,36 +613,42 @@ class ConformerBlock(nn.Module): y = layer_norm(x) """ + config: ConformerConfig @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average, - dropout_rate=DROPOUT_RATE): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average, + dropout_rate=DROPOUT_RATE, + ): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train, dropout_rate) + inputs, padding_mask, train, dropout_rate + ) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train, dropout_rate=dropout_rate) - - inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, - use_running_average, - dropout_rate - ) + inputs, input_paddings, train, dropout_rate=dropout_rate + ) + + inputs = inputs + ConvolutionBlock(config)( + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average, + dropout_rate, + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train, dropout_rate) + inputs, padding_mask, train, dropout_rate + ) if config.use_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -624,27 +663,30 @@ class Conformer(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: ConformerConfig def setup(self): self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=self.config.freq_mask_count, - freq_mask_max_bins=self.config.freq_mask_max_bins, - time_mask_count=self.config.time_mask_count, - time_mask_max_frames=self.config.time_mask_max_frames, - time_mask_max_ratio=self.config.time_mask_max_ratio, - time_masks_per_frame=self.config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.config - .use_dynamic_time_mask_max_frames) + freq_mask_count=self.config.freq_mask_count, + freq_mask_max_bins=self.config.freq_mask_max_bins, + time_mask_count=self.config.time_mask_count, + time_mask_max_frames=self.config.time_mask_max_frames, + time_mask_max_ratio=self.config.time_mask_max_ratio, + time_masks_per_frame=self.config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=self.config.use_dynamic_time_mask_max_frames, + ) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None, - dropout_rate: float = DROPOUT_RATE): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None, + dropout_rate: float = DROPOUT_RATE, + ): config = self.config outputs = inputs @@ -661,8 +703,8 @@ def __call__(self, outputs, output_paddings = preprocessor.MelFilterbankFrontend( preprocessing_config, per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)( - outputs, output_paddings) + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. @@ -670,24 +712,26 @@ def __call__(self, outputs, output_paddings = self.specaug(outputs, output_paddings) outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim,)( - outputs, output_paddings, train, dropout_rate=dropout_rate) + encoder_dim=config.encoder_dim, + )(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, - output_paddings, - train, - update_batch_norm, - use_running_average_bn, - dropout_rate) + outputs = ConformerBlock(config)( + outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn, + dropout_rate, + ) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..d9c1e301b 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -17,6 +17,7 @@ class SpecAug(nn.Module): This is an essential component in speech recognition models that helps achieve better word error rates. """ + freq_mask_count: int = 2 freq_mask_max_bins: int = 27 time_mask_count: int = 10 @@ -28,26 +29,30 @@ class SpecAug(nn.Module): def next_prng_key(self, name='dropout'): return self.make_rng(name) - def _get_mask(self, - batch_size, - choose_range, - mask_size, - max_length=None, - masks_per_frame=0.0, - multiplicity=1, - max_ratio=1.0): + def _get_mask( + self, + batch_size, + choose_range, + mask_size, + max_length=None, + masks_per_frame=0.0, + multiplicity=1, + max_ratio=1.0, + ): # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = jnp.tile(max_length, (batch_size,)) else: max_length = choose_range * max_ratio masked_portion = jax.random.uniform( - key=self.next_prng_key(), - shape=(batch_size, multiplicity), - minval=0.0, - maxval=1.0) - masked_frame_size = jnp.einsum('b,bm->bm', max_length, - masked_portion).astype(jnp.int32) + key=self.next_prng_key(), + shape=(batch_size, multiplicity), + minval=0.0, + maxval=1.0, + ) + masked_frame_size = jnp.einsum( + 'b,bm->bm', max_length, masked_portion + ).astype(jnp.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) @@ -57,7 +62,8 @@ def _get_mask(self, # Choose starting point. random_start = jax.random.uniform( - key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0) + key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0 + ) start_with_in_valid_range = random_start * (choose_range - length + 1) start = start_with_in_valid_range.astype(jnp.int32) @@ -78,11 +84,13 @@ def _get_mask(self, # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = jnp.tile( - jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), - [batch_size, 1]) + jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), + [batch_size, 1], + ) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = ( + multiplicity_weights < multiplicity_tensor + ).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) @@ -98,8 +106,9 @@ def _time_mask(self, inputs, length): max_ratio = self.time_mask_max_ratio # If maximum mask length is zero, do nothing. - if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or - max_ratio <= 0.0): + if ( + time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames + ) or max_ratio <= 0.0: return inputs if multiplicity == 0: return inputs @@ -111,13 +120,14 @@ def _time_mask(self, inputs, length): time_mask_max_frames = None # Create masks in time direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=length, - mask_size=time_length, - max_length=time_mask_max_frames, - masks_per_frame=self.time_masks_per_frame, - multiplicity=multiplicity, - max_ratio=max_ratio) + batch_size, + choose_range=length, + mask_size=time_length, + max_length=time_mask_max_frames, + masks_per_frame=self.time_masks_per_frame, + multiplicity=multiplicity, + max_ratio=max_ratio, + ) outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays) return outputs @@ -136,13 +146,14 @@ def _frequency_mask(self, inputs): choose_range = jnp.tile(num_freq, (batch_size,)) # Create masks in frequency direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=choose_range, - mask_size=num_freq, - max_length=freq_mask_max_bins, - masks_per_frame=0.0, - multiplicity=multiplicity, - max_ratio=1.0) + batch_size, + choose_range=choose_range, + mask_size=num_freq, + max_length=freq_mask_max_bins, + masks_per_frame=0.0, + multiplicity=multiplicity, + max_ratio=1.0, + ) outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays) return outputs diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 1e1a1d3f8..819e57a69 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,31 +2,28 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp import numpy as np import optax import torch +from flax import jax_utils +from flax.core import pop +from jax import lax -from algoperf import data_utils -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.librispeech_conformer import metrics -from algoperf.workloads.librispeech_conformer import workload -from algoperf.workloads.librispeech_conformer.input_pipeline import \ - LibriSpeechDataset +from algoperf import data_utils, param_utils, spec +from algoperf.workloads.librispeech_conformer import metrics, workload +from algoperf.workloads.librispeech_conformer.input_pipeline import ( + LibriSpeechDataset, +) from algoperf.workloads.librispeech_conformer.librispeech_jax import models class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: + def __init__( + self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True + ) -> None: super().__init__() self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path) self.use_specaug = use_specaug @@ -38,7 +35,8 @@ def __init__(self, def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -58,8 +56,8 @@ def attention_temperature(self) -> float: return 1.0 def init_model_fn( - self, - rng: spec.RandomState, + self, + rng: spec.RandomState, ) -> spec.ModelInitState: """Conformer model init function. @@ -71,10 +69,11 @@ def init_model_fn( else: activation_function_name = 'swish' model_config = models.ConformerConfig( - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name, + ) self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] @@ -85,7 +84,7 @@ def init_model_fn( params_rng, _ = jax.random.split(rng, 2) variables = model_init_fn({'params': params_rng}, *fake_input_batch) - model_state, params = pop(variables, "params") + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -97,49 +96,52 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[float] = models.DROPOUT_RATE, + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn, - dropout_rate=dropout_rate) + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout': rng}, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + dropout_rate=dropout_rate, + ) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + inputs, + input_paddings, + train=False, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return (logits, logit_paddings), model_state def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del data_rng del cache del repeat_final_dataset @@ -158,38 +160,41 @@ def _build_input_queue( ds = LibriSpeechDataset(split=split, data_dir=data_dir) dataloader = data_utils.cycle( - torch.utils.data.DataLoader( - ds, - batch_size=global_batch_size, - shuffle=train, - sampler=None, - num_workers=4, - prefetch_factor=10, - pin_memory=False, - drop_last=train, - )) + torch.utils.data.DataLoader( + ds, + batch_size=global_batch_size, + shuffle=train, + sampler=None, + num_workers=4, + prefetch_factor=10, + pin_memory=False, + drop_last=train, + ) + ) for batch in iter(dataloader): inputs, input_paddings = batch['inputs'] targets, target_paddings = batch['targets'] numpy_batch = { - 'inputs': (inputs.numpy(), input_paddings.numpy()), - 'targets': (targets.numpy(), target_paddings.numpy()), + 'inputs': (inputs.numpy(), input_paddings.numpy()), + 'targets': (targets.numpy(), target_paddings.numpy()), } padded_batch = data_utils.shard_and_maybe_pad_np( - numpy_batch, padding_value=1.0) + numpy_batch, padding_value=1.0 + ) yield padded_batch # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -200,10 +205,9 @@ def loss_fn( logits, logit_paddings = logits_batch targets, target_paddings = label_batch logprobs = nn.log_softmax(logits) - per_example_losses = self.ctc_loss(logprobs, - logit_paddings, - targets, - target_paddings) + per_example_losses = self.ctc_loss( + logprobs, logit_paddings, targets, target_paddings + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -213,23 +217,26 @@ def loss_fn( n_valid_examples = jnp.maximum(mask_batch.sum(), 1) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def ctc_loss(self, - logits: spec.Tensor, - logit_paddings: spec.Tensor, - labels: spec.Tensor, - label_paddings: spec.Tensor, - blank_id: int = 0) -> spec.Tensor: + def ctc_loss( + self, + logits: spec.Tensor, + logit_paddings: spec.Tensor, + labels: spec.Tensor, + label_paddings: spec.Tensor, + blank_id: int = 0, + ) -> spec.Tensor: return optax.ctc_loss( - logits=logits, - logit_paddings=logit_paddings, - labels=labels, - label_paddings=label_paddings, - blank_id=blank_id) + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id, + ) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138. @@ -240,21 +247,22 @@ def sequence_mask(self, lengths: spec.Tensor, maxlen: int) -> spec.Tensor: c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype) return c - def collapse_and_remove_blanks(self, - labels: spec.Tensor, - seq_length: spec.Tensor, - blank_id: int = 0) -> spec.Tensor: + def collapse_and_remove_blanks( + self, labels: spec.Tensor, seq_length: spec.Tensor, blank_id: int = 0 + ) -> spec.Tensor: b, t = labels.shape # Zap out blank. blank_mask = 1 - jnp.equal(labels, blank_id) labels = (labels * blank_mask).astype(labels.dtype) # Mask labels that don't equal previous label. - label_mask = jnp.concatenate([ + label_mask = jnp.concatenate( + [ jnp.ones_like(labels[:, :1], dtype=jnp.int32), jnp.not_equal(labels[:, 1:], labels[:, :-1]), - ], - axis=1) + ], + axis=1, + ) # Filter labels that aren't in the original sequence. maxlen = labels.shape[1] @@ -290,12 +298,14 @@ def collapse_and_remove_blanks(self, # Reshape back to square batch. batch_size = labels.shape[0] new_shape = [batch_size, new_maxlen] - return (jnp.reshape(flat, new_shape).astype(labels.dtype), - new_seq_len.astype(seq_length.dtype)) + return ( + jnp.reshape(flat, new_shape).astype(labels.dtype), + new_seq_len.astype(seq_length.dtype), + ) def greedy_decode( - self, logits: spec.Tensor, - logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, logit_paddings: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: per_frame_max = jnp.argmax(logits, axis=-1) seqlen = jnp.sum(1.0 - logit_paddings, axis=-1) hyp, _ = self.collapse_and_remove_blanks(per_frame_max, seqlen, blank_id=0) @@ -303,45 +313,51 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def eval_step_pmapped( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: (logits, logit_paddings), _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + loss_dict=loss, + decoded=decoded, + decoded_paddings=decoded_paddings, + targets=targets, + target_paddings=target_paddings, + axis_name='batch', + ) + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step if model_state is not None and len(model_state) > 0: @@ -351,15 +367,15 @@ def _eval_model_on_split(self, num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - rng, split, data_dir, global_batch_size, num_batches=num_batches) + rng, split, data_dir, global_batch_size, num_batches=num_batches + ) metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step_pmapped( + params, eval_batch, model_state, rng + ).unreplicate() if metrics_report is None: metrics_report = computed_metrics @@ -372,7 +388,8 @@ def _eval_model_on_split(self, return computed_metrics def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics and # we average them. @@ -383,8 +400,8 @@ def sync_batch_stats( class LibriSpeechConformerAttentionTemperatureWorkload( - LibriSpeechConformerWorkload): - + LibriSpeechConformerWorkload +): @property def attention_temperature(self) -> float: return 1.6 @@ -399,7 +416,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): - @property def use_post_layer_norm(self) -> bool: return False @@ -414,7 +430,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): - @property def use_gelu(self) -> bool: return True diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index 3a2eda4af..647b8ff0c 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -2,20 +2,22 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. """ +import math from dataclasses import dataclass from functools import partial -import math from typing import Tuple import torch +import torch.nn.functional as F from torch import nn from torch.nn import init -import torch.nn.functional as F -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import ( + preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import ( + SpecAug, +) DROPOUT_RATE = 0.1 @@ -23,6 +25,7 @@ @dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 encoder_dim: int = 512 num_attention_heads: int = 8 @@ -58,7 +61,6 @@ def initialize(m): class LayerNorm(nn.Module): - def __init__(self, dim, epsilon=1e-6): super().__init__() self.dim = dim @@ -72,24 +74,25 @@ def forward(self, x): class Subsample(nn.Module): - def __init__(self, encoder_dim: int = 0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim) + input_channels=1, output_channels=encoder_dim + ) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim) + input_channels=encoder_dim, output_channels=encoder_dim + ) self.linear = nn.Linear( - in_features=self.encoder_dim * num_bins // 4, - out_features=self.encoder_dim, - bias=True) + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True, + ) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) def forward(self, inputs, input_paddings, dropout_rate): - output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -97,25 +100,27 @@ def forward(self, inputs, input_paddings, dropout_rate): outputs, output_paddings = self.conv2(outputs, output_paddings) batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) + outputs = outputs.permute(0, 2, 3, 1).reshape( + batch_size, subsampled_lengths, subsampled_dims * channels + ) outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) outputs = F.dropout( - outputs, dropout_rate, training=self.training, inplace=True) + outputs, dropout_rate, training=self.training, inplace=True + ) return outputs, output_paddings class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME'): + def __init__( + self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + ): super().__init__() self.input_channels = input_channels @@ -126,7 +131,8 @@ def __init__(self, self.filter_shape = (output_channels, input_channels, 3, 3) self.kernel = nn.Parameter( - torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)) + ) self.bias = nn.Parameter(torch.zeros(output_channels)) self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) @@ -156,12 +162,13 @@ def forward(self, inputs, paddings): else: in_ = inputs outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups, + ) outputs = F.relu(outputs) @@ -169,35 +176,37 @@ def forward(self, inputs, paddings): stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length padded_paddings = F.pad( - paddings[:, None, :], (0, pad_len), mode='constant', value=0) + paddings[:, None, :], (0, pad_len), mode='constant', value=0 + ) out_padding = F.conv1d( - input=padded_paddings, - weight=self.paddings_kernel, - stride=self.filter_stride[:1]) + input=padded_paddings, + weight=self.paddings_kernel, + stride=self.filter_stride[:1], + ) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) return outputs, out_padding class FeedForwardModule(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.config = config self.ln = LayerNorm(dim=config.encoder_dim) self.linear1 = nn.Linear( - in_features=config.encoder_dim, - out_features=config.encoder_dim * config.feed_forward_expansion_factor, - bias=True) + in_features=config.encoder_dim, + out_features=config.encoder_dim * config.feed_forward_expansion_factor, + bias=True, + ) self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) self.linear2 = nn.Linear( - in_features=config.encoder_dim * config.feed_forward_expansion_factor, - out_features=config.encoder_dim, - bias=True) + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True, + ) def forward(self, inputs, padding_mask, dropout_rate): - inputs = self.ln(inputs) inputs = self.linear1(inputs) if self.config.activation_function_name == 'swish': @@ -206,52 +215,58 @@ def forward(self, inputs, padding_mask, dropout_rate): # Use tanh approximation of GELU which is default for jax activation_fn = partial(F.gelu, approximate='tanh') else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = self.dropout1(inputs) inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask inputs = F.dropout( - inputs, dropout_rate, training=self.training, inplace=True) + inputs, dropout_rate, training=self.training, inplace=True + ) return inputs class AddPositionalEmbedding(nn.Module): - - def __init__(self, - min_timescale: int = 1, - max_timescale: int = 10_000, - embedding_dim: int = 512): + def __init__( + self, + min_timescale: int = 1, + max_timescale: int = 10_000, + embedding_dim: int = 512, + ): super().__init__() self.min_timescale = min_timescale self.max_timescale = max_timescale self.embedding_dim = embedding_dim num_timescales = self.embedding_dim // 2 log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale)) / ( - num_timescales - 1) - inv_timescales = self.min_timescale * \ - torch.exp(torch.arange(num_timescales, dtype=torch.float32) - * -log_timescale_increment) + float(self.max_timescale) / float(self.min_timescale) + ) / (num_timescales - 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) + * -log_timescale_increment + ) self.register_buffer('inv_timescales', inv_timescales[None, None, :]) def forward(self, seq_length): position = torch.arange( - end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) + end=seq_length, dtype=torch.float32, device=self.inv_timescales.device + ) scaled_time = position[None, :, None] * self.inv_timescales signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) if self.embedding_dim % 2: signal = torch.cat( - [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) + [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2 + ) return signal class QueryScaler(nn.Module): - def __init__(self, dim): super().__init__() self.dim = dim @@ -264,7 +279,6 @@ def forward(self, inputs): class MHSAwithQS(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.embed_dim = config.encoder_dim @@ -281,20 +295,23 @@ def forward(self, inputs, key_padding_mask=None): q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - out = F.scaled_dot_product_attention( + out = ( + F.scaled_dot_product_attention( query=q, key=k, value=v, attn_mask=~key_padding_mask[:, None, None], dropout_p=self.attention_dropout_rate, - ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + ) + .transpose(1, 2) + .reshape(batch_size, seq_len, embed_dim) + ) out = out * self.attention_temperature out = self.out_proj(out) return out class MultiHeadedSelfAttention(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() @@ -306,16 +323,16 @@ def __init__(self, config: ConformerConfig): def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( - outputs, - key_padding_mask=paddings == 1, + outputs, + key_padding_mask=paddings == 1, ) outputs = F.dropout( - outputs, dropout_rate, training=self.training, inplace=True) + outputs, dropout_rate, training=self.training, inplace=True + ) return outputs class BatchNorm(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() running_mean = torch.zeros(config.encoder_dim) @@ -330,8 +347,8 @@ def __init__(self, config: ConformerConfig): self.epsilon = config.batch_norm_epsilon def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH + # inputs: NHD + # padding: NH """ Alternatively: inputs[input_paddings==0] = F.batch_norm( @@ -355,9 +372,11 @@ def forward(self, inputs, input_paddings): var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() + self.momentum + ) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() + self.momentum + ) * var.detach() else: mean = self.running_mean @@ -369,25 +388,27 @@ def forward(self, inputs, input_paddings): class ConvolutionBlock(nn.Module): - def __init__(self, config): super().__init__() self.config = config self.ln = LayerNorm(dim=config.encoder_dim) self.lin1 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) + in_features=config.encoder_dim, out_features=config.encoder_dim + ) self.lin2 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) + in_features=config.encoder_dim, out_features=config.encoder_dim + ) self.conv1 = nn.Conv1d( - in_channels=config.encoder_dim, - out_channels=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - stride=(1,), - padding='same', - bias=False, - groups=config.encoder_dim) + in_channels=config.encoder_dim, + out_channels=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + stride=(1,), + padding='same', + bias=False, + groups=config.encoder_dim, + ) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) @@ -407,19 +428,21 @@ def forward(self, inputs, input_paddings, dropout_rate): elif self.config.activation_function_name == 'gelu': activation_fn = F.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = self.lin3(inputs) inputs = F.dropout( - inputs, dropout_rate, training=self.training, inplace=True) + inputs, dropout_rate, training=self.training, inplace=True + ) return inputs class ConformerBlock(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() @@ -443,28 +466,30 @@ def forward(self, inputs, input_paddings, dropout_rate): class ConformerEncoderDecoder(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.config = config preprocessing_config = preprocessor.PreprocessorConfig() self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + ) self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) self.subsample = Subsample( - encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins) + encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins + ) self.conformers = nn.ModuleList( - [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) + [ConformerBlock(config) for _ in range(config.num_encoder_layers)] + ) self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) @@ -475,8 +500,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings, - dropout_rate) + outputs, output_paddings = self.subsample( + outputs, output_paddings, dropout_rate + ) for conformer in self.conformers: outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py index 558a0f796..f8c1bd0d2 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py @@ -16,174 +16,175 @@ _MEL_HIGH_FREQUENCY_Q = 1127.0 LIBRISPEECH_MEAN_VECTOR = [ - -7.6047816276550293, - -7.1206226348876953, - -6.8864245414733887, - -6.8705768585205078, - -6.9667720794677734, - -7.1084094047546387, - -6.9528026580810547, - -6.783994197845459, - -6.6195521354675293, - -6.4876265525817871, - -6.4120659828186035, - -6.394047737121582, - -6.4244871139526367, - -6.3993711471557617, - -6.5158271789550781, - -6.7137999534606934, - -6.8476877212524414, - -6.9885001182556152, - -6.9221386909484863, - -7.146148681640625, - -7.2040400505065918, - -7.0537552833557129, - -7.3140382766723633, - -7.1223249435424805, - -7.30251407623291, - -7.1212143898010254, - -7.2425732612609863, - -7.1730537414550781, - -7.0979413986206055, - -7.088747501373291, - -6.9849910736083984, - -6.8787732124328613, - -6.7602753639221191, - -6.6300945281982422, - -6.5145769119262695, - -6.4245057106018066, - -6.356513500213623, - -6.31787633895874, - -6.2660770416259766, - -6.2468328475952148, - -6.2821526527404785, - -6.1908388137817383, - -6.2484354972839355, - -6.1472640037536621, - -6.0924725532531738, - -6.0171003341674805, - -5.9250402450561523, - -5.8535833358764648, - -5.8209109306335449, - -5.8118929862976074, - -5.80783748626709, - -5.7714629173278809, - -5.7453732490539551, - -5.7705655097961426, - -5.7765641212463379, - -5.7831673622131348, - -5.7954087257385254, - -5.7994823455810547, - -5.8023476600646973, - -5.8047118186950684, - -5.8168182373046875, - -5.8844799995422363, - -5.9727106094360352, - -6.0444660186767578, - -6.1284866333007812, - -6.2257585525512695, - -6.3157496452331543, - -6.39061164855957, - -6.4928598403930664, - -6.5498456954956055, - -6.6054320335388184, - -6.6508378982543945, - -6.66917610168457, - -6.6726889610290527, - -6.684234619140625, - -6.6974577903747559, - -6.75471830368042, - -6.7949142456054688, - -6.8634209632873535, - -6.94186544418335 + -7.6047816276550293, + -7.1206226348876953, + -6.8864245414733887, + -6.8705768585205078, + -6.9667720794677734, + -7.1084094047546387, + -6.9528026580810547, + -6.783994197845459, + -6.6195521354675293, + -6.4876265525817871, + -6.4120659828186035, + -6.394047737121582, + -6.4244871139526367, + -6.3993711471557617, + -6.5158271789550781, + -6.7137999534606934, + -6.8476877212524414, + -6.9885001182556152, + -6.9221386909484863, + -7.146148681640625, + -7.2040400505065918, + -7.0537552833557129, + -7.3140382766723633, + -7.1223249435424805, + -7.30251407623291, + -7.1212143898010254, + -7.2425732612609863, + -7.1730537414550781, + -7.0979413986206055, + -7.088747501373291, + -6.9849910736083984, + -6.8787732124328613, + -6.7602753639221191, + -6.6300945281982422, + -6.5145769119262695, + -6.4245057106018066, + -6.356513500213623, + -6.31787633895874, + -6.2660770416259766, + -6.2468328475952148, + -6.2821526527404785, + -6.1908388137817383, + -6.2484354972839355, + -6.1472640037536621, + -6.0924725532531738, + -6.0171003341674805, + -5.9250402450561523, + -5.8535833358764648, + -5.8209109306335449, + -5.8118929862976074, + -5.80783748626709, + -5.7714629173278809, + -5.7453732490539551, + -5.7705655097961426, + -5.7765641212463379, + -5.7831673622131348, + -5.7954087257385254, + -5.7994823455810547, + -5.8023476600646973, + -5.8047118186950684, + -5.8168182373046875, + -5.8844799995422363, + -5.9727106094360352, + -6.0444660186767578, + -6.1284866333007812, + -6.2257585525512695, + -6.3157496452331543, + -6.39061164855957, + -6.4928598403930664, + -6.5498456954956055, + -6.6054320335388184, + -6.6508378982543945, + -6.66917610168457, + -6.6726889610290527, + -6.684234619140625, + -6.6974577903747559, + -6.75471830368042, + -6.7949142456054688, + -6.8634209632873535, + -6.94186544418335, ] LIBRISPEECH_STD_VECTOR = [ - 3.4353282451629639, - 3.5962932109832764, - 3.7012472152709961, - 3.7369205951690674, - 3.7535104751586914, - 3.693629264831543, - 3.6922497749328613, - 3.7641522884368896, - 3.8419716358184814, - 3.8999848365783691, - 3.9294240474700928, - 3.9317409992218018, - 3.9139585494995117, - 3.9031598567962646, - 3.8691999912261963, - 3.8155081272125244, - 3.7644970417022705, - 3.7099106311798096, - 3.6965086460113525, - 3.6003766059875488, - 3.5493226051330566, - 3.5465121269226074, - 3.45003604888916, - 3.4712812900543213, - 3.4084610939025879, - 3.4408135414123535, - 3.4104881286621094, - 3.4217638969421387, - 3.4312851428985596, - 3.4199209213256836, - 3.4305806159973145, - 3.4382665157318115, - 3.4580366611480713, - 3.4817991256713867, - 3.4958710670471191, - 3.5036792755126953, - 3.5047574043273926, - 3.4988734722137451, - 3.493056058883667, - 3.4822943210601807, - 3.459430456161499, - 3.4612770080566406, - 3.4559063911437988, - 3.4755423069000244, - 3.4971549510955811, - 3.5326557159423828, - 3.5705199241638184, - 3.5920312404632568, - 3.596907377243042, - 3.5913500785827637, - 3.5865931510925293, - 3.5826809406280518, - 3.5837743282318115, - 3.5895791053771973, - 3.5819313526153564, - 3.5837869644165039, - 3.5861184597015381, - 3.5889589786529541, - 3.592214822769165, - 3.5939455032348633, - 3.5856630802154541, - 3.5884113311767578, - 3.5921022891998291, - 3.5870490074157715, - 3.5806570053100586, - 3.5731067657470703, - 3.5617532730102539, - 3.54980731010437, - 3.5527374744415283, - 3.5475366115570068, - 3.5387849807739258, - 3.5256178379058838, - 3.5031836032867432, - 3.4922726154327393, - 3.4879646301269531, - 3.4725594520568848, - 3.4558389186859131, - 3.4351828098297119, - 3.4284293651580811, - 3.4299170970916748 + 3.4353282451629639, + 3.5962932109832764, + 3.7012472152709961, + 3.7369205951690674, + 3.7535104751586914, + 3.693629264831543, + 3.6922497749328613, + 3.7641522884368896, + 3.8419716358184814, + 3.8999848365783691, + 3.9294240474700928, + 3.9317409992218018, + 3.9139585494995117, + 3.9031598567962646, + 3.8691999912261963, + 3.8155081272125244, + 3.7644970417022705, + 3.7099106311798096, + 3.6965086460113525, + 3.6003766059875488, + 3.5493226051330566, + 3.5465121269226074, + 3.45003604888916, + 3.4712812900543213, + 3.4084610939025879, + 3.4408135414123535, + 3.4104881286621094, + 3.4217638969421387, + 3.4312851428985596, + 3.4199209213256836, + 3.4305806159973145, + 3.4382665157318115, + 3.4580366611480713, + 3.4817991256713867, + 3.4958710670471191, + 3.5036792755126953, + 3.5047574043273926, + 3.4988734722137451, + 3.493056058883667, + 3.4822943210601807, + 3.459430456161499, + 3.4612770080566406, + 3.4559063911437988, + 3.4755423069000244, + 3.4971549510955811, + 3.5326557159423828, + 3.5705199241638184, + 3.5920312404632568, + 3.596907377243042, + 3.5913500785827637, + 3.5865931510925293, + 3.5826809406280518, + 3.5837743282318115, + 3.5895791053771973, + 3.5819313526153564, + 3.5837869644165039, + 3.5861184597015381, + 3.5889589786529541, + 3.592214822769165, + 3.5939455032348633, + 3.5856630802154541, + 3.5884113311767578, + 3.5921022891998291, + 3.5870490074157715, + 3.5806570053100586, + 3.5731067657470703, + 3.5617532730102539, + 3.54980731010437, + 3.5527374744415283, + 3.5475366115570068, + 3.5387849807739258, + 3.5256178379058838, + 3.5031836032867432, + 3.4922726154327393, + 3.4879646301269531, + 3.4725594520568848, + 3.4558389186859131, + 3.4351828098297119, + 3.4284293651580811, + 3.4299170970916748, ] @dataclass class PreprocessorConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + sample_rate = 16000 frame_size_ms = 25 frame_step_ms = 10 @@ -203,10 +204,12 @@ class PreprocessorConfig: def _hertz_to_mel(frequencies_hertz): """Convert hertz to mel.""" - log_fn = math.log if type(frequencies_hertz) in [type(0.0), type(0) - ] else torch.log - return _MEL_HIGH_FREQUENCY_Q * log_fn(1.0 + (frequencies_hertz / - _MEL_BREAK_FREQUENCY_HERTZ)) + log_fn = ( + math.log if type(frequencies_hertz) in [type(0.0), type(0)] else torch.log + ) + return _MEL_HIGH_FREQUENCY_Q * log_fn( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) def _pad_end_length(num_timesteps, frame_step, frame_size): @@ -218,28 +221,30 @@ def _pad_end_length(num_timesteps, frame_step, frame_size): return padded_length - num_timesteps -def frame(x, - frame_length: int, - frame_step: int, - pad_end: bool = False, - pad_value: Union[int, float] = 0.0): +def frame( + x, + frame_length: int, + frame_step: int, + pad_end: bool = False, + pad_value: Union[int, float] = 0.0, +): """Slides a window and extract values. - This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with - stride of `frame_step`, and returns an array `y` with the shape - `(batch_size, num_frames, frame_length, num_channels)`. Unlike the - counterpart in Tensorflow (`tf.signal.frame`), this function currently - does not take `axis` argument, and the input tensor `x` is expected to - have a shape of `(batch_size, timesteps, channels)`. - Args: - x: An input array with `(batch_size, timesteps, channels)`-shape. - frame_length: The frame length. - frame_step: The frame hop size. - pad_end: If True, the end of signal is padded so the window can continue - sliding while the starting point of the window is in the valid range. - pad_value: A scalar used as a padding value when `pad_end` is True. - Returns: - A tensor with shape `(*, num_frames, frame_length, num_channels)`. - """ + This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with + stride of `frame_step`, and returns an array `y` with the shape + `(batch_size, num_frames, frame_length, num_channels)`. Unlike the + counterpart in Tensorflow (`tf.signal.frame`), this function currently + does not take `axis` argument, and the input tensor `x` is expected to + have a shape of `(batch_size, timesteps, channels)`. + Args: + x: An input array with `(batch_size, timesteps, channels)`-shape. + frame_length: The frame length. + frame_step: The frame hop size. + pad_end: If True, the end of signal is padded so the window can continue + sliding while the starting point of the window is in the valid range. + pad_value: A scalar used as a padding value when `pad_end` is True. + Returns: + A tensor with shape `(*, num_frames, frame_length, num_channels)`. + """ num_timesteps = x.shape[1] if pad_end: @@ -250,60 +255,67 @@ def frame(x, return x.permute(0, 1, 3, 2) -def linear_to_mel_weight_matrix(num_mel_bins: int = 20, - num_spectrogram_bins: int = 129, - sample_rate: Union[int, float] = 8000, - lower_edge_hertz: Union[int, float] = 125.0, - upper_edge_hertz: Union[int, float] = 3800.0, - dtype: Any = torch.float32, - device='cpu'): +def linear_to_mel_weight_matrix( + num_mel_bins: int = 20, + num_spectrogram_bins: int = 129, + sample_rate: Union[int, float] = 8000, + lower_edge_hertz: Union[int, float] = 125.0, + upper_edge_hertz: Union[int, float] = 3800.0, + dtype: Any = torch.float32, + device='cpu', +): r"""Pytorch-port of `tf.signal.linear_to_mel_weight_matrix`. - Args: - num_mel_bins: Python int. How many bands in the resulting mel spectrum. - num_spectrogram_bins: An integer `Tensor`. How many bins there are in - the source spectrogram data, which is understood to be `fft_size // 2 + 1`, - i.e. the spectrogram only contains the nonredundant FFT bins. - sample_rate: An integer or float `Tensor`. Samples per second of the - input signal used to create the spectrogram. Used to figure out the - frequencies corresponding to each spectrogram bin, which dictates how they - are mapped into the mel scale. - lower_edge_hertz: Python float. Lower bound on the frequencies to be - included in the mel spectrum. This corresponds to the lower edge of the - lowest triangular band. - upper_edge_hertz: Python float. The desired top edge of the highest - frequency band. - dtype: The `DType` of the result matrix. Must be a floating point type. - Returns: - An array of shape `[num_spectrogram_bins, num_mel_bins]`. - Raises: - ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not - positive, `lower_edge_hertz` is negative, frequency edges are incorrectly - ordered, `upper_edge_hertz` is larger than the Nyquist frequency. - [mel]: https://en.wikipedia.org/wiki/Mel_scale - """ + Args: + num_mel_bins: Python int. How many bands in the resulting mel spectrum. + num_spectrogram_bins: An integer `Tensor`. How many bins there are in + the source spectrogram data, which is understood to be `fft_size // 2 + 1`, + i.e. the spectrogram only contains the nonredundant FFT bins. + sample_rate: An integer or float `Tensor`. Samples per second of the + input signal used to create the spectrogram. Used to figure out the + frequencies corresponding to each spectrogram bin, which dictates how they + are mapped into the mel scale. + lower_edge_hertz: Python float. Lower bound on the frequencies to be + included in the mel spectrum. This corresponds to the lower edge of the + lowest triangular band. + upper_edge_hertz: Python float. The desired top edge of the highest + frequency band. + dtype: The `DType` of the result matrix. Must be a floating point type. + Returns: + An array of shape `[num_spectrogram_bins, num_mel_bins]`. + Raises: + ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not + positive, `lower_edge_hertz` is negative, frequency edges are incorrectly + ordered, `upper_edge_hertz` is larger than the Nyquist frequency. + [mel]: https://en.wikipedia.org/wiki/Mel_scale + """ # Input validator from tensorflow/python/ops/signal/mel_ops.py#L71 if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) if lower_edge_hertz < 0.0: - raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % - lower_edge_hertz) + raise ValueError( + 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz + ) if lower_edge_hertz >= upper_edge_hertz: - raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % - (lower_edge_hertz, upper_edge_hertz)) + raise ValueError( + 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f' + % (lower_edge_hertz, upper_edge_hertz) + ) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: - raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' - 'frequency (sample_rate / 2). Got %s for sample_rate: %s' % - (upper_edge_hertz, sample_rate)) + raise ValueError( + 'upper_edge_hertz must not be larger than the Nyquist ' + 'frequency (sample_rate / 2). Got %s for sample_rate: %s' + % (upper_edge_hertz, sample_rate) + ) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = torch.linspace( - 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, - device=device)[bands_to_zero:] + 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device + )[bands_to_zero:] spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, None] # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The @@ -311,11 +323,12 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into # num_mel_bins + 2 pieces. edges = torch.linspace( - _hertz_to_mel(lower_edge_hertz), - _hertz_to_mel(upper_edge_hertz), - num_mel_bins + 2, - dtype=dtype, - device=device) + _hertz_to_mel(lower_edge_hertz), + _hertz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=dtype, + device=device, + ) # Split the triples up and reshape them into [1, num_mel_bins] tensors. lower_edge_mel = edges[:-2][None, :] @@ -325,13 +338,16 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Calculate lower and upper slopes for every spectrogram bin. # Line segments are linear in the mel domain, not Hertz. lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( - center_mel - lower_edge_mel) + center_mel - lower_edge_mel + ) upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( - upper_edge_mel - center_mel) + upper_edge_mel - center_mel + ) # Intersect the line segments with each other and zero. mel_weights_matrix = torch.minimum(lower_slopes, upper_slopes).clamp( - min=0.0, max=None) + min=0.0, max=None + ) # Re-add the zeroed lower bins we sliced out above. return F.pad(mel_weights_matrix, (0, 0, bands_to_zero, 0)) @@ -339,43 +355,50 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, def _hanning_greco(win_support, frame_size, dtype, device='cpu'): """Add a greco-style hanning window to the graph. - Note that the Hanning window in Wikipedia is not the same as the Hanning - window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia - page would indicate. Talkin's explanation was that it was like wasting two - samples to have the values at the edge of the window to be 0.0 exactly. - Args: - win_support: Number of samples for non-zero support in the window - frame_size: Total size of the window (frame_size >= win_support) - dtype: TF data type - Returns: - Tensor of size frame_size with the window to apply. - """ + Note that the Hanning window in Wikipedia is not the same as the Hanning + window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia + page would indicate. Talkin's explanation was that it was like wasting two + samples to have the values at the edge of the window to be 0.0 exactly. + Args: + win_support: Number of samples for non-zero support in the window + frame_size: Total size of the window (frame_size >= win_support) + dtype: TF data type + Returns: + Tensor of size frame_size with the window to apply. + """ if frame_size < win_support: raise ValueError( - 'Provided frame_size = {} is lower than win_support = {}'.format( - frame_size, win_support)) + 'Provided frame_size = {} is lower than win_support = {}'.format( + frame_size, win_support + ) + ) arg = torch.pi * 2.0 / (win_support) - hann = 0.5 - (0.5 * torch.cos( - arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5))) + hann = 0.5 - ( + 0.5 + * torch.cos( + arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5) + ) + ) zero_size = frame_size - win_support return F.pad(hann, (0, zero_size)) def _next_pow_of_two(x: Union[int, float]) -> int: - return int(2**np.ceil(np.log2(x))) + return int(2 ** np.ceil(np.log2(x))) class SpectrogramFrontend(nn.Module): - """Layer to convert input audio signals from time domain to frequency domain. - """ - - def __init__(self, - config: PreprocessorConfig = None, - input_scale_factor: float = 1.0, - output_log: bool = False, - dtype=torch.float32, - device='cpu'): + """Layer to convert input audio signals from time domain to frequency domain.""" + + def __init__( + self, + config: PreprocessorConfig = None, + input_scale_factor: float = 1.0, + output_log: bool = False, + dtype=torch.float32, + device='cpu', + ): super().__init__() self.config = config @@ -384,8 +407,9 @@ def __init__(self, p = self.config self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) - self._frame_size = int(round( - p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph + self._frame_size = ( + int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1 + ) # +1 for the preemph # TF-version has maximum of 512, but it's not always necessary self.fft_size = _next_pow_of_two(self._frame_size) @@ -399,23 +423,20 @@ def _hanning_window(frame_size, dtype): if frame_size % 2 == 0: # simulate periodic=True in tf.signal.hann_window return torch.hann_window( - window_length=frame_size, - periodic=True, - dtype=dtype, - device=device) + window_length=frame_size, periodic=True, dtype=dtype, device=device + ) else: return torch.hann_window( - window_length=frame_size, - periodic=False, - dtype=dtype, - device=device) + window_length=frame_size, periodic=False, dtype=dtype, device=device + ) self._window_fn = _hanning_window elif p.window_fn.upper() == 'HANNING_GRECO': # Greco-compatible hanning window def f(frame_size, dtype): return _hanning_greco( - self._frame_size - 1, frame_size, dtype, device=device) + self._frame_size - 1, frame_size, dtype, device=device + ) self._window_fn = f else: @@ -430,25 +451,31 @@ def f(frame_size, dtype): def _apply_preemphasis(self, framed_signal): p = self.config if p.preemph_htk_flavor: - return torch.cat([ - framed_signal[:, :, :1, :] * (1. - p.preemph), - (framed_signal[:, :, 1:-1, :] - - p.preemph * framed_signal[:, :, :-2, :]) - ], - dim=2) + return torch.cat( + [ + framed_signal[:, :, :1, :] * (1.0 - p.preemph), + ( + framed_signal[:, :, 1:-1, :] + - p.preemph * framed_signal[:, :, :-2, :] + ), + ], + dim=2, + ) else: - return (framed_signal[:, :, 1:, :] - - p.preemph * framed_signal[:, :, :-1, :]) + return ( + framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :] + ) def fprop_paddings(self, input_paddings): p = self.config if p.pad_end: - num_extends = _pad_end_length(input_paddings.shape[1], - self._frame_step, - self._frame_size) + num_extends = _pad_end_length( + input_paddings.shape[1], self._frame_step, self._frame_size + ) input_paddings = F.pad(input_paddings, (0, num_extends), value=1.0) x = input_paddings.unfold( - dimension=1, size=self._frame_size, step=self._frame_step) + dimension=1, size=self._frame_size, step=self._frame_step + ) return x.min(dim=2)[0] def forward(self, inputs, input_paddings): @@ -467,7 +494,8 @@ def forward(self, inputs, input_paddings): pcm_audio_chunk = inputs * self.input_scale_factor framed_signal = frame( - pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) + pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end + ) if p.preemph != 0.0: preemphasized = self._apply_preemphasis(framed_signal) @@ -497,12 +525,14 @@ def forward(self, inputs, input_paddings): class MelFilterbankFrontend(nn.Module): """Layer to compute log mel spectograms from input audio signals.""" - def __init__(self, - config: PreprocessorConfig = None, - use_divide_stream: bool = True, - per_bin_mean: Optional[float] = None, - per_bin_stddev: Optional[float] = None, - device='cpu'): + def __init__( + self, + config: PreprocessorConfig = None, + use_divide_stream: bool = True, + per_bin_mean: Optional[float] = None, + per_bin_stddev: Optional[float] = None, + device='cpu', + ): super().__init__() self.config = config @@ -513,7 +543,8 @@ def __init__(self, input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 self.stft = SpectrogramFrontend( - p, input_scale_factor=input_scale_factor, output_log=False) + p, input_scale_factor=input_scale_factor, output_log=False + ) if self.per_bin_mean is None: per_bin_mean = [0.0] * p.num_bins @@ -525,10 +556,13 @@ def __init__(self, else: per_bin_stddev = self.per_bin_stddev - self.register_buffer('_normalizer_mean', - torch.FloatTensor(per_bin_mean)[None, None, :, None]) - self.register_buffer('_normalizer_stddev', - torch.FloatTensor(per_bin_stddev)[None, None, :, None]) + self.register_buffer( + '_normalizer_mean', torch.FloatTensor(per_bin_mean)[None, None, :, None] + ) + self.register_buffer( + '_normalizer_stddev', + torch.FloatTensor(per_bin_stddev)[None, None, :, None], + ) def forward(self, inputs, input_paddings): p = self.config @@ -536,20 +570,24 @@ def forward(self, inputs, input_paddings): spect, spect_paddings = self.stft(inputs, input_paddings) mel_weights = linear_to_mel_weight_matrix( - num_mel_bins=p.num_bins, - num_spectrogram_bins=spect.shape[2], - sample_rate=p.sample_rate, - lower_edge_hertz=p.lower_edge_hertz, - upper_edge_hertz=p.upper_edge_hertz, - device=spect.device) + num_mel_bins=p.num_bins, + num_spectrogram_bins=spect.shape[2], + sample_rate=p.sample_rate, + lower_edge_hertz=p.lower_edge_hertz, + upper_edge_hertz=p.upper_edge_hertz, + device=spect.device, + ) mel_spectrogram = torch.einsum('fn,btfc->btnc', mel_weights, spect) logmel_spectrogram = torch.log( - mel_spectrogram.clamp(min=p.output_floor, max=None)) + mel_spectrogram.clamp(min=p.output_floor, max=None) + ) normalized_logmel_spectrogram = ( - (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) + logmel_spectrogram - self._normalizer_mean + ) / self._normalizer_stddev - normalized_logmel_spectrogram = torch.squeeze(normalized_logmel_spectrogram, - -1) + normalized_logmel_spectrogram = torch.squeeze( + normalized_logmel_spectrogram, -1 + ) return normalized_logmel_spectrogram, spect_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py index 11b93703e..66db657b8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py @@ -9,19 +9,21 @@ class SpecAug(nn.Module): """Layer performs masking prodecure along time and frequency axis. - The procedure is detailed in https://arxiv.org/abs/1904.08779. - This is an essential component in speech recognition models that - helps achieve better word error rates. - """ - - def __init__(self, - freq_mask_count: int = 1, - freq_mask_max_bins: int = 15, - time_mask_count: int = 1, - time_mask_max_frames: int = 50, - time_mask_max_ratio: float = 1.0, - time_masks_per_frame: float = 0.0, - use_dynamic_time_mask_max_frames: bool = False): + The procedure is detailed in https://arxiv.org/abs/1904.08779. + This is an essential component in speech recognition models that + helps achieve better word error rates. + """ + + def __init__( + self, + freq_mask_count: int = 1, + freq_mask_max_bins: int = 15, + time_mask_count: int = 1, + time_mask_max_frames: int = 50, + time_mask_max_ratio: float = 1.0, + time_masks_per_frame: float = 0.0, + use_dynamic_time_mask_max_frames: bool = False, + ): super().__init__() self.freq_mask_count = freq_mask_count @@ -35,23 +37,26 @@ def __init__(self, def next_prng_key(self, name='dropout'): return self.make_rng(name) - def _get_mask(self, - batch_size, - choose_range, - mask_size, - max_length=None, - masks_per_frame=0.0, - multiplicity=1, - max_ratio=1.0, - device='cpu'): + def _get_mask( + self, + batch_size, + choose_range, + mask_size, + max_length=None, + masks_per_frame=0.0, + multiplicity=1, + max_ratio=1.0, + device='cpu', + ): # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = max_length * torch.ones(batch_size, device=device) else: max_length = choose_range * max_ratio masked_portion = torch.rand(batch_size, multiplicity, device=device) - masked_frame_size = torch.einsum('b,bm->bm', max_length, - masked_portion).long() + masked_frame_size = torch.einsum( + 'b,bm->bm', max_length, masked_portion + ).long() # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) @@ -80,8 +85,9 @@ def _get_mask(self, # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = torch.tile( - torch.arange(multiplicity, device=device).long()[None, ...], - [batch_size, 1]) + torch.arange(multiplicity, device=device).long()[None, ...], + [batch_size, 1], + ) multiplicity_tensor = masks_per_frame * choose_range multiplicity_weights = (multiplicity_weights < multiplicity_tensor).long() pre_mask = torch.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) @@ -99,8 +105,9 @@ def _time_mask(self, inputs, length): max_ratio = self.time_mask_max_ratio # If maximum mask length is zero, do nothing. - if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or - max_ratio <= 0.0): + if ( + time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames + ) or max_ratio <= 0.0: return inputs if multiplicity == 0: return inputs @@ -112,14 +119,15 @@ def _time_mask(self, inputs, length): time_mask_max_frames = None # Create masks in time direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=length, - mask_size=time_length, - max_length=time_mask_max_frames, - masks_per_frame=self.time_masks_per_frame, - multiplicity=multiplicity, - max_ratio=max_ratio, - device=inputs.device) + batch_size, + choose_range=length, + mask_size=time_length, + max_length=time_mask_max_frames, + masks_per_frame=self.time_masks_per_frame, + multiplicity=multiplicity, + max_ratio=max_ratio, + device=inputs.device, + ) outputs = torch.einsum('bxy,bx->bxy', inputs, block_arrays) return outputs @@ -138,14 +146,15 @@ def _frequency_mask(self, inputs): choose_range = num_freq * torch.ones(batch_size, device=inputs.device) # Create masks in frequency direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=choose_range, - mask_size=num_freq, - max_length=freq_mask_max_bins, - masks_per_frame=0.0, - multiplicity=multiplicity, - max_ratio=1.0, - device=inputs.device) + batch_size, + choose_range=choose_range, + mask_size=num_freq, + max_length=freq_mask_max_bins, + masks_per_frame=0.0, + multiplicity=multiplicity, + max_ratio=1.0, + device=inputs.device, + ) outputs = torch.einsum('bxy,by->bxy', inputs, block_arrays) return outputs diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index dbeabb16c..25416682c 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -10,15 +10,12 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.librispeech_conformer import metrics -from algoperf.workloads.librispeech_conformer import workload -from algoperf.workloads.librispeech_conformer.input_pipeline import \ - LibriSpeechDataset +from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf.workloads.librispeech_conformer import metrics, workload +from algoperf.workloads.librispeech_conformer.input_pipeline import ( + LibriSpeechDataset, +) from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -27,10 +24,9 @@ class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: + def __init__( + self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True + ) -> None: super().__init__() self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path) self.use_specaug = use_specaug @@ -42,7 +38,8 @@ def __init__(self, def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -74,11 +71,13 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: activation_function_name = 'swish' model = models.ConformerEncoderDecoder( - models.ConformerConfig( - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name)) + models.ConformerConfig( + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name, + ) + ) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') models.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -97,14 +96,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -115,30 +114,33 @@ def model_fn( if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE), - dropout_rate=dropout_rate) + logits, logits_paddings = model( + inputs.to(DEVICE), input_paddings.to(DEVICE), dropout_rate=dropout_rate + ) return (logits, logits_paddings), None def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache del repeat_final_dataset del num_batches @@ -157,7 +159,7 @@ def _build_input_queue( if split == 'eval_train': indices = list(range(len(ds))) random.Random(int(data_rng[0])).shuffle(indices) - ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) + ds = torch.utils.data.Subset(ds, indices[: self.num_eval_train_examples]) sampler = None if USE_PYTORCH_DDP: @@ -168,31 +170,36 @@ def _build_input_queue( if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - ds, num_replicas=N_GPUS, rank=RANK, shuffle=True) + ds, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - ds, num_replicas=N_GPUS, rank=RANK, shuffle=False) + ds, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) dataloader = torch.utils.data.DataLoader( - ds, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4, - pin_memory=True, - drop_last=is_train) + ds, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4, + pin_memory=True, + drop_last=is_train, + ) dataloader = data_utils.cycle( - dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False) + dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False + ) return dataloader # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -206,10 +213,8 @@ def loss_fn( input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long() target_lengths = torch.einsum('bh->b', 1 - target_paddings).long() per_example_losses = self.ctc_loss( - logprobs.permute(1, 0, 2), - targets.long(), - input_lengths, - target_lengths) + logprobs.permute(1, 0, 2), targets.long(), input_lengths, target_lengths + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -220,21 +225,22 @@ def loss_fn( summed_loss = per_example_losses.sum() n_valid_examples = max(n_valid_examples, 1) return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def greedy_decode( - self, logits: spec.Tensor, - logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, logit_paddings: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: framewise_tokens = logits.max(dim=-1)[1] framewise_tokens = framewise_tokens * (1 - logit_paddings) # Add sentinel because unique_consecutive will flatten array # and then compute the unique. framewise_tokens = torch.cat( - [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1) + [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1 + ) _, indices = torch.unique_consecutive(framewise_tokens, return_inverse=True) indices -= indices.min(dim=1, keepdims=True)[0] result = torch.zeros_like(framewise_tokens) @@ -247,11 +253,12 @@ def greedy_decode( # Remove blanks (id = 0). blank_id = 0 fin_result = torch.zeros_like(result) - idxs = torch.arange( - fin_result.numel(), device=result.device).view(*fin_result.shape) - mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + idxs = torch.arange(fin_result.numel(), device=result.device).view( + *fin_result.shape + ) + mask = torch.arange(fin_result.shape[1], device=result.device).view( + 1, -1 + ) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -265,29 +272,31 @@ def sync_sd(self, params: spec.ParameterContainer) -> None: sd[k] = sd[k] / N_GPUS params.load_state_dict(sd) - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: # These iterators repeat indefinitely. - self._eval_iters[split] = ( - self._build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size)) + self._eval_iters[split] = self._build_input_queue( + data_rng, split, data_dir, global_batch_size=global_batch_size + ) total_metrics = { - 'loss': torch.tensor(0., device=DEVICE), - 'lengths': torch.tensor(0., device=DEVICE), - 'word_errors': torch.tensor(0., device=DEVICE), - 'num_words': torch.tensor(0., device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), + 'lengths': torch.tensor(0.0, device=DEVICE), + 'word_errors': torch.tensor(0.0, device=DEVICE), + 'num_words': torch.tensor(0.0, device=DEVICE), } num_batches = int(math.ceil(num_examples / global_batch_size)) if self.requires_sync_before_eval: @@ -296,48 +305,50 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) (logits, logits_padding), _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False, + ) decoded, decoded_paddings = self.greedy_decode(logits, logits_padding) targets, target_paddings = batch['targets'] word_errors, num_words = metrics.compute_wer( - decoded=decoded.cpu().numpy(), - decoded_paddings=decoded_paddings.cpu().numpy(), - targets=targets.cpu().numpy(), - target_paddings=target_paddings.cpu().numpy(), - tokenizer=self.tokenizer) + decoded=decoded.cpu().numpy(), + decoded_paddings=decoded_paddings.cpu().numpy(), + targets=targets.cpu().numpy(), + target_paddings=target_paddings.cpu().numpy(), + tokenizer=self.tokenizer, + ) loss = self.loss_fn((targets, target_paddings), (logits, logits_padding)) summed_loss = loss['summed'] lengths = loss['n_valid_examples'] batch_metrics = { - 'loss': summed_loss, - 'lengths': lengths, - 'word_errors': word_errors, - 'num_words': num_words, + 'loss': summed_loss, + 'lengths': lengths, + 'word_errors': word_errors, + 'num_words': num_words, } total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) return { - 'ctc_loss': - float(total_metrics['loss'].item() / - total_metrics['lengths'].item()), - 'wer': - float(total_metrics['word_errors'].item() / - total_metrics['num_words'].item()), + 'ctc_loss': float( + total_metrics['loss'].item() / total_metrics['lengths'].item() + ), + 'wer': float( + total_metrics['word_errors'].item() / total_metrics['num_words'].item() + ), } class LibriSpeechConformerAttentionTemperatureWorkload( - LibriSpeechConformerWorkload): - + LibriSpeechConformerWorkload +): @property def attention_temperature(self) -> float: return 1.6 @@ -352,7 +363,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): - @property def use_post_layer_norm(self) -> bool: return False @@ -367,7 +377,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): - @property def use_gelu(self) -> bool: return True diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py index de74cfe1b..d5c826575 100644 --- a/algoperf/workloads/librispeech_conformer/metrics.py +++ b/algoperf/workloads/librispeech_conformer/metrics.py @@ -15,17 +15,20 @@ def average_ctc_loss(): @flax.struct.dataclass class _Metric(metrics.Metric): """Applies `fun` and computes the average.""" + total: np.float32 weight: np.float32 @classmethod def from_model_output(cls, loss_dict, **_): return cls( - total=loss_dict['summed'], weight=loss_dict['n_valid_examples']) + total=loss_dict['summed'], weight=loss_dict['n_valid_examples'] + ) def merge(self, other): return type(self)( - total=self.total + other.total, weight=self.weight + other.weight) + total=self.total + other.total, weight=self.weight + other.weight + ) def compute(self): return self.total / self.weight @@ -74,9 +77,10 @@ def edit_distance(source, target): # possibilities and find minimum. else: distance[i][j] = 1 + min( - distance[i][j - 1], # Insert - distance[i - 1][j], # Remove - distance[i - 1][j - 1]) # Replace + distance[i][j - 1], # Insert + distance[i - 1][j], # Remove + distance[i - 1][j - 1], + ) # Replace return distance[num_source_words][num_target_words] @@ -109,17 +113,20 @@ def compute_wer(decoded, decoded_paddings, targets, target_paddings, tokenizer): return word_errors, num_words -def load_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def load_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" if model_path is None: return None with gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer @@ -128,8 +135,10 @@ def wer(tokenizer_vocab_path): @flax.struct.dataclass class WER( - metrics.CollectingMetric.from_outputs( - ('decoded', 'decoded_paddings', 'targets', 'target_paddings'))): + metrics.CollectingMetric.from_outputs( + ('decoded', 'decoded_paddings', 'targets', 'target_paddings') + ) + ): """Computes the mean average precision for a binary classifier on CPU.""" def compute(self): @@ -144,7 +153,8 @@ def compute(self): values['decoded_paddings'], values['targets'].astype(np.int32), values['target_paddings'], - tokenizer) + tokenizer, + ) return word_errors / num_words @@ -153,4 +163,5 @@ def compute(self): def get_metrics_bundle(tokenizer_vocab_path): return metrics.Collection.create( - ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path)) + ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path) + ) diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 94f01dd97..791270719 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -5,7 +5,6 @@ class BaseLibrispeechWorkload(spec.Workload): - _num_outputs: int = 1024 @property @@ -25,8 +24,9 @@ def use_gelu(self) -> bool: def attention_temperature(self) -> float: raise NotImplementedError - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/wer'] < self.validation_target_value @property @@ -53,8 +53,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index 262fc1a95..225852b28 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -10,17 +10,19 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import jax +import jax.numpy as jnp from flax import linen as nn from flax import struct -import jax from jax.experimental import rnn -import jax.numpy as jnp from algoperf.jax_utils import Dropout -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + librispeech_preprocessor as preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + spectrum_augmenter, +) Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -37,6 +39,7 @@ @struct.dataclass class DeepspeechConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 dtype: Any = jnp.float32 encoder_dim: int = 512 @@ -68,6 +71,7 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + config: DeepspeechConfig @nn.compact @@ -76,38 +80,40 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=1, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh - )(outputs, output_paddings, train) + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=1, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh, + )(outputs, output_paddings, train) outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=config.encoder_dim, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh)(outputs, output_paddings, train) + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=config.encoder_dim, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh, + )(outputs, output_paddings, train) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) - outputs = Dropout( - rate=dropout_rate, deterministic=not train)( - outputs, rate=dropout_rate) + outputs = Dropout(rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate + ) return outputs, output_paddings @@ -119,6 +125,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -131,24 +138,26 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings, train): # Computing strided convolution to subsample inputs. feature_group_count = inputs.shape[3] // self.filter_shape[2] outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) @@ -163,48 +172,49 @@ def __call__(self, inputs, paddings, train): pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * (1.0 - - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): """Feedforward block of conformer layer.""" + config: DeepspeechConfig @nn.compact - def __call__(self, - inputs, - input_paddings=None, - train=False, - dropout_rate=DROPOUT_RATE): + def __call__( + self, inputs, input_paddings=None, train=False, dropout_rate=DROPOUT_RATE + ): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config if config.layernorm_everywhere: inputs = LayerNorm(config.encoder_dim)(inputs) else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) - inputs = nn.Dense( + inputs = BatchNorm( config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon, + )(inputs, input_paddings, train) + inputs = nn.Dense( + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) if config.use_tanh: inputs = nn.tanh(inputs) else: @@ -212,7 +222,8 @@ def __call__(self, inputs *= padding_mask inputs = Dropout(rate=dropout_rate)( - inputs, deterministic=not train, rate=dropout_rate) + inputs, deterministic=not train, rate=dropout_rate + ) return inputs @@ -227,6 +238,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -240,7 +252,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -258,6 +270,7 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + encoder_dim: int = 0 dtype: Any = jnp.float32 batch_norm_momentum: float = 0.999 @@ -267,14 +280,12 @@ def setup(self): dim = self.encoder_dim dtype = self.dtype - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @@ -303,7 +314,8 @@ def __call__(self, inputs, input_paddings=None, train=False): mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True + ) sum_v = jax.lax.psum(sum_v, axis_name='batch') count_v = jax.lax.psum(count_v, axis_name='batch') @@ -340,15 +352,14 @@ class CudnnLSTM(nn.Module): @nn.compact def __call__( - self, - inputs: Array, - segmentation_mask: Optional[Array] = None, - return_carry: Optional[bool] = None, - deterministic: bool = False, - initial_states: Optional[Tuple[Array, Array]] = None, - use_cuda: bool = True, + self, + inputs: Array, + segmentation_mask: Optional[Array] = None, + return_carry: Optional[bool] = None, + deterministic: bool = False, + initial_states: Optional[Tuple[Array, Array]] = None, + use_cuda: bool = True, ) -> Union[Array, Tuple[Array, Carry]]: - if jax.devices()[0].platform != 'gpu': use_cuda = False @@ -358,22 +369,22 @@ def __call__( dropout = 0.0 if deterministic else self.dropout_rate weights = self.param( - 'weights', - rnn.init_lstm_weight, - input_size, - self.features, - self.num_layers, - self.bidirectional, + 'weights', + rnn.init_lstm_weight, + input_size, + self.features, + self.num_layers, + self.bidirectional, ) if initial_states is None: h_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, ) c_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, ) else: h_0, c_0 = initial_states @@ -385,20 +396,35 @@ def __call__( if use_cuda: y, h, c = rnn.lstm( - x=inputs, h_0=h_0, c_0=c_0, weights=weights, - seq_lengths=seq_lengths, input_size=input_size, - hidden_size=self.features, num_layers=self.num_layers, - dropout=dropout, bidirectional=self.bidirectional, + x=inputs, + h_0=h_0, + c_0=c_0, + weights=weights, + seq_lengths=seq_lengths, + input_size=input_size, + hidden_size=self.features, + num_layers=self.num_layers, + dropout=dropout, + bidirectional=self.bidirectional, ) else: weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights( - weights, input_size) + weights, input_size + ) y, h, c = rnn.lstm_ref( - x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh, - b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths, - input_size=input_size, hidden_size=self.features, - num_layers=self.num_layers, dropout=dropout, - bidirectional=self.bidirectional, + x=inputs, + h_0=h_0, + c_0=c_0, + W_ih=weight_ih, + W_hh=weight_hh, + b_ih=bias_ih, + b_hh=bias_hh, + seq_lengths=seq_lengths, + input_size=input_size, + hidden_size=self.features, + num_layers=self.num_layers, + dropout=dropout, + bidirectional=self.bidirectional, ) if return_carry: @@ -408,21 +434,22 @@ def __call__( @nn.nowrap def unpack_weights( - self, weights: Array, input_size: int + self, weights: Array, input_size: int ) -> Tuple[ - Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]: + Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array] + ]: return jax.experimental.rnn.unpack_lstm_weights( - weights, - input_size, - self.features, - self.num_layers, - self.bidirectional, + weights, + input_size, + self.features, + self.num_layers, + self.bidirectional, ) class BatchRNN(nn.Module): - """Implements a single deepspeech encoder layer. - """ + """Implements a single deepspeech encoder layer.""" + config: DeepspeechConfig @nn.compact @@ -432,16 +459,17 @@ def __call__(self, inputs, input_paddings, train): if config.layernorm_everywhere: inputs = LayerNorm(config.encoder_dim)(inputs) else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) + inputs = BatchNorm( + config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon, + )(inputs, input_paddings, train) output = CudnnLSTM( - features=config.encoder_dim // 2, - bidirectional=config.bidirectional, - num_layers=1)(inputs, input_paddings) + features=config.encoder_dim // 2, + bidirectional=config.bidirectional, + num_layers=1, + )(inputs, input_paddings) return output @@ -453,18 +481,19 @@ class Deepspeech(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: DeepspeechConfig def setup(self): config = self.config self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) @nn.compact @@ -477,10 +506,10 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs, - output_paddings) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. @@ -488,9 +517,9 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.specaug(outputs, output_paddings) # Subsample input by a factor of 4 by performing strided convolutions. - outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train, - dropout_rate=dropout_rate) + outputs, output_paddings = Subsample(config=config)( + outputs, output_paddings, train, dropout_rate=dropout_rate + ) # Run the lstm layers. for _ in range(config.num_lstm_layers): @@ -502,19 +531,21 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): for _ in range(config.num_ffn_layers): if config.enable_residual_connections: outputs = outputs + FeedForwardModule(config=self.config)( - outputs, output_paddings, train) + outputs, output_paddings, train + ) else: outputs = FeedForwardModule(config=self.config)( - outputs, output_paddings, train, dropout_rate=dropout_rate) + outputs, output_paddings, train, dropout_rate=dropout_rate + ) # Run the decoder which in this case is a trivial projection layer. if config.enable_decoder_layer_norm: outputs = LayerNorm(config.encoder_dim)(outputs) outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 81a56db72..b93934abf 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,31 +1,29 @@ import functools from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload +from algoperf import param_utils, spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload, +) from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: - """Deepspeech model init function. - """ + """Deepspeech model init function.""" model_config = models.DeepspeechConfig( - use_specaug=self.use_specaug, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] @@ -34,12 +32,16 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, _ = jax.random.split(rng, 2) - variables = model_init_fn({ + variables = model_init_fn( + { 'params': params_rng, - }, *fake_input_batch) + }, + *fake_input_batch, + ) - model_state = variables[ - 'batch_stats'] if not self.layernorm_everywhere else {} + model_state = ( + variables['batch_stats'] if not self.layernorm_everywhere else {} + ) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -48,36 +50,34 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return params, model_state def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None, - dropout_rate: Optional[bool] = models.DROPOUT_RATE, + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[bool] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats'], - dropout_rate=dropout_rate) + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout': rng}, + mutable=['batch_stats'], + dropout_rate=dropout_rate, + ) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) + variables, inputs, input_paddings, train=False, mutable=False + ) return (logits, logit_paddings), model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -126,7 +126,6 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): - @property def use_tanh(self) -> bool: return True @@ -141,7 +140,6 @@ def test_target_value(self) -> float: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): - @property def enable_residual_connections(self) -> bool: return False @@ -155,9 +153,9 @@ def test_target_value(self) -> float: return 0.079297 -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload - ): - +class LibriSpeechDeepSpeechNormAndSpecAugWorkload( + LibriSpeechDeepSpeechWorkload +): @property def eval_batch_size(self) -> int: return 128 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 3d8c000e1..ddb7b5c37 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -11,10 +11,12 @@ import torch.distributed.nn as dist_nn import torch.nn.functional as F -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import ( + preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import ( + SpecAug, +) USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ DROPOUT_RATE = 0.1 @@ -23,6 +25,7 @@ @dataclass class DeepspeechConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 encoder_dim: int = 512 num_lstm_layers: int = 6 @@ -47,7 +50,6 @@ class DeepspeechConfig: class LayerNorm(nn.Module): - def __init__(self, dim, epsilon=1e-6): super().__init__() self.dim = dim @@ -61,14 +63,13 @@ def forward(self, x): var = x.var(dim=-1, unbiased=False, keepdims=True) normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) - normed_x *= (1 + self.scale) + normed_x *= 1 + self.scale normed_x += self.bias return normed_x class Subsample(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() encoder_dim = config.encoder_dim @@ -76,11 +77,13 @@ def __init__(self, config: DeepspeechConfig): self.encoder_dim = encoder_dim self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) + input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh + ) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, - output_channels=encoder_dim, - use_tanh=config.use_tanh) + input_channels=encoder_dim, + output_channels=encoder_dim, + use_tanh=config.use_tanh, + ) self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) @@ -93,9 +96,9 @@ def forward(self, inputs, input_paddings, dropout_rate): outputs, output_paddings = self.conv2(outputs, output_paddings) batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) + outputs = outputs.permute(0, 2, 3, 1).reshape( + batch_size, subsampled_lengths, subsampled_dims * channels + ) outputs = self.lin(outputs) outputs = F.dropout(outputs, dropout_rate, training=self.training) @@ -104,15 +107,16 @@ def forward(self, inputs, input_paddings, dropout_rate): class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME', - batch_norm_momentum: float = 0.999, - batch_norm_epsilon: float = 0.001, - use_tanh: bool = False): + def __init__( + self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + batch_norm_momentum: float = 0.999, + batch_norm_epsilon: float = 0.001, + use_tanh: bool = False, + ): super().__init__() self.input_channels = input_channels @@ -123,7 +127,8 @@ def __init__(self, self.filter_shape = (output_channels, input_channels, 3, 3) self.kernel = nn.Parameter( - nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + nn.init.xavier_uniform_(torch.empty(*self.filter_shape)) + ) self.bias = nn.Parameter(torch.zeros(output_channels)) self.use_tanh = use_tanh @@ -154,12 +159,13 @@ def forward(self, inputs, paddings): else: in_ = inputs outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups, + ) if self.use_tanh: outputs = F.tanh(outputs) @@ -170,21 +176,24 @@ def forward(self, inputs, paddings): stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = F.conv1d( - input=torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) + input=torch.cat( + [ + paddings[:, None, :], + torch.zeros( + size=(paddings.shape[0], 1, pad_len), device=paddings.device + ), ], - dim=2), - weight=torch.ones([1, 1, 1], device=paddings.device), - stride=self.filter_stride[:1]) + dim=2, + ), + weight=torch.ones([1, 1, 1], device=paddings.device), + stride=self.filter_stride[:1], + ) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) return outputs, out_padding class FeedForwardModule(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config @@ -193,9 +202,10 @@ def __init__(self, config: DeepspeechConfig): self.normalization_layer = LayerNorm(config.encoder_dim) else: self.bn_normalization_layer = BatchNorm( - dim=config.encoder_dim, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon) + dim=config.encoder_dim, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + ) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) def forward(self, inputs, input_paddings, dropout_rate): @@ -220,7 +230,6 @@ def forward(self, inputs, input_paddings, dropout_rate): class BatchNorm(nn.Module): - def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): super().__init__() running_mean = torch.zeros(dim) @@ -235,8 +244,8 @@ def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): self.dim = dim def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH + # inputs: NHD + # padding: NH mask = 1 - input_paddings[:, :, None] if self.training: count = mask.sum() @@ -253,9 +262,11 @@ def forward(self, inputs, input_paddings): var = sum_ / count self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() + self.momentum + ) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() + self.momentum + ) * var.detach() else: mean = self.running_mean var = self.running_var @@ -266,7 +277,6 @@ def forward(self, inputs, input_paddings): class BatchRNN(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config @@ -278,19 +288,23 @@ def __init__(self, config: DeepspeechConfig): if config.layernorm_everywhere: self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.bn_normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + self.bn_normalization_layer = BatchNorm( + config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon, + ) if bidirectional: self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size // 2, - bidirectional=True, - batch_first=True) + input_size=input_size, + hidden_size=hidden_size // 2, + bidirectional=True, + batch_first=True, + ) else: self.lstm = nn.LSTM( - input_size=input_size, hidden_size=hidden_size, batch_first=True) + input_size=input_size, hidden_size=hidden_size, batch_first=True + ) def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: @@ -299,50 +313,59 @@ def forward(self, inputs, input_paddings): inputs = self.bn_normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( - inputs, lengths, batch_first=True, enforce_sorted=False) + inputs, lengths, batch_first=True, enforce_sorted=False + ) packed_outputs, _ = self.lstm(packed_inputs) outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( - packed_outputs, batch_first=True) + packed_outputs, batch_first=True + ) if outputs.shape[1] < inputs.shape[1]: - outputs = torch.cat([ + outputs = torch.cat( + [ outputs, torch.zeros( - size=(outputs.shape[0], - inputs.shape[1] - outputs.shape[1], - outputs.shape[2]), - device=outputs.device) - ], - dim=1) + size=( + outputs.shape[0], + inputs.shape[1] - outputs.shape[1], + outputs.shape[2], + ), + device=outputs.device, + ), + ], + dim=1, + ) return outputs class DeepspeechEncoderDecoder(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) preprocessing_config = preprocessor.PreprocessorConfig() self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + ) self.subsample = Subsample(config=config) self.lstms = nn.ModuleList( - [BatchRNN(config) for _ in range(config.num_lstm_layers)]) + [BatchRNN(config) for _ in range(config.num_lstm_layers)] + ) self.ffns = nn.ModuleList( - [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) + [FeedForwardModule(config) for _ in range(config.num_ffn_layers)] + ) if config.enable_decoder_layer_norm: self.ln = LayerNorm(config.encoder_dim) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 0b9ce1e3c..672f3440f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,18 +3,19 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - initialize -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechEncoderDecoder +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + initialize, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechConfig, + DeepspeechEncoderDecoder, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -22,19 +23,20 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( - DeepspeechConfig( - use_specaug=self.use_specaug, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count)).eval() + DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) + ).eval() self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') # Run model once to initialize lazy layers. t = MAX_INPUT_LENGTH @@ -55,24 +57,26 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return model, None def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: # override super method, changing only the default dropout_rate # pylint: disable=useless-parent-delegation - return super().model_fn(params, - augmented_and_preprocessed_input_batch, - model_state, - mode, - rng, - update_batch_norm, - dropout_rate) + return super().model_fn( + params, + augmented_and_preprocessed_input_batch, + model_state, + mode, + rng, + update_batch_norm, + dropout_rate, + ) def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] @@ -120,7 +124,6 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): - @property def use_tanh(self) -> bool: return True @@ -135,7 +138,6 @@ def test_target_value(self) -> float: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): - @property def enable_residual_connections(self) -> bool: return False @@ -149,9 +151,9 @@ def test_target_value(self) -> float: return 0.079297 -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload - ): - +class LibriSpeechDeepSpeechNormAndSpecAugWorkload( + LibriSpeechDeepSpeechWorkload +): @property def eval_batch_size(self) -> int: return 128 diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 27bd9ae54..0d192cbf5 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -3,20 +3,18 @@ import functools from typing import Any, Dict, Optional, Tuple -from flax import jax_utils -from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from flax import linen as nn +from jax import lax -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): - @nn.compact def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: del train @@ -31,12 +29,12 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() - initial_params = self._model.init({'params': rng}, init_val, - train=True)['params'] + initial_params = self._model.init({'params': rng}, init_val, train=True)[ + 'params' + ] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(initial_params), None @@ -45,31 +43,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits_batch = self._model.apply( - {'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - train=train) + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + train=train, + ) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -79,7 +80,8 @@ def loss_fn( one_hot_targets = jax.nn.one_hot(label_batch, 10) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + smoothed_targets * nn.log_softmax(logits_batch), axis=-1 + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -88,41 +90,45 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) accuracy = jnp.sum( - (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) + (jnp.argmax(logits, axis=-1) == batch['targets']) * weights + ) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py index 780e1bca0..ca861d551 100644 --- a/algoperf/workloads/mnist/mnist_pytorch/workload.py +++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py @@ -20,18 +20,20 @@ class _Model(nn.Module): - def __init__(self) -> None: super().__init__() input_size = 28 * 28 num_hidden = 128 num_classes = 10 self.net = nn.Sequential( - OrderedDict([('layer1', - torch.nn.Linear(input_size, num_hidden, bias=True)), - ('layer1_sig', torch.nn.Sigmoid()), - ('layer2', - torch.nn.Linear(num_hidden, num_classes, bias=True))])) + OrderedDict( + [ + ('layer1', torch.nn.Linear(input_size, num_hidden, bias=True)), + ('layer1_sig', torch.nn.Sigmoid()), + ('layer2', torch.nn.Linear(num_hidden, num_classes, bias=True)), + ] + ) + ) def reset_parameters(self) -> None: for m in self.net.modules(): @@ -44,16 +46,16 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache if N_GPUS != 0: per_device_batch_size = int(global_batch_size / N_GPUS) @@ -63,22 +65,27 @@ def _build_input_queue( # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset, + ) while True: if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) + batch['inputs'], dtype=torch.float32, device=DEVICE + ) targets = torch.as_tensor( - batch['targets'], dtype=torch.long, device=DEVICE) + batch['targets'], dtype=torch.long, device=DEVICE + ) if 'weights' in batch: weights = torch.as_tensor( - batch['weights'], dtype=torch.bool, device=DEVICE) + batch['weights'], dtype=torch.bool, device=DEVICE + ) else: weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE) # Send batch to other devices when using DDP. @@ -94,34 +101,37 @@ def _build_input_queue( targets = targets.view(-1, *targets.shape[2:]) weights = weights.view(-1, *weights.shape[2:]) else: - inputs = torch.empty((N_GPUS, per_device_batch_size, 28, 28, 1), - dtype=torch.float32, - device=DEVICE) + inputs = torch.empty( + (N_GPUS, per_device_batch_size, 28, 28, 1), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(inputs, src=0) inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.long, - device=DEVICE) + targets = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.long, device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] - weights = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.bool, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.bool, device=DEVICE + ) dist.broadcast(weights, src=0) weights = weights[RANK] batch = { - 'inputs': inputs.permute(0, 3, 1, 2), - 'targets': targets, - 'weights': weights, + 'inputs': inputs.permute(0, 3, 1, 2), + 'targets': targets, + 'weights': weights, } yield batch def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate @@ -149,13 +159,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['net.layer2.weight', 'net_layer2.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -163,8 +174,8 @@ def model_fn( if mode == spec.ForwardPassMode.EVAL: model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) @@ -173,11 +184,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -185,10 +197,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -197,25 +210,27 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: @@ -227,8 +242,8 @@ def _eval_model( return {'accuracy': accuracy, 'loss': summed_loss} def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index f53aadd0b..20aa975ae 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -23,16 +23,17 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor: def _build_mnist_dataset( - data_rng: jax.random.PRNGKey, - num_train_examples: int, - num_validation_examples: int, - train_mean: float, - train_stddev: float, - split: str, - data_dir: str, - global_batch_size: int, - cache: bool = False, - repeat_final_dataset: bool = True) -> Iterator[Dict[str, spec.Tensor]]: + data_rng: jax.random.PRNGKey, + num_train_examples: int, + num_validation_examples: int, + train_mean: float, + train_stddev: float, + split: str, + data_dir: str, + global_batch_size: int, + cache: bool = False, + repeat_final_dataset: bool = True, +) -> Iterator[Dict[str, spec.Tensor]]: shuffle = split in ['train', 'eval_train'] assert num_train_examples + num_validation_examples == 60000 if shuffle: @@ -42,12 +43,14 @@ def _build_mnist_dataset( else: tfds_split = 'test' ds = tfds.load( - 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir) + 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir + ) ds = ds.map( - lambda x: { - 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + lambda x: { + 'inputs': _normalize(x['image'], train_mean, train_stddev), + 'targets': x['label'], + } + ) is_train = split == 'train' if cache: @@ -62,22 +65,23 @@ def _build_mnist_dataset( ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return iter(ds) class BaseMnistWorkload(spec.Workload): - @property def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -104,8 +108,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -138,31 +143,33 @@ def eval_period_time_sec(self) -> int: @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches ds = _build_mnist_dataset( - data_rng=data_rng, - num_train_examples=self.num_train_examples, - num_validation_examples=self.num_validation_examples, - train_mean=self.train_mean, - train_stddev=self.train_stddev, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=cache, - repeat_final_dataset=repeat_final_dataset) + data_rng=data_rng, + num_train_examples=self.num_train_examples, + num_validation_examples=self.num_validation_examples, + train_mean=self.train_mean, + train_stddev=self.train_stddev, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + ) return ds @property @@ -173,49 +180,52 @@ def step_hint(self) -> int: return 7813 def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: raise NotImplementedError - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=True, - repeat_final_dataset=True) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=True, + repeat_final_dataset=True, + ) total_metrics = { - 'accuracy': 0., - 'loss': 0., + 'accuracy': 0.0, + 'loss': 0.0, } num_batches = int(math.ceil(num_examples / global_batch_size)) num_devices = max(torch.cuda.device_count(), jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) per_device_model_rngs = prng.split(model_rng, num_devices) - batch_metrics = self._eval_model(params, - batch, - model_state, - per_device_model_rngs) + batch_metrics = self._eval_model( + params, batch, model_state, per_device_model_rngs + ) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..79a4ddc4a 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -14,10 +14,10 @@ AVG_EDGES_PER_GRAPH = 56 TFDS_SPLIT_NAME = { - 'train': 'train', - 'eval_train': 'train', - 'validation': 'validation', - 'test': 'test', + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', + 'test': 'test', } @@ -33,11 +33,12 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng) dataset = tfds.load( - 'ogbg_molpcba:0.1.3', - split=TFDS_SPLIT_NAME[split], - shuffle_files=should_shuffle, - read_config=read_config, - data_dir=data_dir) + 'ogbg_molpcba:0.1.3', + split=TFDS_SPLIT_NAME[split], + shuffle_files=should_shuffle, + read_config=read_config, + data_dir=data_dir, + ) if should_shuffle: dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15) @@ -62,16 +63,17 @@ def _to_jraph(example): receivers = edge_index[:, 1] return jraph.GraphsTuple( - n_node=num_nodes, - n_edge=np.array([len(edge_index) * 2]), - nodes=node_feat, - edges=np.concatenate([edge_feat, edge_feat]), - # Make the edges bidirectional - senders=np.concatenate([senders, receivers]), - receivers=np.concatenate([receivers, senders]), - # Keep the labels with the graph for batching. They will be removed - # in the processed batch. - globals=np.expand_dims(labels, axis=0)) + n_node=num_nodes, + n_edge=np.array([len(edge_index) * 2]), + nodes=node_feat, + edges=np.concatenate([edge_feat, edge_feat]), + # Make the edges bidirectional + senders=np.concatenate([senders, receivers]), + receivers=np.concatenate([receivers, senders]), + # Keep the labels with the graph for batching. They will be removed + # in the processed batch. + globals=np.expand_dims(labels, axis=0), + ) def _get_weights_by_nan_and_padding(labels, padding_mask): @@ -123,10 +125,9 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): max_n_graphs = per_device_batch_size jraph_iter = map(_to_jraph, dataset_iter) - batched_iter = jraph.dynamically_batch(jraph_iter, - max_n_nodes + 1, - max_n_edges, - max_n_graphs + 1) + batched_iter = jraph.dynamically_batch( + jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1 + ) count = 0 graphs_shards = [] @@ -141,7 +142,8 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): graph = batched_graph._replace(globals={}) replaced_labels, weights = _get_weights_by_nan_and_padding( - labels, jraph.get_graph_padding_mask(graph)) + labels, jraph.get_graph_padding_mask(graph) + ) graphs_shards.append(graph) labels_shards.append(replaced_labels) @@ -156,9 +158,9 @@ def f(x): labels_shards = f(labels_shards) weights_shards = f(weights_shards) yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards, + 'inputs': graphs_shards, + 'targets': labels_shards, + 'weights': weights_shards, } count = 0 @@ -170,5 +172,6 @@ def f(x): def get_dataset_iter(split, data_rng, data_dir, global_batch_size): shuffle = split in ['train', 'eval_train'] ds = _load_dataset( - split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir) + split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir + ) return _get_batch_iterator(iter(ds), global_batch_size) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 55f83d905..668501788 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -16,10 +16,9 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() -def predictions_match_labels(*, - logits: jnp.ndarray, - labels: jnp.ndarray, - **kwargs) -> jnp.ndarray: +def predictions_match_labels( + *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs +) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. preds = logits > 0 @@ -28,7 +27,8 @@ def predictions_match_labels(*, @flax.struct.dataclass class MeanAveragePrecision( - metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): + metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask')) +): """Computes the mean average precision (mAP) over different tasks.""" def compute(self): @@ -62,7 +62,8 @@ def compute(self): if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: is_labeled = mask[:, task] average_precisions[task] = average_precision_score( - labels[is_labeled, task], probs[is_labeled, task]) + labels[is_labeled, task], probs[is_labeled, task] + ) # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. if np.isnan(average_precisions).all(): diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 8524bb60e..db1ca416c 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,9 +2,9 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Tuple -from flax import linen as nn import jax.numpy as jnp import jraph +from flax import linen as nn from algoperf.jax_utils import Dropout @@ -12,7 +12,6 @@ def _make_embed(latent_dim, name): - def make_fn(inputs): return nn.Dense(features=latent_dim, name=name)(inputs) @@ -29,9 +28,9 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = Dropout( - rate=dropout_rate, deterministic=not train)( - x, rate=dropout_rate) + x = Dropout(rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate + ) return x return make_fn @@ -42,6 +41,7 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ + num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) @@ -50,13 +50,14 @@ class GNN(nn.Module): @nn.compact def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): - graph = graph._replace( - globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) + globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]) + ) embedder = jraph.GraphMapFeatures( - embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), - embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) + embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), + embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'), + ) graph = embedder(graph) if self.activation_fn_name == 'relu': @@ -67,25 +68,30 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): activation_fn = nn.silu else: raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') + f'Invalid activation function name: {self.activation_fn_name}' + ) for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp( - self.hidden_dims, - activation_fn=activation_fn, - train=train, - dropout_rate=dropout_rate), - update_node_fn=_make_mlp( - self.hidden_dims, - activation_fn=activation_fn, - train=train, - dropout_rate=dropout_rate), - update_global_fn=_make_mlp( - self.hidden_dims, - activation_fn=activation_fn, - train=train, - dropout_rate=dropout_rate)) + update_edge_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + update_node_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + update_global_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + ) graph = net(graph) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index 0535aea83..8471fcdcc 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -1,39 +1,40 @@ """OGBG workload implemented in Jax.""" + import functools from typing import Any, Dict, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import jraph import optax +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.ogbg import metrics from algoperf.workloads.ogbg.ogbg_jax import models from algoperf.workloads.ogbg.workload import BaseOgbgWorkload class OgbgWorkload(BaseOgbgWorkload): - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, params_rng = jax.random.split(rng, 2) self._model = models.GNN( - self._num_outputs, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + ) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, self._num_outputs)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, self._num_outputs)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0]), + ) params = init_fn({'params': params_rng}, fake_batch) params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) @@ -44,40 +45,45 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: raise ValueError( - f'Expected model_state to be None, received {model_state}.') + f'Expected model_state to be None, received {model_state}.' + ) train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train, - dropout_rate=dropout_rate) + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None def _binary_cross_entropy_with_mask( - self, - labels: jnp.ndarray, - logits: jnp.ndarray, - mask: jnp.ndarray, - label_smoothing: float = 0.0) -> jnp.ndarray: + self, + labels: jnp.ndarray, + logits: jnp.ndarray, + mask: jnp.ndarray, + label_smoothing: float = 0.0, + ) -> jnp.ndarray: """Binary cross entropy loss for logits, with masked elements.""" if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens raise ValueError( - f'Shape mismatch between logits ({logits.shape}), targets ' - f'({labels.shape}), and weights ({mask.shape}).') + f'Shape mismatch between logits ({logits.shape}), targets ' + f'({labels.shape}), and weights ({mask.shape}).' + ) if len(logits.shape) != 2: raise ValueError(f'Rank of logits ({logits.shape}) must be 2.') @@ -93,26 +99,31 @@ def _binary_cross_entropy_with_mask( positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) - losses = relu_logits - (logits * smoothed_labels) + ( - jnp.log(1 + jnp.exp(-abs_logits))) - return jnp.where(mask, losses, 0.) + losses = ( + relu_logits + - (logits * smoothed_labels) + + (jnp.log(1 + jnp.exp(-abs_logits))) + ) + return jnp.where(mask, losses, 0.0) def _eval_metric(self, labels, logits, masks): loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( - loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + loss=loss['per_example'], logits=logits, labels=labels, mask=masks + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples total_metrics = total_metrics.reduce() @@ -120,7 +131,6 @@ def _normalize_eval_metrics( class OgbgGeluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -136,7 +146,6 @@ def test_target_value(self) -> float: class OgbgSiluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -152,7 +161,6 @@ def test_target_value(self) -> float: class OgbgModelSizeWorkload(OgbgWorkload): - @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index 8a40bef58..a69bc6ee1 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -4,13 +4,12 @@ from typing import Callable, Optional, Tuple import jax.tree_util as tree -from jraph import GraphsTuple import torch +from jraph import GraphsTuple from torch import nn from algoperf import init_utils -from algoperf.pytorch_utils import CustomDropout -from algoperf.pytorch_utils import SequentialWithDropout +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout DROPOUT_RATE = 0.1 @@ -19,8 +18,9 @@ def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" layers = SequentialWithDropout() for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', - nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module( + f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim) + ) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) layers.add_module(f'dropout_{i}', CustomDropout()) @@ -35,12 +35,14 @@ class GNN(nn.Module): variables. The final prediction will be encoded in the globals. """ - def __init__(self, - num_outputs: int = 128, - activation_fn_name: str = 'relu', - latent_dim: int = 256, - hidden_dims: Tuple[int] = (256,), - num_message_passing_steps: int = 5) -> None: + def __init__( + self, + num_outputs: int = 128, + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5, + ) -> None: super().__init__() self.latent_dim = latent_dim self.hidden_dims = hidden_dims @@ -58,7 +60,8 @@ def __init__(self, activation_fn = nn.SiLU else: raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') + f'Invalid activation function name: {self.activation_fn_name}' + ) graph_network_layers = [] for st in range(self.num_message_passing_steps): @@ -66,8 +69,9 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[ - -1] * 2 + self.num_outputs + in_dim_node_fn = ( + self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs + ) last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 @@ -75,32 +79,36 @@ def __init__(self, last_in_dim = self.hidden_dims[-1] * 3 graph_network_layers.append( - GraphNetwork( - update_edge_fn=_make_mlp(in_dim_edge_fn, - self.hidden_dims, - activation_fn), - update_node_fn=_make_mlp(in_dim_node_fn, - self.hidden_dims, - activation_fn), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - activation_fn))) + GraphNetwork( + update_edge_fn=_make_mlp( + in_dim_edge_fn, self.hidden_dims, activation_fn + ), + update_node_fn=_make_mlp( + in_dim_node_fn, self.hidden_dims, activation_fn + ), + update_global_fn=_make_mlp( + last_in_dim, self.hidden_dims, activation_fn + ), + ) + ) self.graph_network = SequentialWithDropout(*graph_network_layers) self.decoder = nn.Linear( - in_features=self.hidden_dims[-1], out_features=self.num_outputs) + in_features=self.hidden_dims[-1], out_features=self.num_outputs + ) for m in self.modules(): if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, - graph: GraphsTuple, - dropout_rate: float = DROPOUT_RATE) -> torch.Tensor: - + def forward( + self, graph: GraphsTuple, dropout_rate: float = DROPOUT_RATE + ) -> torch.Tensor: graph = graph._replace( - globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], - device=graph.n_node.device)) + globals=torch.zeros( + [graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device + ) + ) graph = graph._replace(nodes=self.node_embedder(graph.nodes)) graph = graph._replace(edges=self.edge_embedder(graph.edges)) @@ -138,10 +146,12 @@ class GraphNetwork(nn.Module): A method that applies the configured GraphNetwork. """ - def __init__(self, - update_edge_fn: Optional[Callable] = None, - update_node_fn: Optional[Callable] = None, - update_global_fn: Optional[Callable] = None) -> None: + def __init__( + self, + update_edge_fn: Optional[Callable] = None, + update_node_fn: Optional[Callable] = None, + update_global_fn: Optional[Callable] = None, + ) -> None: super().__init__() self.update_edge_fn = update_edge_fn self.update_node_fn = update_node_fn @@ -168,35 +178,41 @@ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: nodes, edges, receivers, senders, globals_, n_node, n_edge = graph sum_n_node = tree.tree_leaves(nodes)[0].shape[0] if not tree.tree_all( - tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): + tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes) + ): raise ValueError( - 'All node arrays in nest must contain the same number of nodes.') + 'All node arrays in nest must contain the same number of nodes.' + ) sent_attributes = tree.tree_map(lambda n: n[senders], nodes) received_attributes = tree.tree_map(lambda n: n[receivers], nodes) # Here we scatter the global features to the corresponding edges, # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) + lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_ + ) if self.update_edge_fn: edge_fn_inputs = torch.cat( - [edges, sent_attributes, received_attributes, global_edge_attributes], - dim=-1) + [edges, sent_attributes, received_attributes, global_edge_attributes], + dim=-1, + ) edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) if self.update_node_fn: sent_attributes = tree.tree_map( - lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) + lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges + ) received_attributes = tree.tree_map( - lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), - edges) + lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), edges + ) # Here we scatter the global features to the corresponding nodes, # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_ + ) node_fn_inputs = torch.cat( - [nodes, sent_attributes, received_attributes, global_attributes], - dim=-1) + [nodes, sent_attributes, received_attributes, global_attributes], dim=-1 + ) nodes = self.update_node_fn(node_fn_inputs, dropout_rate) if self.update_global_fn: @@ -210,31 +226,37 @@ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) # We use the aggregation function to pool the nodes/edges per graph. node_attributes = tree.tree_map( - lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) + lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes + ) edge_attributes = tree.tree_map( - lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) + lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges + ) # These pooled nodes are the inputs to the global update fn. - global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], - dim=-1) + global_fn_inputs = torch.cat( + [node_attributes, edge_attributes, globals_], dim=-1 + ) globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) return GraphsTuple( - nodes=nodes, - edges=edges, - receivers=receivers, - senders=senders, - globals=globals_, - n_node=n_node, - n_edge=n_edge) + nodes=nodes, + edges=edges, + receivers=receivers, + senders=senders, + globals=globals_, + n_node=n_node, + n_edge=n_edge, + ) # Forked from # github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. -def scatter_sum(src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None) -> torch.Tensor: +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: r""" | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index a45a93668..f72ff5141 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -1,16 +1,15 @@ """OGBG workload implemented in PyTorch.""" + import contextlib from typing import Any, Callable, Dict, Optional, Tuple import jax -from jraph import GraphsTuple import torch import torch.distributed as dist +from jraph import GraphsTuple from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.ogbg import metrics from algoperf.workloads.ogbg.ogbg_pytorch import models from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN @@ -23,9 +22,11 @@ def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) return jax.tree.map( - lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) - if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), - inputs) + lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) + if len(a.shape) == 3 + else torch.as_tensor(a, device=DEVICE).view(-1), + inputs, + ) def _shard(inputs: Any) -> Any: @@ -36,44 +37,47 @@ def _shard(inputs: Any) -> Any: def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: return GraphsTuple( - nodes=function(graph.nodes), - edges=function(graph.edges), - receivers=function(graph.receivers), - senders=function(graph.senders), - globals=function(graph.globals), - n_node=function(graph.n_node), - n_edge=function(graph.n_edge)) + nodes=function(graph.nodes), + edges=function(graph.edges), + receivers=function(graph.receivers), + senders=function(graph.senders), + globals=function(graph.globals), + n_node=function(graph.n_node), + n_edge=function(graph.n_edge), + ) class OgbgWorkload(BaseOgbgWorkload): - # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of valid examples in batch, 'per_example': 1-d array of per-example losses} (not synced across devices). """ - loss_dict = super().loss_fn(label_batch, - logits_batch, - mask_batch, - label_smoothing) + loss_dict = super().loss_fn( + label_batch, logits_batch, mask_batch, label_smoothing + ) loss_dict['n_valid_examples'] = torch.as_tensor( - loss_dict['n_valid_examples'], device=DEVICE) + loss_dict['n_valid_examples'], device=DEVICE + ) return loss_dict - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + ): # TODO: Check where the + 1 comes from. per_device_batch_size = int(global_batch_size / N_GPUS) + 1 @@ -81,10 +85,9 @@ def _build_input_queue(self, # avoid creating too many threads. if RANK == 0: data_rng = data_rng.astype('uint32') - dataset_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size) + dataset_iter = super()._build_input_queue( + data_rng, split, data_dir, global_batch_size + ) while True: if RANK == 0: @@ -92,14 +95,16 @@ def _build_input_queue(self, graph = _graph_map(_pytorch_map, batch['inputs']) targets = torch.as_tensor(batch['targets'], device=DEVICE) weights = torch.as_tensor( - batch['weights'], dtype=torch.bool, device=DEVICE) + batch['weights'], dtype=torch.bool, device=DEVICE + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: dist.broadcast_object_list([graph], src=0, device=DEVICE) # During eval, the batch size of the remainder might be different. if split != 'train': per_device_batch_size = torch.tensor( - len(targets[0]), dtype=torch.int32, device=DEVICE) + len(targets[0]), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) dist.broadcast(targets, src=0) targets = targets[0] @@ -114,25 +119,27 @@ def _build_input_queue(self, graph = graph[0] # During eval, the batch size of the remainder might be different. if split != 'train': - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) targets = torch.empty( - (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE) + (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] weights = torch.empty( - (N_GPUS, per_device_batch_size, self._num_outputs), - dtype=torch.bool, - device=DEVICE) + (N_GPUS, per_device_batch_size, self._num_outputs), + dtype=torch.bool, + device=DEVICE, + ) dist.broadcast(weights, src=0) weights = weights[RANK] batch = { - 'inputs': _graph_map(_shard, graph), - 'targets': targets, - 'weights': weights, + 'inputs': _graph_map(_shard, graph), + 'targets': targets, + 'weights': weights, } yield batch @@ -140,11 +147,12 @@ def _build_input_queue(self, def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( - num_outputs=self._num_outputs, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps, - activation_fn_name=self.activation_fn_name) + num_outputs=self._num_outputs, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -159,21 +167,22 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['decoder.weight', 'decoder.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. if model_state is not None: raise ValueError( - f'Expected model_state to be None, received {model_state}.') + f'Expected model_state to be None, received {model_state}.' + ) model = params if mode == spec.ForwardPassMode.TRAIN: @@ -182,28 +191,31 @@ def model_fn( model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate) + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits, None def _binary_cross_entropy_with_mask( - self, - labels: torch.Tensor, - logits: torch.Tensor, - mask: torch.Tensor, - label_smoothing: float = 0.0) -> torch.Tensor: + self, + labels: torch.Tensor, + logits: torch.Tensor, + mask: torch.Tensor, + label_smoothing: float = 0.0, + ) -> torch.Tensor: """Binary cross entropy loss for logits, with masked elements.""" if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens raise ValueError( - f'Shape mismatch between logits ({logits.shape}), targets ' - f'({labels.shape}), and weights ({mask.shape}).') + f'Shape mismatch between logits ({logits.shape}), targets ' + f'({labels.shape}), and weights ({mask.shape}).' + ) if len(logits.shape) != 2: raise ValueError(f'Rank of logits ({logits.shape}) must be 2.') @@ -213,36 +225,40 @@ def _binary_cross_entropy_with_mask( # Apply label_smoothing. num_classes = labels.shape[-1] - smoothed_labels = ((1.0 - label_smoothing) * labels + - label_smoothing / num_classes) + smoothed_labels = ( + 1.0 - label_smoothing + ) * labels + label_smoothing / num_classes # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). positive_logits = logits >= 0 relu_logits = torch.where(positive_logits, logits, 0) abs_logits = torch.where(positive_logits, logits, -logits) - losses = relu_logits - (logits * smoothed_labels) + ( - torch.log(1 + torch.exp(-abs_logits))) - return torch.where(mask.to(torch.bool), losses, 0.) + losses = ( + relu_logits + - (logits * smoothed_labels) + + (torch.log(1 + torch.exp(-abs_logits))) + ) + return torch.where(mask.to(torch.bool), losses, 0.0) def _eval_metric(self, labels, logits, masks): loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( - loss=loss['per_example'].cpu().numpy(), - logits=logits.cpu().numpy(), - labels=labels.cpu().numpy(), - mask=masks.cpu().numpy()) + loss=loss['per_example'].cpu().numpy(), + logits=logits.cpu().numpy(), + labels=labels.cpu().numpy(), + mask=masks.cpu().numpy(), + ) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples return {k: float(v) for k, v in total_metrics.compute().items()} class OgbgGeluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -258,7 +274,6 @@ def test_target_value(self) -> float: class OgbgSiluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -274,7 +289,6 @@ def test_target_value(self) -> float: class OgbgModelSizeWorkload(OgbgWorkload): - @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..1d182fed5 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -14,7 +14,6 @@ class BaseOgbgWorkload(spec.Workload): - _num_outputs: int = 128 @property @@ -40,8 +39,10 @@ def num_message_passing_steps(self) -> int: return 5 def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result[ - 'validation/mean_average_precision'] > self.validation_target_value + return ( + eval_result['validation/mean_average_precision'] + > self.validation_target_value + ) @property def validation_target_value(self) -> float: @@ -94,15 +95,16 @@ def max_allowed_runtime_sec(self) -> int: def eval_period_time_sec(self) -> int: return 4 * 60 - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int): - dataset_iter = input_pipeline.get_dataset_iter(split, - data_rng, - data_dir, - global_batch_size) + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + ): + dataset_iter = input_pipeline.get_dataset_iter( + split, data_rng, data_dir, global_batch_size + ) if split != 'train': # Note that this stores the entire val dataset in memory. dataset_iter = itertools.cycle(dataset_iter) @@ -111,11 +113,12 @@ def _build_input_queue(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -123,19 +126,20 @@ def loss_fn( (not synced across devices). """ per_example_losses = self._binary_cross_entropy_with_mask( - labels=label_batch, - logits=logits_batch, - mask=mask_batch, - label_smoothing=label_smoothing) + labels=label_batch, + logits=logits_batch, + mask=mask_batch, + label_smoothing=label_smoothing, + ) if mask_batch is not None: n_valid_examples = mask_batch.sum() else: n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @property @@ -145,39 +149,45 @@ def step_hint(self) -> int: @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> metrics.EvalMetrics: + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> metrics.EvalMetrics: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) return self._eval_metric(batch['targets'], logits, batch['weights']) - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size) + data_rng, split, data_dir, global_batch_size=global_batch_size + ) total_metrics = None num_eval_steps = int(math.ceil(float(num_examples) / global_batch_size)) @@ -186,8 +196,10 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) batch_metrics = self._eval_batch(params, batch, model_state, model_rng) total_metrics = ( - batch_metrics - if total_metrics is None else total_metrics.merge(batch_metrics)) + batch_metrics + if total_metrics is None + else total_metrics.merge(batch_metrics) + ) if total_metrics is None: return {} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/utils.py b/algoperf/workloads/utils.py index 7719f91fb..920c3cf46 100644 --- a/algoperf/workloads/utils.py +++ b/algoperf/workloads/utils.py @@ -5,10 +5,12 @@ def print_jax_model_summary(model, fake_inputs): """Prints a summary of the jax module.""" tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={ - 'force_terminal': False, 'force_jupyter': False, 'width': 240 - }, + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, + 'force_jupyter': False, + 'width': 240, + }, ) print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..e0064ba51 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -30,11 +30,11 @@ def my_log(num): """ - Floors the log function + Floors the log function - :param num: the number - :return: log(num) floored to a very low number - """ + :param num: the number + :return: log(num) floored to a very low number + """ if num == 0.0: return -9999999999 @@ -43,12 +43,12 @@ def my_log(num): def tokenize_13a(line): """ - Tokenizes an input line using a relatively minimal tokenization that is - however equivalent to mteval-v13a, used by WMT. + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. - :param line: a segment to tokenize - :return: the tokenized line - """ + :param line: a segment to tokenize + :return: the tokenized line + """ norm = line @@ -62,14 +62,17 @@ def tokenize_13a(line): norm = norm.replace('>', '>') # language-dependent part (assuming Western languages): - norm = " {} ".format(norm) + norm = ' {} '.format(norm) norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) - norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', - norm) # tokenize period and comma unless preceded by a digit - norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', - norm) # tokenize period and comma unless followed by a digit - norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', - norm) # tokenize dash when preceded by a digit + norm = re.sub( + r'([^0-9])([\.,])', '\\1 \\2 ', norm + ) # tokenize period and comma unless preceded by a digit + norm = re.sub( + r'([\.,])([^0-9])', ' \\1 \\2', norm + ) # tokenize period and comma unless followed by a digit + norm = re.sub( + r'([0-9])(-)', '\\1 \\2 ', norm + ) # tokenize dash when preceded by a digit norm = re.sub(r'\s+', ' ', norm) # one space only between words norm = re.sub(r'^\s+', '', norm) # no leading space norm = re.sub(r'\s+$', '', norm) # no trailing space @@ -80,14 +83,15 @@ def tokenize_13a(line): class UnicodeRegex: """Ad-hoc hack to recognize all punctuation and symbols. - without depending on https://pypi.python.org/pypi/regex/.""" + without depending on https://pypi.python.org/pypi/regex/.""" @staticmethod def _property_chars(prefix): return ''.join( - chr(x) - for x in range(sys.maxunicode) - if unicodedata.category(chr(x)).startswith(prefix)) + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix) + ) punctuation = _property_chars('P') nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') @@ -98,27 +102,27 @@ def _property_chars(prefix): def tokenize_v14_international(string): r"""Tokenize a string following the official BLEU implementation. - See - https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 - In our case, the input string is expected to be just one line - and no HTML entities de-escaping is needed. - So we just tokenize on punctuation and symbols, - except when a punctuation is preceded and followed by a digit - (e.g. a comma/dot as a thousand/decimal separator). - - Note that a number (e.g., a year) followed by a dot at the end of sentence - is NOT tokenized, - i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` - does not match this case (unless we add a space after each sentence). - However, this error is already in the original mteval-v14.pl - and we want to be consistent with it. - The error is not present in the non-international version, - which uses, - `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). - - :param string: the input string - :return: a list of tokens - """ + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) string = UnicodeRegex.symbol_re.sub(r' \1 ', string) @@ -127,94 +131,94 @@ def tokenize_v14_international(string): def tokenize_zh(sentence): """MIT License - Copyright (c) 2017 - Shujian Huang - - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files - (the "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject to the - following conditions: - - The above copyright notice and this permission notice shall be included - in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, - DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE - USE OR OTHER DEALINGS IN THE SOFTWARE. - - The tokenization of Chinese text in this script contains two steps: - separate each Chinese characters (by utf-8 encoding); - tokenize the non Chinese part (following the mteval script). - Author: Shujian Huang huangsj@nju.edu.cn - - :param sentence: input sentence - :return: tokenized sentence - """ + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ def is_chinese_char(uchar): """ - :param uchar: input char in unicode - :return: whether the input char is a Chinese character. - """ - if "\u3400" <= uchar <= "\u4db5": + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if '\u3400' <= uchar <= '\u4db5': return True - elif "\u4e00" <= uchar <= "\u9fa5": + elif '\u4e00' <= uchar <= '\u9fa5': return True - elif "\u9fa6" <= uchar <= "\u9fbb": + elif '\u9fa6' <= uchar <= '\u9fbb': return True - elif "\uf900" <= uchar <= "\ufa2d": + elif '\uf900' <= uchar <= '\ufa2d': return True - elif "\ufa30" <= uchar <= "\ufa6a": + elif '\ufa30' <= uchar <= '\ufa6a': return True - elif "\ufa70" <= uchar <= "\ufad9": + elif '\ufa70' <= uchar <= '\ufad9': return True - elif "\u20000" <= uchar <= "\u2a6d6": + elif '\u20000' <= uchar <= '\u2a6d6': return True - elif "\u2f800" <= uchar <= "\u2fa1d": + elif '\u2f800' <= uchar <= '\u2fa1d': return True - elif "\uff00" <= uchar <= "\uffef": + elif '\uff00' <= uchar <= '\uffef': return True - elif "\u2e80" <= uchar <= "\u2eff": + elif '\u2e80' <= uchar <= '\u2eff': return True - elif "\u3000" <= uchar <= "\u303f": + elif '\u3000' <= uchar <= '\u303f': return True - elif "\u31c0" <= uchar <= "\u31ef": + elif '\u31c0' <= uchar <= '\u31ef': return True - elif "\u2f00" <= uchar <= "\u2fdf": + elif '\u2f00' <= uchar <= '\u2fdf': return True - elif "\u2ff0" <= uchar <= "\u2fff": + elif '\u2ff0' <= uchar <= '\u2fff': return True - elif "\u3100" <= uchar <= "\u312f": + elif '\u3100' <= uchar <= '\u312f': return True - elif "\u31a0" <= uchar <= "\u31bf": + elif '\u31a0' <= uchar <= '\u31bf': return True - elif "\ufe10" <= uchar <= "\ufe1f": + elif '\ufe10' <= uchar <= '\ufe1f': return True - elif "\ufe30" <= uchar <= "\ufe4f": + elif '\ufe30' <= uchar <= '\ufe4f': return True - elif "\u2600" <= uchar <= "\u26ff": + elif '\u2600' <= uchar <= '\u26ff': return True - elif "\u2700" <= uchar <= "\u27bf": + elif '\u2700' <= uchar <= '\u27bf': return True - elif "\u3200" <= uchar <= "\u32ff": + elif '\u3200' <= uchar <= '\u32ff': return True - elif "\u3300" <= uchar <= "\u33ff": + elif '\u3300' <= uchar <= '\u33ff': return True return False sentence = sentence.strip() - sentence_in_chars = "" + sentence_in_chars = '' for char in sentence: if is_chinese_char(char): - sentence_in_chars += " " + sentence_in_chars += ' ' sentence_in_chars += char - sentence_in_chars += " " + sentence_in_chars += ' ' else: sentence_in_chars += char sentence = sentence_in_chars @@ -245,10 +249,10 @@ def is_chinese_char(uchar): TOKENIZERS = { - '13a': tokenize_13a, - 'intl': tokenize_v14_international, - 'zh': tokenize_zh, - 'none': lambda x: x, + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, } DEFAULT_TOKENIZER = '13a' @@ -256,16 +260,16 @@ def is_chinese_char(uchar): def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. - :param line: a segment containing a sequence of words - :param max_order: collect n-grams from 1<=n<=max - :return: a dictionary containing ngrams and counts - """ + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ ngrams = Counter() tokens = line.split() for n in range(min_order, max_order + 1): for i in range(0, len(tokens) - n + 1): - ngram = ' '.join(tokens[i:i + n]) + ngram = ' '.join(tokens[i : i + n]) ngrams[ngram] += 1 return ngrams @@ -293,41 +297,44 @@ def ref_stats(output, refs): return ngrams, closest_diff, closest_len -BLEU = namedtuple('BLE', - 'score, counts, totals, precisions, bp, sys_len, ref_len') +BLEU = namedtuple( + 'BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len' +) -def compute_bleu(correct: List[int], - total: List[int], - sys_len: int, - ref_len: int, - smooth_method='none', - smooth_value=SMOOTH_VALUE_DEFAULT, - use_effective_order=False) -> BLEU: +def compute_bleu( + correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False, +) -> BLEU: """Computes BLEU score from its sufficient statistics. Adds smoothing. - Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques - for Sentence-Level BLEU", Boxing Chen and Colin Cherry, - WMT 2014: http://aclweb.org/anthology/W14-3346) - - - exp: NIST smoothing method (Method 3) - - floor: Method 1 - - add-k: Method 2 (generalizing Lin and Och, 2004) - - none: do nothing. - - :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER - :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER - :param sys_len: The cumulative system length - :param ref_len: The cumulative reference length - :param smooth: The smoothing method to use - :param smooth_value: The smoothing value added, if smooth is 'floor' - :param use_effective_order: Use effective order. - :return: A BLEU object with the score (100-based) and other statistics. - """ + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth is 'floor' + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ precisions = [0 for x in range(NGRAM_ORDER)] - smooth_mteval = 1. + smooth_mteval = 1.0 effective_order = NGRAM_ORDER for n in range(NGRAM_ORDER): if smooth_method == 'add-k' and n > 1: @@ -342,11 +349,11 @@ def compute_bleu(correct: List[int], if correct[n] == 0: if smooth_method == 'exp': smooth_mteval *= 2 - precisions[n] = 100. / (smooth_mteval * total[n]) + precisions[n] = 100.0 / (smooth_mteval * total[n]) elif smooth_method == 'floor': - precisions[n] = 100. * smooth_value / total[n] + precisions[n] = 100.0 * smooth_value / total[n] else: - precisions[n] = 100. * correct[n] / total[n] + precisions[n] = 100.0 * correct[n] / total[n] # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU # score is 0 (technically undefined). This is a problem for sentence-level @@ -360,20 +367,24 @@ def compute_bleu(correct: List[int], brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 bleu = brevity_penalty * math.exp( - sum(map(my_log, precisions[:effective_order])) / effective_order) + sum(map(my_log, precisions[:effective_order])) / effective_order + ) return BLEU._make( - [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) - - -def corpus_bleu(sys_stream: Sequence[str], - ref_streams: Sequence[str], - smooth_method: str = 'exp', - smooth_value: float = 0.0, - force: bool = False, - lowercase: bool = False, - tokenize: str = '13a', - use_effective_order: bool = False) -> BLEU: + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len] + ) + + +def corpus_bleu( + sys_stream: Sequence[str], + ref_streams: Sequence[str], + smooth_method: str = 'exp', + smooth_value: float = 0.0, + force: bool = False, + lowercase: bool = False, + tokenize: str = '13a', + use_effective_order: bool = False, +) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -414,13 +425,16 @@ def corpus_bleu(sys_stream: Sequence[str], tokenized_count += 1 if tokenized_count == 100: + logging.warning("That's 100 lines that end in a tokenized period ('.')") + logging.warning( + 'It looks like you forgot to detokenize your test ' + 'data, which may hurt your score.' + ) logging.warning( - 'That\'s 100 lines that end in a tokenized period (\'.\')') - logging.warning('It looks like you forgot to detokenize your test ' - 'data, which may hurt your score.') - logging.warning('If you insist your data is detokenized, ' - 'or don\'t care, you can suppress this message with ' - '\'--force\'.') + 'If you insist your data is detokenized, ' + "or don't care, you can suppress this message with " + "'--force'." + ) output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] @@ -453,10 +467,11 @@ def corpus_bleu(sys_stream: Sequence[str], total = total.cpu().numpy().tolist() return compute_bleu( - correct, - total, - sys_len, - ref_len, - smooth_method=smooth_method, - smooth_value=smooth_value, - use_effective_order=use_effective_order) + correct, + total, + sys_len, + ref_len, + smooth_method=smooth_method, + smooth_value=smooth_value, + use_effective_order=use_effective_order, + ) diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py index d743b43b0..1df1dfc55 100644 --- a/algoperf/workloads/wmt/input_pipeline.py +++ b/algoperf/workloads/wmt/input_pipeline.py @@ -1,4 +1,5 @@ """Input pipeline for a WMT dataset.""" + import functools import os from typing import Dict, List, Optional, Union @@ -16,10 +17,10 @@ Features = Dict[str, tf.Tensor] TFDS_SPLIT_NAME = { - 'train': 'train', - 'eval_train': 'train', - 'validation': 'validation', - 'test': 'test', + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', + 'test': 'test', } @@ -31,9 +32,11 @@ def normalize_feature_names(ds_info, features: Features) -> Features: return features -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate @@ -75,7 +78,8 @@ def pack_dataset(dataset: tf.data.Dataset, for k in keys: if k not in shapes: raise ValueError( - f'Key {k} not found in dataset. Available keys are {shapes.keys()}') + f'Key {k} not found in dataset. Available keys are {shapes.keys()}' + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the @@ -88,13 +92,15 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + batch_size, padded_shapes={k: [-1] for k in keys} + ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -104,9 +110,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, - keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: @@ -127,8 +133,9 @@ def write_packed_example(partial, outputs): new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -146,9 +153,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -163,13 +172,15 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -180,49 +191,51 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( - cond=lambda *_: True, - body=body_fn, - loop_vars=(i, partial, outputs), - shape_invariants=( - tf.TensorShape([]), - {k: tf.TensorShape([None]) for k in keys_etc}, - {k: tf.TensorShape(None) for k in keys_etc}, - ), - maximum_iterations=dynamic_batch_size) + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, + {k: tf.TensorShape(None) for k in keys_etc}, + ), + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch() -def preprocess_wmt_data(dataset: tf.data.Dataset, - data_rng, - train: bool, - shuffle: bool, - shuffle_buffer_size: int = 1024, - max_length: int = 256, - global_batch_size: int = 128): +def preprocess_wmt_data( + dataset: tf.data.Dataset, + data_rng, + train: bool, + shuffle: bool, + shuffle_buffer_size: int = 1024, + max_length: int = 256, + global_batch_size: int = 128, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): - def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) @@ -242,24 +255,27 @@ def filter_fn(x): dataset = dataset.batch(global_batch_size, drop_remainder=train) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - global_batch_size, - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=False) + global_batch_size, + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=False, + ) dataset = dataset.prefetch(AUTOTUNE) return dataset -def get_wmt_dataset(data_rng, - split: str, - data_dir: str, - is_training: bool, - vocab_size: int, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): +def get_wmt_dataset( + data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None, +): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model') @@ -271,7 +287,8 @@ def get_wmt_dataset(data_rng, dataset_builder = tfds.builder(ds_name, data_dir=data_dir) ds = dataset_builder.as_dataset( - split=TFDS_SPLIT_NAME[split], shuffle_files=False) + split=TFDS_SPLIT_NAME[split], shuffle_files=False + ) # Avoid creating too many threads when using PyTorch DDP. if RANK != 0: @@ -280,8 +297,9 @@ def get_wmt_dataset(data_rng, ds = ds.with_options(options) ds = ds.map( - functools.partial(normalize_feature_names, dataset_builder.info), - num_parallel_calls=AUTOTUNE) + functools.partial(normalize_feature_names, dataset_builder.info), + num_parallel_calls=AUTOTUNE, + ) # Load tf-text SentencePiece tokenizer. sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path) @@ -289,12 +307,13 @@ def get_wmt_dataset(data_rng, shuffle = split in ['train', 'eval_train'] ds = preprocess_wmt_data( - ds, - data_rng, - train=is_training, - shuffle=shuffle, - global_batch_size=global_batch_size, - max_length=256) + ds, + data_rng, + train=is_training, + shuffle=shuffle, + global_batch_size=global_batch_size, + max_length=256, + ) if num_batches: ds = ds.take(num_batches) @@ -303,9 +322,10 @@ def get_wmt_dataset(data_rng, ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return ds, sp_tokenizer diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py index 1f001e619..d94cab808 100644 --- a/algoperf/workloads/wmt/tokenizer.py +++ b/algoperf/workloads/wmt/tokenizer.py @@ -19,9 +19,9 @@ def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets') + dataset: tf.data.Dataset, + maxchars: int = int(1e7), + data_keys=('inputs', 'targets'), ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -36,7 +36,8 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix='/tmp/ds_chars' + ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: @@ -46,14 +47,16 @@ def _dump_chars_to_textfile( return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets')): +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -75,17 +78,21 @@ def _train_sentencepiece(dataset: tf.data.Dataset, else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) + dataset, maxchars=maxchars, data_keys=data_keys + ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ + argstr = ' '.join( + [ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', - ]) + ] + ) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address @@ -100,32 +107,38 @@ def _train_sentencepiece(dataset: tf.data.Dataset, time.sleep(1) -def _load_sentencepiece_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer -def train_tokenizer(dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets')): +def train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str, str] = ('inputs', 'targets'), +): """Trains a tokenizer from `dataset`.""" logging.info('Building SentencePiece vocab from data.') _train_sentencepiece( - dataset, - vocab_size=vocab_size, - maxchars=max_corpus_chars, - model_path=vocab_path, - data_keys=data_keys) + dataset, + vocab_size=vocab_size, + maxchars=max_corpus_chars, + model_path=vocab_path, + data_keys=data_keys, + ) def load_tokenizer(vocab_path: str): @@ -135,7 +148,6 @@ def load_tokenizer(vocab_path: str): @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py index dfead5918..b5f5f1099 100644 --- a/algoperf/workloads/wmt/wmt_jax/decode.py +++ b/algoperf/workloads/wmt/wmt_jax/decode.py @@ -78,8 +78,9 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size): [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ batch_indices = jnp.reshape( - jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size)) + jnp.arange(batch_size * new_beam_size) // new_beam_size, + (batch_size, new_beam_size), + ) def gather_fn(x): if x.ndim < 2: # ignore scalars (e.g. cache index) @@ -114,6 +115,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @flax.struct.dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence log probabilities and finished sequence scores. @@ -133,7 +135,8 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] + ) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) @@ -141,25 +144,28 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( - cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) # Beam search routine: -def beam_search(inputs, - cache, - tokens_to_logits, - beam_size=4, - alpha=0.6, - eos_id=EOS_ID, - max_decode_len=None): +def beam_search( + inputs, + cache, + tokens_to_logits, + beam_size=4, + alpha=0.6, + eos_id=EOS_ID, + max_decode_len=None, +): """Beam search for transformer machine translation. Args: @@ -185,10 +191,9 @@ def beam_search(inputs, end_marker = jnp.array(eos_id) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache) + beam_search_init_state = beam_init( + batch_size, beam_size, max_decode_len, cache + ) def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" @@ -201,11 +206,12 @@ def beam_search_loop_cond_fn(state): best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores = jnp.min( - state.finished_scores, axis=1, keepdims=True) + state.finished_scores, axis=1, keepdims=True + ) # Mask out scores from slots without any actual finished sequences. - worst_finished_scores = jnp.where(state.finished_flags, - worst_finished_scores, - NEG_INF) + worst_finished_scores = jnp.where( + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = jnp.all(worst_finished_scores > best_live_scores) @@ -221,8 +227,10 @@ def beam_search_loop_body_fn(state): # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( - lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), - (batch_size, beam_size, 1))) + lax.dynamic_slice( + state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) + ) + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree.map(flatten_beam_dim, state.cache) @@ -237,14 +245,16 @@ def beam_search_loop_body_fn(state): # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree.map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache + ) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] - log_probs = ( - candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) + log_probs = candidate_log_probs + jnp.expand_dims( + state.live_logprobs, axis=2 + ) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] @@ -264,10 +274,9 @@ def beam_search_loop_body_fn(state): topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] - topk_seq = gather_beams(state.live_seqs, - topk_beam_indices, - batch_size, - beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -275,13 +284,14 @@ def beam_search_loop_body_fn(state): topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq = lax.dynamic_update_slice(topk_seq, - topk_ids, (0, 0, state.cur_index + 1)) + topk_seq = lax.dynamic_update_slice( + topk_seq, topk_ids, (0, 0, state.cur_index + 1) + ) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -292,22 +302,20 @@ def beam_search_loop_body_fn(state): new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] - top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], - new_topk_indices, - batch_size, beam_size) + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] - top_alive_indices = gather_beams(topk_beam_indices, - new_topk_indices, - batch_size, - beam_size) + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} - top_alive_cache = gather_beams(new_cache, - top_alive_indices, - batch_size, - beam_size) + top_alive_cache = gather_beams( + new_cache, top_alive_indices, batch_size, beam_size + ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -320,42 +328,54 @@ def beam_search_loop_body_fn(state): # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], - axis=1) + [state.finished_seqs, topk_seq], axis=1 + ) finished_scores = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_scores, new_scores], axis=1) + [state.finished_scores, new_scores], axis=1 + ) finished_flags = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], axis=1) + [state.finished_flags, newly_finished], axis=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) return BeamState( - cur_index=state.cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache) + cur_index=state.cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) # Run while loop and get final beam search state. - final_state = lax.while_loop(beam_search_loop_cond_fn, - beam_search_loop_body_fn, - beam_search_init_state) + final_state = lax.while_loop( + beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state + ) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] none_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] - finished_seqs = jnp.where(none_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs) + finished_seqs = jnp.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = jnp.where(none_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs) + finished_scores = jnp.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) return finished_seqs, finished_scores diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 1147eb34b..81f2ece4c 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -5,11 +5,11 @@ from typing import Any, Callable, Optional +import jax.numpy as jnp +import numpy as np from flax import linen as nn from flax import struct from jax import lax -import jax.numpy as jnp -import numpy as np from algoperf.jax_utils import Dropout @@ -19,6 +19,7 @@ @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + share_embeddings: bool = True dtype: Any = jnp.float32 vocab_size: int = 32000 @@ -44,7 +45,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return padded[:, :-1] @@ -68,8 +70,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -83,6 +85,7 @@ class AddPositionEmbs(nn.Module): config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ + config: TransformerConfig decode: bool = False @@ -103,27 +106,28 @@ def __call__(self, inputs, inputs_positions=None): """ cfg = self.config # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - f' but it is: {inputs.ndim}') + assert inputs.ndim == 3, ( + f'Number of dimensions should be 3, but it is: {inputs.ndim}' + ) length = inputs.shape[1] pos_emb_shape = (1, cfg.max_len, inputs.shape[-1]) if cfg.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None, - pos_emb_shape, - None) + pos_embedding = sinusoidal_init(max_len=cfg.max_len)( + None, pos_emb_shape, None + ) else: - pos_embedding = self.param('pos_embedding', - cfg.posemb_init, - pos_emb_shape) + pos_embedding = self.param( + 'pos_embedding', cfg.posemb_init, pos_emb_shape + ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', - 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -140,10 +144,10 @@ def __call__(self, inputs, inputs_positions=None): class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - out_dim: optionally specify out dimension. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ config: TransformerConfig out_dim: Optional[int] = None @@ -155,42 +159,41 @@ def __call__(self, inputs, dropout_rate=DROPOUT_RATE): actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) + x = cfg.activation(x) + if cfg.glu: + y = nn.Dense( cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, - )( - inputs) - x = cfg.activation(x) - if cfg.glu: - y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )( - inputs) + )(inputs) x = x * y x = Dropout(rate=dropout_rate)( - x, rate=dropout_rate, deterministic=cfg.deterministic) + x, rate=dropout_rate, deterministic=cfg.deterministic + ) output = nn.Dense( - actual_out_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - )( - x) + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(x) output = Dropout(rate=dropout_rate)( - output, rate=dropout_rate, deterministic=cfg.deterministic) + output, rate=dropout_rate, deterministic=cfg.deterministic + ) return output class Encoder1DBlock(nn.Module): """Transformer encoder layer. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ config: TransformerConfig @@ -198,13 +201,13 @@ class Encoder1DBlock(nn.Module): def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): """Applies Encoder1DBlock module. - Args: - inputs: input data. - encoder_mask: encoder self-attention mask. + Args: + inputs: input data. + encoder_mask: encoder self-attention mask. - Returns: - output after transformer encoder block. - """ + Returns: + output after transformer encoder block. + """ cfg = self.config pre_ln = cfg.pre_ln @@ -213,19 +216,20 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) x = Dropout(rate=dropout_rate)( - x, deterministic=cfg.deterministic, rate=dropout_rate) + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -240,32 +244,32 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): class EncoderDecoder1DBlock(nn.Module): """Transformer encoder-decoder layer. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + """ config: TransformerConfig @nn.compact def __call__( - self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=DROPOUT_RATE, + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, ): """Applies EncoderDecoder1DBlock module. - Args: - targets: input data for decoder - encoded: input data from encoder - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. + Args: + targets: input data for decoder + encoded: input data from encoder + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. - Returns: - output after transformer encoder-decoder block. - """ + Returns: + output after transformer encoder-decoder block. + """ cfg = self.config pre_ln = cfg.pre_ln @@ -275,20 +279,21 @@ def __call__( x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask) x = Dropout(rate=dropout_rate)( - x, deterministic=cfg.deterministic, rate=dropout_rate) + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -296,19 +301,20 @@ def __call__( # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=dropout_rate, - deterministic=cfg.deterministic, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = Dropout(rate=dropout_rate)( - y, deterministic=cfg.deterministic, rate=dropout_rate) + y, deterministic=cfg.deterministic, rate=dropout_rate + ) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) @@ -323,30 +329,32 @@ def __call__( class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - encoder_mask=None, - dropout_rate=DROPOUT_RATE): + def __call__( + self, + inputs, + inputs_positions=None, + encoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. - Args: - inputs: input data - inputs_positions: input subsequence positions for packed examples. - encoder_mask: decoder self-attention mask. + Args: + inputs: input data + inputs_positions: input subsequence positions for packed examples. + encoder_mask: decoder self-attention mask. - Returns: - output of a transformer encoder. - """ + Returns: + output of a transformer encoder. + """ cfg = self.config assert inputs.ndim == 2 # (batch, len) @@ -354,30 +362,34 @@ def __call__(self, # Input Embedding if self.shared_embedding is None: input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), ) else: input_embed = self.shared_embedding - x = inputs.astype("int32") + x = inputs.astype('int32') x = input_embed(x) - x = AddPositionEmbs( - config=cfg, decode=False, name="posembed_input")( - x, inputs_positions=inputs_positions) + x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')( + x, inputs_positions=inputs_positions + ) x = Dropout(rate=dropout_rate)( - x, deterministic=cfg.deterministic, rate=dropout_rate) + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x.astype(cfg.dtype) # Input Encoder for lyr in range(cfg.num_layers): - x = Encoder1DBlock( - config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate) + x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')( + x, encoder_mask, dropout_rate + ) encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x) - if cfg.pre_ln else x) + nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) + if cfg.pre_ln + else x + ) return encoded @@ -385,36 +397,36 @@ def __call__(self, class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. - Attributes: - config: TransformerConfig dataclass containing hyperparameters. - shared_embedding: a shared embedding layer to use. - """ + Attributes: + config: TransformerConfig dataclass containing hyperparameters. + shared_embedding: a shared embedding layer to use. + """ config: TransformerConfig shared_embedding: Any = None @nn.compact def __call__( - self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - dropout_rate=DROPOUT_RATE, + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, ): """Applies Transformer model on the inputs. - Args: - encoded: encoded input data from encoder. - targets: target inputs. - targets_positions: input subsequence positions for packed examples. - decoder_mask: decoder self-attention mask. - encoder_decoder_mask: encoder-decoder attention mask. + Args: + encoded: encoded input data from encoder. + targets: target inputs. + targets_positions: input subsequence positions for packed examples. + decoder_mask: decoder self-attention mask. + encoder_decoder_mask: encoder-decoder attention mask. - Returns: - output of a transformer decoder. - """ + Returns: + output of a transformer decoder. + """ cfg = self.config assert encoded.ndim == 3 # (batch, len, depth) @@ -423,38 +435,40 @@ def __call__( # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0), + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), ) else: output_embed = self.shared_embedding - y = targets.astype("int32") + y = targets.astype('int32') if not cfg.decode: y = shift_right(y) y = output_embed(y) - y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name="posembed_output")( - y, inputs_positions=targets_positions) + y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')( + y, inputs_positions=targets_positions + ) y = Dropout(rate=dropout_rate)( - y, deterministic=cfg.deterministic, rate=dropout_rate) + y, deterministic=cfg.deterministic, rate=dropout_rate + ) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock( - config=cfg, name=f"encoderdecoderblock_{lyr}")( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=dropout_rate, - ) + y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) y = ( - nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y) - if cfg.pre_ln else y) + nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) + if cfg.pre_ln + else y + ) # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) @@ -469,6 +483,7 @@ class Transformer(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig def setup(self): @@ -477,22 +492,26 @@ def setup(self): if cfg.share_embeddings: if cfg.vocab_size is not None: assert cfg.vocab_size == cfg.vocab_size, ( - "can't share embedding with different vocab sizes.") + "can't share embedding with different vocab sizes." + ) self.shared_embedding = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: self.shared_embedding = None self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, - inputs, - inputs_positions=None, - inputs_segmentation=None, - dropout_rate=DROPOUT_RATE): + def encode( + self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer encoder-branch on the inputs. Args: @@ -506,31 +525,33 @@ def encode(self, cfg = self.config # Make padding attention mask. encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=cfg.dtype) + inputs > 0, inputs > 0, dtype=cfg.dtype + ) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( - encoder_mask, - nn.make_attention_mask( - inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) + encoder_mask, + nn.make_attention_mask( + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) return self.encoder( - inputs, - inputs_positions=inputs_positions, - encoder_mask=encoder_mask, - dropout_rate=dropout_rate) + inputs, + inputs_positions=inputs_positions, + encoder_mask=encoder_mask, + dropout_rate=dropout_rate, + ) def decode( - self, - encoded, - inputs, # only needed for masks - targets, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None, - dropout_rate=DROPOUT_RATE): + self, + encoded, + inputs, # only needed for masks + targets, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -550,47 +571,51 @@ def decode( if cfg.decode: decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype + ) else: decoder_mask = nn.combine_masks( - nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), - nn.make_causal_mask(targets, dtype=cfg.dtype)) + nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), + nn.make_causal_mask(targets, dtype=cfg.dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=cfg.dtype) + targets > 0, inputs > 0, dtype=cfg.dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( - decoder_mask, - nn.make_attention_mask( - targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=cfg.dtype)) + decoder_mask, + nn.make_attention_mask( + targets_segmentation, targets_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) encoder_decoder_mask = nn.combine_masks( - encoder_decoder_mask, - nn.make_attention_mask( - targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) + encoder_decoder_mask, + nn.make_attention_mask( + targets_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) logits = self.decoder( - encoded, - targets, - targets_positions=targets_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - dropout_rate=dropout_rate) + encoded, + targets, + targets_positions=targets_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) return logits.astype(self.config.dtype) - def __call__(self, - inputs, - targets, - inputs_positions=None, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None, - dropout_rate=DROPOUT_RATE): + def __call__( + self, + inputs, + targets, + inputs_positions=None, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. Args: @@ -605,16 +630,18 @@ def __call__(self, logits array from full transformer. """ encoded = self.encode( - inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - dropout_rate=dropout_rate) + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + ) return self.decode( - encoded, - inputs, # only used for masks - targets, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - dropout_rate=dropout_rate) + encoded, + inputs, # only used for masks + targets, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + dropout_rate=dropout_rate, + ) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index d402f9d95..51d8a85a7 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -1,23 +1,21 @@ """WMT workload implemented in Jax.""" -from dataclasses import replace import functools +from dataclasses import replace from typing import Any, Dict, Iterator, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np import optax +from absl import logging +from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.wmt import bleu -from algoperf.workloads.wmt.wmt_jax import decode -from algoperf.workloads.wmt.wmt_jax import models +from algoperf.workloads.wmt.wmt_jax import decode, models from algoperf.workloads.wmt.workload import BaseWmtWorkload @@ -31,11 +29,12 @@ class WmtWorkload(BaseWmtWorkload): """WMT Jax workload.""" def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -50,76 +49,86 @@ def compute_weighted_cross_entropy( valid examples in batch, 'per_example': 1-d array of per-example losses} """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) smoothed_targets = optax.smooth_labels( - common_utils.onehot(targets, self._vocab_size), label_smoothing) + common_utils.onehot(targets, self._vocab_size), label_smoothing + ) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits), axis=-1) + smoothed_targets * nn.log_softmax(logits), axis=-1 + ) if weights is None: weights = jnp.ones_like(targets) - per_example_losses = jnp.where(weights, per_example_losses, 0.) + per_example_losses = jnp.where(weights, per_example_losses, 0.0) summed_loss = per_example_losses.sum() n_valid_examples = weights.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,) + ) def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] weights = batch['weights'] logits = self._eval_model.apply({'params': params}, inputs, targets) - summed_loss = self.compute_weighted_cross_entropy(logits, - targets, - weights, - 0.0)['summed'] + summed_loss = self.compute_weighted_cross_entropy( + logits, targets, weights, 0.0 + )['summed'] acc_sum, weight_sum = self.compute_weighted_accuracy( - logits, targets, weights) + logits, targets, weights + ) return { - 'loss': summed_loss, - 'accuracy': acc_sum, - 'denominator': weight_sum, + 'loss': summed_loss, + 'accuracy': acc_sum, + 'denominator': weight_sum, } - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + def eval_step( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def initialize_cache(self, - inputs: spec.Tensor, - max_decode_len: int = 256) -> Dict[str, spec.Tensor]: + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,) + ) + def initialize_cache( + self, inputs: spec.Tensor, max_decode_len: int = 256 + ) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + jax.random.PRNGKey(0), + jnp.ones(inputs.shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + ) return initial_variables['cache'] # eos_id, max_decode_len are constant. @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) - def predict_step(self, - inputs: spec.Tensor, - params: spec.ParameterContainer, - cache: Dict[str, spec.Tensor], - eos_id: int, - max_decode_len: int, - beam_size: int = 4) -> spec.Tensor: + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5) + ) + def predict_step( + self, + inputs: spec.Tensor, + params: spec.ParameterContainer, + cache: Dict[str, spec.Tensor], + eos_id: int, + max_decode_len: int, + beam_size: int = 4, + ) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" config = replace(self._eval_model.config, decode=True) # Prepare transformer fast-decoder call for beam search: for beam search, we @@ -129,27 +138,29 @@ def predict_step(self, # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( - models.Transformer(config).apply({'params': params}, - inputs, - method=models.Transformer.encode), - beam_size) + models.Transformer(config).apply( + {'params': params}, inputs, method=models.Transformer.encode + ), + beam_size, + ) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits( - flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] + flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] ) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]: """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( - { - 'params': params, - 'cache': flat_cache, - }, - encoded_inputs, - raw_inputs, # only needed for input padding mask - flat_ids, - mutable=['cache'], - method=models.Transformer.decode) + { + 'params': params, + 'cache': flat_cache, + }, + encoded_inputs, + raw_inputs, # only needed for input padding mask + flat_ids, + mutable=['cache'], + method=models.Transformer.decode, + ) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] @@ -159,35 +170,36 @@ def tokens_ids_to_logits( # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( - inputs, - cache, - tokens_ids_to_logits, - beam_size=beam_size, - alpha=0.6, - eos_id=eos_id, - max_decode_len=max_decode_len) + inputs, + cache, + tokens_ids_to_logits, + beam_size=beam_size, + alpha=0.6, + eos_id=eos_id, + max_decode_len=max_decode_len, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] - def translate_and_calculate_bleu(self, - params: spec.ParameterContainer, - ds_iter: Iterator, - num_batches: int, - max_predict_length: int) -> spec.Tensor: + def translate_and_calculate_bleu( + self, + params: spec.ParameterContainer, + ds_iter: Iterator, + num_batches: int, + max_predict_length: int, + ) -> spec.Tensor: """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) + predicted = self.predict_step( + pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length + ) predicted = _to_host(predicted) targets = _to_host(pred_batch['targets']) # Find actual batch size, ignoring the potential padding. @@ -219,18 +231,20 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: raise ValueError(f'Unknown activation function {self.activation}.') model_config = models.TransformerConfig( - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu, + ) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, _ = jax.random.split(rng) - initial_variables = jax.jit( - self._eval_model.init)({'params': params_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + initial_variables = jax.jit(self._eval_model.init)( + {'params': params_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + ) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) @@ -241,14 +255,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: [float] = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: [float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm @@ -256,33 +270,39 @@ def model_fn( inputs = augmented_and_preprocessed_input_batch.get('inputs', None) targets = augmented_and_preprocessed_input_batch.get('targets', None) inputs_positions = augmented_and_preprocessed_input_batch.get( - 'inputs_position', None) + 'inputs_position', None + ) targets_positions = augmented_and_preprocessed_input_batch.get( - 'targets_position', None) + 'targets_position', None + ) inputs_segmentations = augmented_and_preprocessed_input_batch.get( - 'inputs_segmentation', None) + 'inputs_segmentation', None + ) targets_segmentations = augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None) + 'targets_segmentation', None + ) if mode == spec.ForwardPassMode.TRAIN: model = self._train_model else: model = self._eval_model - logits_batch = model.apply({'params': params}, - inputs, - targets, - inputs_positions=inputs_positions, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentations, - targets_segmentation=targets_segmentations, - rngs={'dropout': rng}, - dropout_rate=dropout_rate) + logits_batch = model.apply( + {'params': params}, + inputs, + targets, + inputs_positions=inputs_positions, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentations, + targets_segmentation=targets_segmentations, + rngs={'dropout': rng}, + dropout_rate=dropout_rate, + ) return logits_batch, None def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') diff --git a/algoperf/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py index 26ff36650..7974412d7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/decode.py +++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py @@ -21,8 +21,9 @@ NEG_INF = torch.tensor(-1.0e7, device=DEVICE) -def brevity_penalty(alpha: float, length: Union[int, - torch.Tensor]) -> torch.Tensor: +def brevity_penalty( + alpha: float, length: Union[int, torch.Tensor] +) -> torch.Tensor: """Brevity penalty function for beam search penalizing short sequences. Args: @@ -57,8 +58,9 @@ def flatten_beam_dim(x: torch.Tensor) -> torch.Tensor: return x.view(-1, *x.shape[2:]) -def unflatten_beam_dim(x: torch.Tensor, batch_size: int, - beam_size: int) -> torch.Tensor: +def unflatten_beam_dim( + x: torch.Tensor, batch_size: int, beam_size: int +) -> torch.Tensor: """Unflattens the first, flat batch*beam dimension of a non-scalar tensor.""" if x.dim() < 2: # ignore scalars (e.g. cache index) return x @@ -71,10 +73,12 @@ def flat_batch_beam_expand(x: torch.Tensor, beam_size: int) -> torch.Tensor: return flatten_beam_dim(add_beam_dim(x, beam_size)) -def gather_beams(nested: Dict[str, Any], - beam_indices: torch.Tensor, - batch_size: int, - new_beam_size: int) -> Dict[str, Any]: +def gather_beams( + nested: Dict[str, Any], + beam_indices: torch.Tensor, + batch_size: int, + new_beam_size: int, +) -> Dict[str, Any]: """Gathers the beam slices indexed by beam_indices into new beam tensor. Args: @@ -88,10 +92,13 @@ def gather_beams(nested: Dict[str, Any], [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ batch_indices = torch.reshape( - torch.div( - torch.arange(batch_size * new_beam_size, device=DEVICE), - new_beam_size, - rounding_mode='floor'), (batch_size, new_beam_size)) + torch.div( + torch.arange(batch_size * new_beam_size, device=DEVICE), + new_beam_size, + rounding_mode='floor', + ), + (batch_size, new_beam_size), + ) def gather_fn(x): if x.dim() < 2: # ignore scalars (e.g. cache index) @@ -101,10 +108,12 @@ def gather_fn(x): return jax.tree.map(gather_fn, nested) -def gather_topk_beams(nested: Dict[str, Any], - score_or_log_prob: torch.Tensor, - batch_size: int, - new_beam_size: int) -> Dict[str, Any]: +def gather_topk_beams( + nested: Dict[str, Any], + score_or_log_prob: torch.Tensor, + batch_size: int, + new_beam_size: int, +) -> Dict[str, Any]: """Gathers the top-k beam slices given by score_or_log_prob array. Args: @@ -129,6 +138,7 @@ def gather_topk_beams(nested: Dict[str, Any], @dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: torch.Tensor # scalar int32: current decoded length index. # The active sequence log probabilities and finished sequence scores. @@ -143,49 +153,52 @@ class BeamState: cache: Dict[str, Any] # Any dict (of dicts), with torch.Tensors as leafs. -def beam_init(batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Dict[str, Any]) -> BeamState: +def beam_init( + batch_size: int, beam_size: int, max_decode_len: int, cache: Dict[str, Any] +) -> BeamState: """Initializes the beam search state data structure.""" cur_index0 = torch.tensor(0, device=DEVICE) live_logprobs0 = torch.tile( - torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE), - [batch_size, 1]) + torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE), + [batch_size, 1], + ) finished_scores0 = ( - torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF) - live_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len), - dtype=torch.int32, - device=DEVICE) - finished_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len), - dtype=torch.int32, - device=DEVICE) - finished_flags0 = torch.zeros((batch_size, beam_size), - dtype=torch.bool, - device=DEVICE) + torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF + ) + live_seqs0 = torch.zeros( + (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE + ) + finished_seqs0 = torch.zeros( + (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE + ) + finished_flags0 = torch.zeros( + (batch_size, beam_size), dtype=torch.bool, device=DEVICE + ) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( - cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) # Beam search routine: def beam_search( - inputs: torch.Tensor, - cache: Optional[Dict[str, Any]], - tokens_to_logits: Callable, - beam_size: int = 4, - alpha: float = 0.6, - eos_id: int = EOS_ID, - max_decode_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + inputs: torch.Tensor, + cache: Optional[Dict[str, Any]], + tokens_to_logits: Callable, + beam_size: int = 4, + alpha: float = 0.6, + eos_id: int = EOS_ID, + max_decode_len: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: """Beam search for transformer machine translation. Args: @@ -211,10 +224,9 @@ def beam_search( end_marker = torch.tensor(eos_id, device=DEVICE) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache) + beam_search_init_state = beam_init( + batch_size, beam_size, max_decode_len, cache + ) def beam_search_loop_cond_fn(state: BeamState) -> bool: """Beam search loop termination condition.""" @@ -227,11 +239,12 @@ def beam_search_loop_cond_fn(state: BeamState) -> bool: best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores, _ = torch.min( - state.finished_scores, dim=1, keepdim=True) + state.finished_scores, dim=1, keepdim=True + ) # Mask out scores from slots without any actual finished sequences. - worst_finished_scores = torch.where(state.finished_flags, - worst_finished_scores, - NEG_INF) + worst_finished_scores = torch.where( + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = torch.all(worst_finished_scores > best_live_scores) @@ -248,7 +261,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # --> [batch * beam, 1] cur_index = state.cur_index flat_ids = flatten_beam_dim( - state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) + state.live_seqs[:batch_size, :beam_size, cur_index : cur_index + 1] + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree.map(flatten_beam_dim, state.cache) @@ -263,7 +277,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree.map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache + ) # Gather log probabilities from logits candidate_log_probs = F.log_softmax(logits, dim=-1) @@ -287,13 +302,13 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_log_probs, topk_indices = torch.topk(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = torch.div( - topk_indices, vocab_size, rounding_mode='floor') + topk_indices, vocab_size, rounding_mode='floor' + ) # Gather 2*k top beams. # --> [batch, 2*beams, length] - topk_seq = gather_beams(state.live_seqs, - topk_beam_indices, - batch_size, - beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -301,11 +316,11 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_ids = torch.unsqueeze(topk_indices % vocab_size, dim=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq[:, :, cur_index + 1:] = topk_ids + topk_seq[:, :, cur_index + 1 :] = topk_ids # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -316,22 +331,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: new_topk_indices = torch.flip(new_topk_indices, (1,)) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] - top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], - new_topk_indices, - batch_size, beam_size) + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] - top_alive_indices = gather_beams(topk_beam_indices, - new_topk_indices, - batch_size, - beam_size) + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} - top_alive_cache = gather_beams(new_cache, - top_alive_indices, - batch_size, - beam_size) + top_alive_cache = gather_beams( + new_cache, top_alive_indices, batch_size, beam_size + ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -344,24 +357,33 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = torch.cat( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], dim=1) + [state.finished_seqs, topk_seq], dim=1 + ) finished_scores = torch.cat( # --> [batch, 3*beams] - [state.finished_scores, new_scores], dim=1) + [state.finished_scores, new_scores], dim=1 + ) finished_flags = torch.cat( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], dim=1) + [state.finished_flags, newly_finished], dim=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) return BeamState( - cur_index=cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache) + cur_index=cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) state = beam_search_init_state while beam_search_loop_cond_fn(state): @@ -373,12 +395,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # --> [batch] none_finished = torch.any(final_state.finished_flags, dim=1) # --> [batch, beams, length] - finished_seqs = torch.where(none_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs) + finished_seqs = torch.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = torch.where(none_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs) + finished_scores = torch.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) return finished_seqs, finished_scores diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index 0de719c4b..430cc945b 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -3,11 +3,9 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from torch import nn -from torch import Tensor import torch.nn.functional as F -from torch.nn.init import normal_ -from torch.nn.init import xavier_uniform_ +from torch import Tensor, nn +from torch.nn.init import normal_, xavier_uniform_ DROPOUT_RATE = 0.1 @@ -23,7 +21,8 @@ def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: A `[batch..., len, len]` shaped causal attention mask. """ idxs = torch.broadcast_to( - torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) + torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape + ) return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) @@ -33,55 +32,60 @@ def make_src_mask(src, inputs_segmentation, nhead): # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: src_mask = torch.logical_and( - src_mask, - torch.eq( - inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) + src_mask, + torch.eq( + inputs_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2) + ), + ) # Flip values and ensure numerical stability. src_mask = torch.repeat_interleave( - torch.logical_not(src_mask), repeats=nhead, dim=0) + torch.logical_not(src_mask), repeats=nhead, dim=0 + ) new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) new_src_mask.masked_fill_(src_mask, -1e10) return new_src_mask -def make_tgt_and_memory_mask(tgt, - src, - inputs_segmentation, - targets_segmentation, - decode, - nhead): - """ Utility for creating target and memory mask and adjust them for PyTorch +def make_tgt_and_memory_mask( + tgt, src, inputs_segmentation, targets_segmentation, decode, nhead +): + """Utility for creating target and memory mask and adjust them for PyTorch Transformer API.""" if not decode: tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), - make_causal_mask(tgt, device=tgt.device)) + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), + make_causal_mask(tgt, device=tgt.device), + ) memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) else: tgt_mask = None - memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)) + memory_mask = torch.mul( + (torch.ones_like(tgt) > 0).unsqueeze(-1), (src > 0).unsqueeze(-2) + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: tgt_mask = torch.logical_and( - tgt_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2))) + tgt_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), targets_segmentation.unsqueeze(-2) + ), + ) memory_mask = torch.logical_and( - memory_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) + memory_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2) + ), + ) # Flip values and ensure numerical stability. memory_mask = torch.repeat_interleave( - torch.logical_not(memory_mask), repeats=nhead, dim=0) + torch.logical_not(memory_mask), repeats=nhead, dim=0 + ) new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) new_memory_mask.masked_fill_(memory_mask, -1e10) if tgt_mask is not None: tgt_mask = torch.repeat_interleave( - torch.logical_not(tgt_mask), repeats=nhead, dim=0) + torch.logical_not(tgt_mask), repeats=nhead, dim=0 + ) new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) new_tgt_mask.masked_fill_(tgt_mask, -1e10) tgt_mask = new_tgt_mask @@ -100,38 +104,44 @@ def shift_right(x, axis=1): class Transformer(nn.Module): """Transformer architecture based on the model from the WMT Jax workload.""" - def __init__(self, - ntoken: int = 32000, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + ntoken: int = 32000, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) - self.encoder = Encoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - self.decoder = Decoder(d_model, - nhead, - d_hid, - nlayers, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) + self.encoder = Encoder( + d_model, + nhead, + d_hid, + nlayers, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln, + ) + self.decoder = Decoder( + d_model, + nhead, + d_hid, + nlayers, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln, + ) # Share positional encoding and embedding between encoder and decoder. self.encoder.pos_encoder = self.pos_encoder self.encoder.shared_embedding = self.shared_embedding @@ -148,15 +158,17 @@ def _reset_parameters(self): if module.bias is not None: normal_(module.bias, std=1e-6) - def forward(self, - src: Tensor, - tgt: Tensor, - inputs_positions: Optional[Tensor] = None, - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - dropout_rate: float = DROPOUT_RATE) -> Tensor: + def forward( + self, + src: Tensor, + tgt: Tensor, + inputs_positions: Optional[Tensor] = None, + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + dropout_rate: float = DROPOUT_RATE, + ) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -175,19 +187,21 @@ def forward(self, raise RuntimeError('The batch size of src and tgt must be equal.') memory = self.encoder( - src, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - dropout_rate=dropout_rate) + src, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + ) output = self.decoder( - tgt, - memory, - src, # just for calculating the padding mask - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - decode=decode, - dropout_rate=dropout_rate) + tgt, + memory, + src, # just for calculating the padding mask + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + decode=decode, + dropout_rate=dropout_rate, + ) return output @@ -210,26 +224,32 @@ class TransformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ + __constants__ = ['norm'] - def __init__(self, - encoder_layer, - num_layers, - norm=None, - enable_nested_tensor=True, - mask_check=True): + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True, + ): super().__init__() self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) self.num_layers = num_layers self.norm = norm self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, - src: Tensor, - mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: """Pass the input through the encoder layers in turn. Args: @@ -247,7 +267,7 @@ def forward(self, output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: - output = output.to_padded_tensor(0.) + output = output.to_padded_tensor(0.0) if self.norm is not None: output = self.norm(output) @@ -256,39 +276,42 @@ def forward(self, class Encoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() self.nhead = nhead self.shared_embedding = None self.pos_encoder = None encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - d_hid, - activation=activation, - glu=glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) - encoder_norm = ( - nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + d_model, + nhead, + d_hid, + activation=activation, + glu=glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln, + ) + encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) - def forward(self, - src: Tensor, - inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: + def forward( + self, + src: Tensor, + inputs_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) @@ -298,67 +321,73 @@ def forward(self, class Decoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() self.nhead = nhead self.shared_embedding = None self.pos_encoder = None - self.decoder = TransformerDecoder(d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - nlayers, - attention_temp, - pre_ln) + self.decoder = TransformerDecoder( + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps, + nlayers, + attention_temp, + pre_ln, + ) def forward( - self, - tgt: Tensor, - memory: Tensor, - src: Tensor, # just for calculating the padding mask - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: + self, + tgt: Tensor, + memory: Tensor, + src: Tensor, # just for calculating the padding mask + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( - tgt, src, inputs_segmentation, targets_segmentation, - decode, self.nhead) + tgt, src, inputs_segmentation, targets_segmentation, decode, self.nhead + ) if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) tgt = self.pos_encoder( - tgt, - targets_positions, - decode=decode, - cache=cache, - dropout_rate=dropout_rate) + tgt, + targets_positions, + decode=decode, + cache=cache, + dropout_rate=dropout_rate, + ) if decode: tgt, cache = tgt output = self.decoder( - tgt, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - dropout_rate=dropout_rate) + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + dropout_rate=dropout_rate, + ) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -369,7 +398,6 @@ def forward( class PositionalEncoding(nn.Module): - def __init__(self, d_model: int, max_len: int = 256): super().__init__() @@ -377,17 +405,17 @@ def __init__(self, d_model: int, max_len: int = 256): scale_factor = -math.log(10000.0) / (d_model // 2 - 1) div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) pe = torch.zeros(1, max_len, d_model) - pe[0, :, :d_model // 2] = torch.sin(position * div_term) - pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) + pe[0, :, : d_model // 2] = torch.sin(position * div_term) + pe[0, :, d_model // 2 : 2 * (d_model // 2)] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward( - self, - x: Tensor, - inputs_positions: Optional[Tensor] = None, - decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None, - dropout_rate: Optional[float] = 0.0 + self, + x: Tensor, + inputs_positions: Optional[Tensor] = None, + decode: bool = False, + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = 0.0, ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -404,17 +432,18 @@ def forward( name = self._get_name() if cache is None: cache = { - name: { - 'cache_index': - torch.tensor(0, dtype=torch.long, device=self.pe.device), - }, + name: { + 'cache_index': torch.tensor( + 0, dtype=torch.long, device=self.pe.device + ), + }, } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: - pe = self.pe[:, :x.size(1), :] + pe = self.pe[:, : x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] @@ -449,28 +478,32 @@ class TransformerEncoderLayer(nn.Module): >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ + __constants__ = ['pre_ln'] - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True, - device=None, - dtype=None) -> None: + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=True, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) # Implementation of Feedforward model. self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) @@ -485,10 +518,12 @@ def __init__(self, self.activation = activation - def forward(self, - src: Tensor, - src_mask: Optional[Tensor] = None, - dropout_rate: Optional[float] = 0.0) -> Tensor: + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: r"""Pass the input through the encoder layer. Args: @@ -509,17 +544,19 @@ def forward(self, return x # Self-attention block: - def _sa_block(self, - x: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, - inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: + def _ff_block( + self, inputs: Tensor, dropout_rate: Optional[float] = 0.0 + ) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -548,42 +585,51 @@ class TransformerDecoder(nn.Module): >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory) """ + __constants__ = ['norm'] - def __init__(self, - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps, - num_layers, - attention_temp, - pre_ln): + def __init__( + self, + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps, + num_layers, + attention_temp, + pre_ln, + ): super().__init__() - self.layers = nn.ModuleList([ + self.layers = nn.ModuleList( + [ TransformerDecoderLayer( - d_model, - nhead, - d_hid, - activation, - glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) for _ in range(num_layers) - ]) + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln, + ) + for _ in range(num_layers) + ] + ) self.num_layers = num_layers - self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - - def forward(self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - dropout_rate: Optional[float] = 0.0) -> Any: + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -600,15 +646,16 @@ def forward(self, for idx, mod in enumerate(self.layers): output, cache = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=idx, - dropout_rate=dropout_rate) + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=idx, + dropout_rate=dropout_rate, + ) if self.norm is not None: output = self.norm(output) @@ -647,43 +694,48 @@ class TransformerDecoderLayer(nn.Module): >>> tgt = torch.rand(32, 20, 512) >>> out = decoder_layer(tgt, memory) """ + __constants__ = ['pre_ln'] - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - pre_ln: bool = True, - attention_temp: float = 1.0, - device=None, - dtype=None) -> None: + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + pre_ln: bool = True, + attention_temp: float = 1.0, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=True, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) self.multihead_attn = MultiheadAttention( - d_model, - nhead, - self_attn=False, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=False, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) # Implementation of Feedforward model. self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.glu = glu if self.glu: - self.linear_glu = nn.Linear(dim_feedforward, - dim_feedforward, - **factory_kwargs) + self.linear_glu = nn.Linear( + dim_feedforward, dim_feedforward, **factory_kwargs + ) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln @@ -694,16 +746,17 @@ def __init__(self, self.activation = activation def forward( # pylint: disable=arguments-renamed - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -721,25 +774,27 @@ def forward( # pylint: disable=arguments-renamed x = tgt if self.pre_ln: sa_out, cache = self._sa_block( - self.norm1(x), - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) + self.norm1(x), + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) x = x + sa_out x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( - x, - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) + x, + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) x = self.norm1(x + sa_out) x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) x = self.norm3(x + self._ff_block(x, dropout_rate)) @@ -748,41 +803,43 @@ def forward( # pylint: disable=arguments-renamed # Self-attention block: def _sa_block( # pylint: disable=arguments-renamed - self, - x: Tensor, - attn_mask: Optional[Tensor], - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0) -> Any: + self, + x: Tensor, + attn_mask: Optional[Tensor], + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: x, cache = self.self_attn( - x, - attn_mask=attn_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index, - dropout_rate=dropout_rate) + x, + attn_mask=attn_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: - def _mha_block(self, - x: Tensor, - mem: Tensor, - attn_mask: Optional[Tensor], - dropout_rate: Optional[float] = 0.0) -> Tensor: + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: x, _ = self.multihead_attn( - x, - mem, - attn_mask=attn_mask, - dropout_rate=dropout_rate) + x, mem, attn_mask=attn_mask, dropout_rate=dropout_rate + ) return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, - inputs: Tensor, - dropout_rate: Optional[float] = 0.0) -> Tensor: + def _ff_block( + self, inputs: Tensor, dropout_rate: Optional[float] = 0.0 + ) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) @@ -815,33 +872,38 @@ class MultiheadAttention(nn.Module): >>> attn_output, cache = multihead_attn(x) """ - def __init__(self, - embed_dim: int, - num_heads: int, - self_attn: bool = True, - attention_temp: float = 1.0, - bias: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: + def __init__( + self, + embed_dim: int, + num_heads: int, + self_attn: bool = True, + attention_temp: float = 1.0, + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp - assert self.head_dim * num_heads == self.embed_dim, \ - 'embed_dim must be divisible by num_heads.' + assert self.head_dim * num_heads == self.embed_dim, ( + 'embed_dim must be divisible by num_heads.' + ) factory_kwargs = {'device': device, 'dtype': dtype} if self_attn: # Self-attention. self.in_proj = nn.Linear( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) else: # Encoder-decoder attention. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.kv_proj = nn.Linear( - embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs + ) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self._reset_parameters() @@ -855,15 +917,15 @@ def _reset_parameters(self): normal_(module.bias, std=1e-6) def forward( - self, - x: Tensor, - mem: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None, - dropout_rate: Optional[float] = 0.0 + self, + x: Tensor, + mem: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, ) -> Any: # TODO: (nico) remove default?! r""" Args: @@ -872,7 +934,7 @@ def forward( attention mechanism. See "Attention Is All You Need" for more details. mem: Batch of input sequences of shape (batch size, sequence length, embedding dimensionality) for - encoder-decoder attention. See "Attention Is All You Need" for more + encoder-decoder attention. See "Attention Is All You Need" for more details. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or @@ -915,16 +977,13 @@ def forward( if decode: if loc_cache is None: loc_cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), + 'cached_key': torch.zeros( + (bsz, max_len, embed_dim), dtype=k.dtype, device=k.device + ), + 'cached_value': torch.zeros( + (bsz, max_len, embed_dim), dtype=v.dtype, device=v.device + ), + 'cache_index': torch.tensor(0, dtype=torch.long, device=k.device), } cached_key = loc_cache['cached_key'] cached_value = loc_cache['cached_value'] @@ -932,11 +991,13 @@ def forward( # Shape check of cached keys against query input. expected_shape = (bsz, 1, embed_dim) if expected_shape != x.shape: - raise ValueError('Autoregressive cache shape error, expected query ' - f'shape {expected_shape} instead got {x.shape}.') + raise ValueError( + 'Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.' + ) # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k - cached_value[:, cache_index:cache_index + 1, :] = v + cached_key[:, cache_index : cache_index + 1, :] = k + cached_value[:, cache_index : cache_index + 1, :] = v k = cached_key v = cached_value cache_index += 1 @@ -946,8 +1007,9 @@ def forward( # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = ( + torch.arange(max_len, device=k.device) >= cache_index + ).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) @@ -959,17 +1021,21 @@ def forward( # Check dtype and shape of attention mask. if not decode and attn_mask is not None: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + ) # Ensure attn_mask's dim is 3. if attn_mask.dim() == 3: correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'but should be {correct_3d_size}.') + raise RuntimeError( + f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.' + ) else: raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) # Reshape attention mask to be consistent with q, k, v. attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) @@ -985,10 +1051,12 @@ def forward( # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, attn_dropout_rate) + q, k, v, attn_mask, attn_dropout_rate + ) # Rearrange for output projection. - attn_output = attn_output.transpose(1, 2).contiguous().view( - bsz, tgt_len, embed_dim) + attn_output = ( + attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim) + ) # Output projection. attn_output = self.out_proj(attn_output) diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index 4ec816f2f..53d95d393 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -3,21 +3,18 @@ import contextlib from typing import Any, Dict, Optional, Tuple -from absl import logging import jax import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP import torch.nn.functional as F +from absl import logging +from torch.nn import DataParallel as DP from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.wmt import bleu -from algoperf.workloads.wmt.wmt_pytorch import decode -from algoperf.workloads.wmt.wmt_pytorch import models +from algoperf.workloads.wmt.wmt_pytorch import decode, models from algoperf.workloads.wmt.wmt_pytorch.models import Transformer from algoperf.workloads.wmt.workload import BaseWmtWorkload @@ -28,11 +25,12 @@ class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -47,11 +45,14 @@ def compute_weighted_cross_entropy( valid examples in batch, 'per_example': 1-d array of per-example losses} """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) loss_fn = torch.nn.CrossEntropyLoss( - reduction='none', label_smoothing=label_smoothing) + reduction='none', label_smoothing=label_smoothing + ) if N_GPUS > 1 and not USE_PYTORCH_DDP: loss_fn = DP(loss_fn) @@ -60,24 +61,27 @@ def compute_weighted_cross_entropy( if weights is None: weights = torch.ones_like(targets) per_example_losses = torch.where( - weights.to(torch.bool), per_example_losses, 0.) + weights.to(torch.bool), per_example_losses, 0.0 + ) summed_loss = per_example_losses.sum() n_valid_examples = weights.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } # Primary eval / decode step functions. # ---------------------------------------------------------------------------- @torch.no_grad() - def predict_step(self, - inputs: spec.Tensor, - params: spec.ParameterContainer, - eos_id: int, - max_decode_len: int, - beam_size: int = 4) -> spec.Tensor: + def predict_step( + self, + inputs: spec.Tensor, + params: spec.ParameterContainer, + eos_id: int, + max_decode_len: int, + beam_size: int = 4, + ) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" # params = params.module if isinstance(params, (DP, DDP)) else params if hasattr(params, 'module'): @@ -86,8 +90,8 @@ def predict_step(self, if hasattr(params, '_modules'): params = params._modules - encoder = params["encoder"] - decoder = params["decoder"] + encoder = params['encoder'] + decoder = params['decoder'] else: encoder = params.encoder decoder = params.decoder @@ -98,21 +102,23 @@ def predict_step(self, decoder = DP(decoder) encoded_inputs = torch.repeat_interleave( - encoder(inputs), repeats=beam_size, dim=0) + encoder(inputs), repeats=beam_size, dim=0 + ) raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0) def tokens_ids_to_logits( - flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] + flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] ) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]: """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_flat_cache = decoder( - flat_ids, - encoded_inputs, - raw_inputs, - decode=True, - max_len=max_decode_len, - cache=flat_cache) + flat_ids, + encoded_inputs, + raw_inputs, + decode=True, + max_len=max_decode_len, + cache=flat_cache, + ) # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(dim=1) @@ -121,24 +127,27 @@ def tokens_ids_to_logits( # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( - inputs, - None, - tokens_ids_to_logits, - beam_size=beam_size, - alpha=0.6, - eos_id=eos_id, - max_decode_len=max_decode_len) + inputs, + None, + tokens_ids_to_logits, + beam_size=beam_size, + alpha=0.6, + eos_id=eos_id, + max_decode_len=max_decode_len, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] - def translate_and_calculate_bleu(self, - params: spec.ParameterContainer, - ds_iter: tf.data.Dataset, - num_batches: int, - max_predict_length: int): + def translate_and_calculate_bleu( + self, + params: spec.ParameterContainer, + ds_iter: tf.data.Dataset, + num_batches: int, + max_predict_length: int, + ): """Translates the `ds_iter` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] @@ -146,10 +155,9 @@ def translate_and_calculate_bleu(self, pred_batch = next(ds_iter) inputs = pred_batch['inputs'] targets = pred_batch['targets'] - predicted = self.predict_step(inputs, - params, - decode.EOS_ID, - max_predict_length) + predicted = self.predict_step( + inputs, params, decode.EOS_ID, max_predict_length + ) # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') @@ -177,10 +185,11 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: raise ValueError(f'Unknown activation function {self.activation}.') model = Transformer( - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -195,14 +204,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding.weight' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = models.DROPOUT_RATE + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -214,44 +223,53 @@ def model_fn( model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model( - src=augmented_and_preprocessed_input_batch['inputs'], - tgt=augmented_and_preprocessed_input_batch['targets'], - inputs_positions=augmented_and_preprocessed_input_batch.get( - 'inputs_position', None), - targets_positions=augmented_and_preprocessed_input_batch.get( - 'targets_position', None), - inputs_segmentation=augmented_and_preprocessed_input_batch.get( - 'inputs_segmentation', None), - targets_segmentation=augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None), - dropout_rate=dropout_rate) + src=augmented_and_preprocessed_input_batch['inputs'], + tgt=augmented_and_preprocessed_input_batch['targets'], + inputs_positions=augmented_and_preprocessed_input_batch.get( + 'inputs_position', None + ), + targets_positions=augmented_and_preprocessed_input_batch.get( + 'targets_position', None + ), + inputs_segmentation=augmented_and_preprocessed_input_batch.get( + 'inputs_segmentation', None + ), + targets_segmentation=augmented_and_preprocessed_input_batch.get( + 'targets_segmentation', None + ), + dropout_rate=dropout_rate, + ) return logits_batch, None - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): per_device_batch_size = int(global_batch_size / N_GPUS) n_inputs = 7 if split == 'train' else 3 # The input pipeline has to be created in all processes, because # self._tokenizer has to be available in every process. - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset, + ) # We only need np_iter in one Python process. if RANK != 0: del np_iter @@ -266,14 +284,15 @@ def _build_input_queue(self, tensor = torch.as_tensor(value, dtype=torch.int64, device=DEVICE) tensor_list.append(tensor) batch[key] = ( - tensor[0] if USE_PYTORCH_DDP else tensor.view( - -1, value.shape[-1])) + tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, value.shape[-1]) + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: # During eval, the batch size of the remainder might be different. if split != 'train': per_device_batch_size = torch.tensor( - len(batch['inputs']), dtype=torch.int32, device=DEVICE) + len(batch['inputs']), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) # We don't need to broadcast the batch for the device with RANK == 0. dist.broadcast(torch.stack(tensor_list)[:, 1:].contiguous(), src=0) @@ -281,25 +300,27 @@ def _build_input_queue(self, batch = {} # During eval, the batch size of the remainder might be different. if split != 'train': - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the batch for RANK == 0. - tensor = torch.empty((n_inputs, N_GPUS - 1, per_device_batch_size, 256), - dtype=torch.int64, - device=DEVICE) + tensor = torch.empty( + (n_inputs, N_GPUS - 1, per_device_batch_size, 256), + dtype=torch.int64, + device=DEVICE, + ) dist.broadcast(tensor, src=0) # Note that the order of the keys is important. if split == 'train': keys = [ - 'inputs', - 'inputs_position', - 'inputs_segmentation', - 'targets', - 'targets_position', - 'targets_segmentation', - 'weights', + 'inputs', + 'inputs_position', + 'inputs_segmentation', + 'targets', + 'targets_position', + 'targets_segmentation', + 'weights', ] # For all eval/test splits. else: @@ -309,34 +330,35 @@ def _build_input_queue(self, batch[key] = tensor[n][RANK - 1] yield batch - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + def eval_step( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" targets = batch['targets'] weights = batch['weights'] logits, _ = self.model_fn( - params, - batch, - mode=spec.ForwardPassMode.EVAL, - model_state=None, - rng=None, - update_batch_norm=False) - summed_loss = self.compute_weighted_cross_entropy(logits, - targets, - weights, - 0.0)['summed'] + params, + batch, + mode=spec.ForwardPassMode.EVAL, + model_state=None, + rng=None, + update_batch_norm=False, + ) + summed_loss = self.compute_weighted_cross_entropy( + logits, targets, weights, 0.0 + )['summed'] acc_sum, weight_sum = self.compute_weighted_accuracy( - logits, targets, weights) + logits, targets, weights + ) return { - 'loss': summed_loss, - 'accuracy': acc_sum, - 'denominator': weight_sum, + 'loss': summed_loss, + 'accuracy': acc_sum, + 'denominator': weight_sum, } def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples if USE_PYTORCH_DDP: diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 51b33373d..40e4262dd 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -60,8 +60,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -115,23 +116,26 @@ def activation(self) -> str: def glu(self) -> bool: return False - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): is_training = split == 'train' ds, self._tokenizer = input_pipeline.get_wmt_dataset( - data_rng, - split, - data_dir, - is_training=is_training, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) # Separate function is necessary because the code above has to be executed # when _build_input_queue is called (not when next() is first called on it). @@ -148,19 +152,21 @@ def _input_queue_generator(): @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -168,12 +174,13 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) eval_metrics = {} for _ in range(num_batches): @@ -186,16 +193,17 @@ def _eval_model_on_split(self, eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['bleu'] = self.translate_and_calculate_bleu( - params=params, - ds_iter=self._eval_iters[split], - num_batches=num_batches, - max_predict_length=256) + params=params, + ds_iter=self._eval_iters[split], + num_batches=num_batches, + max_predict_length=256, + ) return eval_results def compute_weighted_accuracy( - self, logits: spec.Tensor, targets: spec.Tensor, - weights: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, targets: spec.Tensor, weights: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: """Compute weighted accuracy for log probs and targets. Args: @@ -207,8 +215,10 @@ def compute_weighted_accuracy( Tuple of scalar summed accuracy and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) accuracy = (logits.argmax(-1) == targets) * weights normalizing_factor = weights.sum() return accuracy.sum(), normalizing_factor @@ -216,17 +226,18 @@ def compute_weighted_accuracy( def _decode_tokens(self, toks: spec.Tensor) -> spec.Tensor: if isinstance(toks, torch.Tensor): toks = toks.cpu().numpy() - valid_toks = toks[:np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32) + valid_toks = toks[: np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32) return self._tokenizer.detokenize(valid_toks).numpy().decode('utf-8') # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -234,7 +245,8 @@ def loss_fn( (not synced across devices). """ return self.compute_weighted_cross_entropy( - logits_batch, - label_batch, - weights=mask_batch, - label_smoothing=label_smoothing) + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing, + ) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..4dd4717e9 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -1,5 +1,5 @@ -""" Registry of workload info -""" +"""Registry of workload info""" + import importlib import inspect import os @@ -9,149 +9,151 @@ BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { - 'cifar': { - 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload' - }, - 'criteo1tb': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', - }, - 'criteo1tb_test': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - }, - 'criteo1tb_layernorm': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' - }, - 'criteo1tb_embed_init': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload' - }, - 'criteo1tb_resnet': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' - }, - 'fastmri': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIWorkload', - }, - 'fastmri_model_size': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIModelSizeWorkload', - }, - 'fastmri_tanh': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRITanhWorkload', - }, - 'fastmri_layernorm': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRILayerNormWorkload', - }, - 'imagenet_resnet': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetWorkload', - }, - 'imagenet_resnet_silu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetSiLUWorkload', - }, - 'imagenet_resnet_gelu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetGELUWorkload', - }, - 'imagenet_resnet_large_bn_init': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', - }, - 'imagenet_vit': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitWorkload', - }, - 'imagenet_vit_glu': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitGluWorkload', - }, - 'imagenet_vit_post_ln': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitPostLNWorkload', - }, - 'imagenet_vit_map': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitMapWorkload', - }, - 'librispeech_conformer': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerWorkload', - }, - 'librispeech_conformer_attention_temperature': { - 'workload_path': - 'librispeech_conformer/librispeech', - 'workload_class_name': - 'LibriSpeechConformerAttentionTemperatureWorkload', - }, - 'librispeech_conformer_layernorm': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', - }, - 'librispeech_conformer_gelu': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerGeluWorkload', - }, - 'librispeech_deepspeech': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', - }, - 'librispeech_deepspeech_tanh': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', - }, - 'librispeech_deepspeech_no_resnet': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', - }, - 'librispeech_deepspeech_norm_and_spec_aug': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', - }, - 'mnist': { - 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' - }, - 'ogbg': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' - }, - 'ogbg_gelu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' - }, - 'ogbg_silu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' - }, - 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgModelSizeWorkload' - }, - 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, - 'wmt_post_ln': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' - }, - 'wmt_attention_temp': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadAttentionTemp' - }, - 'wmt_glu_tanh': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH' - }, + 'cifar': { + 'workload_path': 'cifar/cifar', + 'workload_class_name': 'CifarWorkload', + }, + 'criteo1tb': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', + }, + 'criteo1tb_test': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + }, + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload', + }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload', + }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload', + }, + 'fastmri': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIWorkload', + }, + 'fastmri_model_size': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIModelSizeWorkload', + }, + 'fastmri_tanh': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRITanhWorkload', + }, + 'fastmri_layernorm': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRILayerNormWorkload', + }, + 'imagenet_resnet': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetWorkload', + }, + 'imagenet_resnet_silu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetSiLUWorkload', + }, + 'imagenet_resnet_gelu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetGELUWorkload', + }, + 'imagenet_resnet_large_bn_init': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', + }, + 'imagenet_vit': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitWorkload', + }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitMapWorkload', + }, + 'librispeech_conformer': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerWorkload', + }, + 'librispeech_conformer_attention_temperature': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', + }, + 'librispeech_conformer_layernorm': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', + }, + 'librispeech_conformer_gelu': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerGeluWorkload', + }, + 'librispeech_deepspeech': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', + }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, + 'mnist': { + 'workload_path': 'mnist/mnist', + 'workload_class_name': 'MnistWorkload', + }, + 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'}, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgGeluWorkload', + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgSiluWorkload', + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload', + }, + 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, + 'wmt_post_ln': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadPostLN', + }, + 'wmt_attention_temp': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadAttentionTemp', + }, + 'wmt_glu_tanh': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadGLUTanH', + }, } BASE_WORKLOADS = [ - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'ogbg', - 'wmt' + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'ogbg', + 'wmt', ] @@ -171,10 +173,12 @@ def convert_filepath_to_module(path: str): return base.replace('/', '.') -def import_workload(workload_path: str, - workload_class_name: str, - return_class=False, - workload_init_kwargs=None) -> spec.Workload: +def import_workload( + workload_path: str, + workload_class_name: str, + return_class=False, + workload_init_kwargs=None, +) -> spec.Workload: """Import and add the workload to the registry. This importlib loading is nice to have because it allows runners to avoid @@ -206,9 +210,10 @@ def import_workload(workload_path: str, break if workload_class is None: raise ValueError( - f'Could not find member {workload_class_name} in {workload_path}. ' - 'Make sure the Workload class is spelled correctly and defined in ' - 'the top scope of the module.') + f'Could not find member {workload_class_name} in {workload_path}. ' + 'Make sure the Workload class is spelled correctly and defined in ' + 'the top scope of the module.' + ) if return_class: return workload_class return workload_class(**workload_init_kwargs) From 97255544b1daf604dee84c9f1de9f8d122d2e121 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:21:23 +0200 Subject: [PATCH 108/123] Format docker/ --- docker/scripts/singularity_converter.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index 48c521009..a816eb5c2 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -15,26 +15,27 @@ from spython.main.parse.writers import get_writer # globals -ENTRY_POINT = "/bin/bash" # seems to be a good default +ENTRY_POINT = '/bin/bash' # seems to be a good default FORCE = False # seems to be a good default # -parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser = argparse.ArgumentParser(description='Custom Singularity converter') parser.add_argument( - "-i", "--input", type=str, help="Docker input path", default="Dockerfile") + '-i', '--input', type=str, help='Docker input path', default='Dockerfile' +) parser.add_argument( - "-o", - "--output", - type=str, - help="Singularity output path", - default="Singularity.def", + '-o', + '--output', + type=str, + help='Singularity output path', + default='Singularity.def', ) args = parser.parse_args() INPUT_DOCKERFILE_PATH = args.input OUTPUT_SINGULARITY_PATH = args.output # create Docker parser and Singularity writer -parser = get_parser("docker") -writer = get_writer("singularity") +parser = get_parser('docker') +writer = get_writer('singularity') # parse Dockerfile into Singularity and suppress %files commands recipeParser = parser(INPUT_DOCKERFILE_PATH) @@ -44,5 +45,5 @@ # convert to string and save to output file result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) -with open(OUTPUT_SINGULARITY_PATH, "w") as f: +with open(OUTPUT_SINGULARITY_PATH, 'w') as f: f.write(result) From 7b18fff74c8d08b4f6dd338ab652c67d39252e4a Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:32:47 +0200 Subject: [PATCH 109/123] Lint tests/ --- tests/modeldiffs/diff.py | 7 +++---- tests/modeldiffs/imagenet_resnet/gelu_compare.py | 3 +-- tests/modeldiffs/imagenet_resnet/silu_compare.py | 3 +-- tests/modeldiffs/librispeech_deepspeech/compare.py | 6 +++--- .../librispeech_deepspeech_noresnet/compare.py | 6 ++++-- .../librispeech_deepspeech_normaug/compare.py | 6 ++++-- .../librispeech_deepspeech_tanh/compare.py | 6 ++++-- tests/modeldiffs/torch2jax_utils.py | 2 +- tests/modeldiffs/vanilla_sgd_jax.py | 10 +++++----- tests/modeldiffs/vanilla_sgd_pytorch.py | 8 ++++---- tests/submission_runner_test.py | 8 +++----- tests/test_baselines.py | 8 +++----- tests/test_param_shapes.py | 3 --- tests/test_param_types.py | 6 +----- tests/test_ssim.py | 3 +-- tests/test_traindiffs.py | 12 ++++++------ .../imagenet_resnet/imagenet_jax/workload_test.py | 2 +- 17 files changed, 45 insertions(+), 54 deletions(-) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 8449c3241..a74159028 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,11 +1,10 @@ -from flax import jax_utils -from flax.core import FrozenDict import jax import numpy as np import torch +from flax import jax_utils +from flax.core import FrozenDict -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform +from tests.modeldiffs.torch2jax_utils import Torch2Jax, value_transform # pylint: disable=dangerous-default-value diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 8aa48382d..a92712ddc 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -14,8 +14,7 @@ ImagenetResNetGELUWorkload as PyTorchWorkload, ) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.imagenet_resnet.compare import key_transform -from tests.modeldiffs.imagenet_resnet.compare import sd_transform +from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 393badd18..bbbfd082b 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -14,8 +14,7 @@ ImagenetResNetSiLUWorkload as PyTorchWorkload, ) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.imagenet_resnet.compare import key_transform -from tests.modeldiffs.imagenet_resnet.compare import sd_transform +from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 12cd11513..bd1073524 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -58,9 +58,9 @@ def sd_transform(sd): else: out[k] = sd[k] elif 'LSTM' in ''.join(k): - l = out.get(k[:-1], dict()) - l[k[-1]] = sd[k] - out[k[:-1]] = l + l_tmp = out.get(k[:-1], dict()) + l_tmp[k[-1]] = sd[k] + out[k[:-1]] = l_tmp else: out[k] = sd[k] keys_to_del = [] diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 6a719a84a..8593894e4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -14,8 +14,10 @@ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload, ) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index c8820d397..27e4760a6 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -14,8 +14,10 @@ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload, ) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 0882f3d1e..7990f063b 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -14,8 +14,10 @@ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload, ) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 7c77a152c..4c95ca7e4 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -1,5 +1,5 @@ -from collections import Counter import pprint +from collections import Counter def jax_like_pytorch_statedict(model, state_dict, keys=None): diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index e80e70b8e..aa7bebd4f 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -1,15 +1,15 @@ -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py index d6613479e..6448ac097 100644 --- a/tests/modeldiffs/vanilla_sgd_pytorch.py +++ b/tests/modeldiffs/vanilla_sgd_pytorch.py @@ -1,12 +1,12 @@ import torch from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index b9beb9101..c6c993b7b 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -9,13 +9,11 @@ import os import sys -from absl import flags -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized +from absl import flags, logging +from absl.testing import absltest, parameterized -from algoperf.profiler import PassThroughProfiler import submission_runner +from algoperf.profiler import PassThroughProfiler FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError diff --git a/tests/test_baselines.py b/tests/test_baselines.py index 9ebc50222..c5097a567 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -8,14 +8,12 @@ import os import sys -from absl import flags -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized +from absl import flags, logging +from absl.testing import absltest, parameterized +import submission_runner from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads -import submission_runner FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index 1badd39ed..2243ce52e 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -5,8 +5,6 @@ import pytest from flax.core import FrozenDict -# isort: skip_file -# pylint:disable=line-too-long from algoperf.workloads.cifar.cifar_jax.workload import ( CifarWorkload as JaxCifarWorkload, ) @@ -67,7 +65,6 @@ from algoperf.workloads.wmt.wmt_pytorch.workload import ( WmtWorkload as PyTorchWmtWorkload, ) -# pylint:enable=line-too-long WORKLOADS = [ 'cifar', diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 1583342ff..9f14f7dd8 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -1,11 +1,8 @@ import jax import pytest - from absl import logging -from algoperf import spec -# isort: skip_file -# pylint:disable=line-too-long +from algoperf import spec from algoperf.workloads.cifar.cifar_jax.workload import ( CifarWorkload as JaxCifarWorkload, ) @@ -66,7 +63,6 @@ from algoperf.workloads.wmt.wmt_pytorch.workload import ( WmtWorkload as PyTorchWmtWorkload, ) -# pylint:enable=line-too-long WORKLOADS = [ 'cifar', diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 7d730c251..dcb3f25e0 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -3,11 +3,10 @@ import os from typing import Tuple -from absl.testing import absltest -from absl.testing import parameterized import jax.numpy as jnp import numpy as np import torch +from absl.testing import absltest, parameterized from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.fastmri.fastmri_jax.ssim import ( diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index b1982a3bf..8acfc855a 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -6,13 +6,10 @@ import pickle import subprocess -from subprocess import DEVNULL -from subprocess import run -from subprocess import STDOUT +from subprocess import DEVNULL, STDOUT, run from absl import flags -from absl.testing import absltest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from numpy import allclose FLAGS = flags.FLAGS @@ -92,7 +89,10 @@ def test_workload(self, workload): 'Train Loss (jax)', 'Train Loss (torch)', ] - fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|' + + def fmt(line): + return '|' + '|'.join(map(lambda x: f'{x:^20s}', line)) + '|' + header = fmt(header) pad = (len(header) - len((name))) // 2 print('=' * pad, name, '=' * (len(header) - len(name) - pad), sep='') diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 3d06c9839..60a1af2f2 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -1,8 +1,8 @@ """Tests for imagenet_resnet/imagenet_jax/workload.py.""" -from absl.testing import absltest import jax import jax.numpy as jnp +from absl.testing import absltest from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( From f34bb6d4d0c4d2e971ebef0ff307defc4b5cdbc5 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:34:35 +0200 Subject: [PATCH 110/123] Lint submissions/ --- submissions/submission_checker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py index fcc4e1faf..f8af9fb52 100644 --- a/submissions/submission_checker.py +++ b/submissions/submission_checker.py @@ -29,7 +29,6 @@ import argparse import logging import os -import subprocess SELF_TUNING = 'self_tuning' EXTERNAL_TUNING = 'external_tuning' From 0aeb545f1f78284936f58d8d6abc07fdf7b8da9a Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:47:44 +0200 Subject: [PATCH 111/123] Remove perf. profile tests as it is only a placeholder --- scoring/test_performance_profile.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 scoring/test_performance_profile.py diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py deleted file mode 100644 index 01c96de71..000000000 --- a/scoring/test_performance_profile.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from absl.testing import absltest - -from scoring import performance_profile, scoring_utils - - -class Test(absltest.TestCase): - def test_get_workloads_time_to_target(self): - # TODO(kasimbeg) - pass - - def test_get_best_trial_index(self): - # TODO(kasimbeg) - pass - - def test_compute_performance_profiles(self): - # TODO(kasimbeg) - pass - - -if __name__ == '__main__': - absltest.main() From 4802dfbdba1dab3919879da587ef3eb46d3cc310 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:47:50 +0200 Subject: [PATCH 112/123] Lint scoring/ --- scoring/performance_profile.py | 2 +- scoring/scoring_utils.py | 2 +- scoring/test_scoring_utils.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index d79f705d1..4f2ae9c57 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -52,7 +52,7 @@ try: with open('held_out_workloads_algoperf_v05.json', 'r') as f: HELDOUT_WORKLOADS = json.load(f) -except: +except FileNotFoundError: HELDOUT_WORKLOADS = None # These global variables have to be set according to the current set of diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index ab639f870..5be6c790c 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -200,7 +200,7 @@ def get_experiment_df(experiment_dir): ) try: trial_df = pd.read_csv(eval_measurements_filepath) - except FileNotFoundError as e: + except FileNotFoundError: logging.info(f'Could not read {eval_measurements_filepath}') continue data['trial'] = (trial, experiment_dir) diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py index e3a5f7263..64e141976 100644 --- a/scoring/test_scoring_utils.py +++ b/scoring/test_scoring_utils.py @@ -1,8 +1,6 @@ -import os - from absl.testing import absltest -from scoring import performance_profile, scoring_utils +from scoring import scoring_utils TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log' TEST_DIR = 'test_data/experiment_dir' From 5e97e7850fc7d3c53398c9eb2bdfd56c93eeb8cf Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:49:16 +0200 Subject: [PATCH 113/123] Lint prize_qualification_baselines/ --- .../external_tuning/jax_nadamw_full_budget.py | 4 +- .../jax_nadamw_target_setting.py | 4 +- .../pytorch_nadamw_full_budget.py | 8 +- .../pytorch_nadamw_target_setting.py | 8 +- .../external_tuning/tuning_search_space.json | 96 +++++++++---------- .../self_tuning/jax_nadamw_full_budget.py | 4 +- .../self_tuning/jax_nadamw_target_setting.py | 4 +- .../self_tuning/pytorch_nadamw_full_budget.py | 8 +- .../pytorch_nadamw_target_setting.py | 8 +- 9 files changed, 65 insertions(+), 79 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index aa1a08f69..ed721d167 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -18,11 +18,11 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index f1d7d62e0..fdcdd5348 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -18,11 +18,11 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ecd299988..f6c2faa9d 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 0d8054135..68ff30b2a 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 199f77041..ad68372c6 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -1,53 +1,47 @@ [ - { - "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.2, - "learning_rate": 0.0008445074561975979, - "one_minus_beta1": 0.11042418465, - "beta2": 0.9978504782314613, - "weight_decay": 0.08135402759553023, - "warmup_factor": 0.05 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.0, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.0, - "learning_rate": 0.004958460849689891, - "one_minus_beta1": 0.13625575743, - "beta2": 0.6291854735396584, - "weight_decay": 0.1147386261512052, - "warmup_factor": 0.02 - }, - { - "dropout_rate": 0.1, - "label_smoothing": 0.0, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "label_smoothing": 0.0, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } ] - - - - - - diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index fb322bd5a..0b4e5aba3 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -18,11 +18,11 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 99d996bb9..d3efc3a55 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -18,11 +18,11 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index cc54e3b4e..8e8adbeaa 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -4,13 +4,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index bd065dc06..5d7c444a4 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -4,13 +4,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup From 4ae54188b101229117897ec7662c95b954c5a69f Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 10:57:14 +0200 Subject: [PATCH 114/123] Lint datasets/ --- datasets/librispeech_preprocess.py | 9 ++++----- datasets/librispeech_tokenizer.py | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index c419eb39b..1c216db46 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -4,16 +4,15 @@ import multiprocessing.dummy import os -from os.path import exists import sys import threading import time -from absl import logging import numpy as np import pandas as pd -from pydub import AudioSegment import tensorflow as tf +from absl import logging +from pydub import AudioSegment from datasets import librispeech_tokenizer @@ -84,8 +83,8 @@ def process(index): return utterance_ids with open(trans_file, 'r', encoding='UTF-8') as f: - for l in f: - utt, trans = l.strip().split(' ', maxsplit=1) + for line in f: + utt, trans = line.strip().split(' ', maxsplit=1) audio_path = ( f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac' ) diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index 5b9888cc2..d566d5716 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -8,10 +8,10 @@ import tempfile from typing import Dict -from absl import logging import sentencepiece as spm import tensorflow as tf import tensorflow_text as tftxt +from absl import logging gfile = tf.io.gfile copy = tf.io.gfile.copy @@ -41,8 +41,8 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): logging.info('path does not exist -> %s', trans_file) continue with open(trans_file, 'r', encoding='UTF-8') as f: - for l in f: - _, line = l.strip().split(' ', maxsplit=1) + for lines in f: + _, line = lines.strip().split(' ', maxsplit=1) line = line + '\n' char_count += len(line) if char_count > maxchars: From e3f1b74ae5582b72ee37db62a70e84fa85e152e0 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 11:05:19 +0200 Subject: [PATCH 115/123] Lint reference_algorithms/ --- .../cifar/cifar_jax/submission.py | 4 +- .../cifar/cifar_pytorch/submission.py | 4 +- .../cifar/tuning_search_space.json | 10 +-- .../mnist/discrete_space.json | 32 +++++----- .../mnist/mnist_jax/submission.py | 4 +- .../mnist/tuning_search_space.json | 6 +- .../adafactor/jax/sharded_adafactor.py | 2 +- .../adafactor/jax/submission.py | 4 +- .../adafactor/pytorch/submission.py | 6 +- .../adafactor/tuning_search_space.json | 42 ++++++------ .../tuning_search_space_no_beta1.json | 40 ++++++------ .../paper_baselines/adamw/jax/submission.py | 4 +- .../adamw/pytorch/submission.py | 6 +- .../adamw/tuning_search_space.json | 48 ++++++++------ .../adamw/tuning_search_space_no_beta1.json | 46 +++++++------ .../paper_baselines/lamb/jax/submission.py | 4 +- .../lamb/pytorch/submission.py | 6 +- .../lamb/tuning_search_space.json | 48 ++++++++------ .../lamb/tuning_search_space_no_beta1.json | 46 +++++++------ .../momentum/jax/submission.py | 4 +- .../momentum/pytorch/submission.py | 2 +- .../momentum/tuning_search_space.json | 18 ++++-- .../tuning_search_space_no_beta1.json | 16 +++-- .../paper_baselines/nadamw/jax/submission.py | 4 +- .../nadamw/pytorch/submission.py | 8 +-- .../nadamw/tuning_search_space.json | 20 ++++-- .../nadamw/tuning_search_space_no_beta1.json | 18 ++++-- .../nesterov/jax/submission.py | 4 +- .../nesterov/pytorch/submission.py | 2 +- .../nesterov/tuning_search_space.json | 18 ++++-- .../tuning_search_space_no_beta1.json | 16 +++-- .../paper_baselines/sam/jax/submission.py | 4 +- .../paper_baselines/sam/pytorch/submission.py | 6 +- .../sam/tuning_search_space.json | 54 +++++++++------- .../sam/tuning_search_space_no_beta1.json | 52 ++++++++------- .../shampoo/jax/distributed_shampoo.py | 10 +-- .../paper_baselines/shampoo/jax/submission.py | 4 +- .../shampoo/tuning_search_space.json | 48 ++++++++------ .../shampoo/tuning_search_space_no_beta1.json | 46 +++++++------ .../cosine_warmup.py | 4 +- .../tuning_search_space.json | 41 +++++------- .../tuning_search_space.json | 41 +++++------- .../criteo1tb_resnet/tuning_search_space.json | 41 +++++------- .../fastmri/tuning_search_space.json | 56 ++++++---------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../fastmri_tanh/tuning_search_space.json | 40 +++++------- .../imagenet_resnet/tuning_search_space.json | 64 +++++++------------ .../tuning_search_space.json | 56 ++++++---------- .../tuning_search_space.json | 56 ++++++---------- .../tuning_search_space.json | 40 +++++------- .../imagenet_vit/tuning_search_space.json | 40 +++++------- .../imagenet_vit_glu/tuning_search_space.json | 40 +++++------- .../imagenet_vit_map/tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../target_setting_algorithms/jax_adamw.py | 14 ++-- .../target_setting_algorithms/jax_momentum.py | 14 ++-- .../target_setting_algorithms/jax_nadamw.py | 14 ++-- .../target_setting_algorithms/jax_nesterov.py | 14 ++-- .../jax_submission_base.py | 2 +- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../tuning_search_space.json | 40 +++++------- .../ogbg/tuning_search_space.json | 48 ++++++-------- .../ogbg_gelu/tuning_search_space.json | 42 +++++------- .../ogbg_model_size/tuning_search_space.json | 42 +++++------- .../ogbg_silu/tuning_search_space.json | 42 +++++------- .../pytorch_adamw.py | 12 ++-- .../pytorch_momentum.py | 12 ++-- .../pytorch_nadamw.py | 12 ++-- .../pytorch_nesterov.py | 14 ++-- .../pytorch_submission_base.py | 2 +- .../wmt/tuning_search_space.json | 48 ++++++-------- .../tuning_search_space.json | 48 ++++++-------- .../wmt_glu_tanh/tuning_search_space.json | 48 ++++++-------- .../wmt_post_ln/tuning_search_space.json | 48 ++++++-------- 83 files changed, 988 insertions(+), 1283 deletions(-) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 37f74ac45..d080f2fb3 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 5fd51c3b2..e8080fe34 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -3,9 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec diff --git a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json index 283341705..aa8fcacfd 100644 --- a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json +++ b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json @@ -1,7 +1,7 @@ { - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [200]}, - "l2": {"feasible_points": [5e-4]}, - "momentum": {"feasible_points": [0.9]} + "learning_rate": { "feasible_points": [0.1] }, + "warmup_epochs": { "feasible_points": [5] }, + "num_epochs": { "feasible_points": [200] }, + "l2": { "feasible_points": [5e-4] }, + "momentum": { "feasible_points": [0.9] } } diff --git a/reference_algorithms/development_algorithms/mnist/discrete_space.json b/reference_algorithms/development_algorithms/mnist/discrete_space.json index 310f19e7d..8056d4861 100644 --- a/reference_algorithms/development_algorithms/mnist/discrete_space.json +++ b/reference_algorithms/development_algorithms/mnist/discrete_space.json @@ -1,17 +1,17 @@ [ - { - "learning_rate": 1e-3, - "one_minus_beta_1": 0.999, - "epsilon": 0.9 - }, - { - "learning_rate": 1e-2, - "one_minus_beta_1": 0.99, - "epsilon": 0.99 - }, - { - "learning_rate": 1e-1, - "one_minus_beta_1": 0.9, - "epsilon": 0.999 - } -] \ No newline at end of file + { + "learning_rate": 1e-3, + "one_minus_beta_1": 0.999, + "epsilon": 0.9 + }, + { + "learning_rate": 1e-2, + "one_minus_beta_1": 0.99, + "epsilon": 0.99 + }, + { + "learning_rate": 1e-1, + "one_minus_beta_1": 0.9, + "epsilon": 0.999 + } +] diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 3ef97577f..afdd1bd43 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json index 35b941133..bf7d3f1c1 100644 --- a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json +++ b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json @@ -1,5 +1,5 @@ { - "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"}, - "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"}, - "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]} + "learning_rate": { "min": 1e-4, "max": 1e-2, "scaling": "log" }, + "one_minus_beta_1": { "min": 0.9, "max": 0.999, "scaling": "log" }, + "epsilon": { "feasible_points": [1e-8, 1e-5, 1e-3] } } diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index c83f14a13..32ba97da4 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -30,8 +30,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import jax -from jax import numpy as jnp import optax +from jax import numpy as jnp JTensor = Any NestedJTensor = Any diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 898de35eb..8dcaa6578 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import ( diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index dd831566f..4c96e5562 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -3,12 +3,10 @@ from functools import partial from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json index 5543689ea..37b36e55d 100644 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json @@ -1,20 +1,26 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.45, "scaling": "log" - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 1e-2, + "max": 0.45, + "scaling": "log" + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json index 98a506084..35f106f9d 100644 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json @@ -1,20 +1,24 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 52c8d5ee2..17d0c2fc2 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index faefcd254..5df907160 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -2,12 +2,10 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json index c96b03eda..abdd6e32d 100644 --- a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 2e-2, "max": 0.5, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 2e-2, + "max": 0.5, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json index b8bd2ea49..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index eedbbfc37..168c0579b 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 5cda59d6f..c73f89e71 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -3,12 +3,10 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch +from absl import logging from torch import Tensor -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json index f2fcde461..7de33fe47 100644 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.3, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 5e-2, + "max": 0.3, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json index 8934e512d..0f1bc208a 100644 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 3540f9415..df084c17b 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index bad750857..b3d38b3dd 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -2,10 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import optax import torch import torch.distributed.nn as dist_nn +from absl import logging from torch.optim.lr_scheduler import LambdaLR from algoperf import spec diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json index 8423bdab7..9ec39a6ef 100644 --- a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json @@ -1,21 +1,27 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "min": 5e-3, "max": 0.3, "scaling": "log" + "min": 5e-3, + "max": 0.3, + "scaling": "log" }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json index f874862d8..80f9c7968 100644 --- a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json @@ -1,21 +1,25 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index aa1a08f69..ed721d167 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -18,11 +18,11 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ecd299988..f6c2faa9d 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json index cba20c4c2..a3d322771 100644 --- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json @@ -1,23 +1,29 @@ { "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" + "min": 1e-4, + "max": 1e-2, + "scaling": "log" }, "one_minus_beta1": { - "min": 4e-3, "max": 0.1, "scaling": "log" + "min": 4e-3, + "max": 0.1, + "scaling": "log" }, "beta2": { - "feasible_points": [0.999] + "feasible_points": [0.999] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" + "min": 5e-3, + "max": 1.0, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] } } diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json index 58973eb27..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" + "min": 1e-4, + "max": 1e-2, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "beta2": { - "feasible_points": [0.999] + "feasible_points": [0.999] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" + "min": 5e-3, + "max": 1.0, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] } } diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index d32026212..18e58b3c0 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 77361f472..9d3bfa6e7 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -2,10 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import optax import torch import torch.distributed.nn as dist_nn +from absl import logging from torch.optim.lr_scheduler import LambdaLR from algoperf import spec diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json index 8423bdab7..9ec39a6ef 100644 --- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json @@ -1,21 +1,27 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "min": 5e-3, "max": 0.3, "scaling": "log" + "min": 5e-3, + "max": 0.3, + "scaling": "log" }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json index f874862d8..80f9c7968 100644 --- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json @@ -1,21 +1,25 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index ce6db3ac3..9ab193b56 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index fdd4eb8b7..652ebed1d 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -2,12 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json index 66dae232b..f32058937 100644 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/sam/tuning_search_space.json @@ -1,26 +1,32 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.43, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 5e-2, + "max": 0.43, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-2, + "max": 0.2, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + }, + "rho": { + "feasible_points": [0.01, 0.02, 0.05] + } } diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json index 89c480e7a..ee4e0c3e4 100644 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json @@ -1,26 +1,30 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-2, + "max": 0.2, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + }, + "rho": { + "feasible_points": [0.01, 0.02, 0.05] + } } diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 830dd4816..c719361d3 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -36,17 +36,17 @@ import functools import itertools import logging -from typing import Any, cast, List, NamedTuple, Optional, TypeVar, Union +from typing import Any, List, NamedTuple, Optional, TypeVar, Union, cast import chex -from flax import struct import jax -from jax import lax -from jax.experimental import pjit -from jax.experimental.sparse import linalg import jax.numpy as jnp import numpy as np import optax +from flax import struct +from jax import lax +from jax.experimental import pjit +from jax.experimental.sparse import linalg # Dtype for inverse-pth root routine # Switch to f64 if you have hardware that supports it. Enable the jax flag diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 526afe7d5..8bf4d2dc5 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import ( diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json index 9d804ba0e..58f6f4fd1 100644 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.15, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 1e-2, + "max": 0.15, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json index b8bd2ea49..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py index eeb87cd87..6a2241732 100644 --- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py +++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py @@ -1,9 +1,7 @@ """Implementions of a linear warmup then cosine decay LR schedule.""" import optax -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR def jax_cosine_warmup(step_hint: int, hyperparameters): diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json index bd6c9702f..110138607 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.002517072211464665 - ] - }, - "beta1": { - "feasible_points": [ - 0.9908351643533544 - ] - }, - "beta2": { - "feasible_points": [ - 0.9859568907533993 - ] - }, - "warmup_steps": { - "feasible_points": [ - 799 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.12274552870237089 - ] - } + "learning_rate": { + "feasible_points": [0.002517072211464665] + }, + "beta1": { + "feasible_points": [0.9908351643533544] + }, + "beta2": { + "feasible_points": [0.9859568907533993] + }, + "warmup_steps": { + "feasible_points": [799] + }, + "weight_decay": { + "feasible_points": [0.12274552870237089] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json index 8d128dae1..a7f52681d 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.05493199486120455 - ] - }, - "beta1": { - "feasible_points": [ - 0.954922991734919 - ] - }, - "beta2": { - "feasible_points": [ - 0.9986188074995163 - ] - }, - "warmup_steps": { - "feasible_points": [ - 799 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00011065469792077193 - ] - } + "learning_rate": { + "feasible_points": [0.05493199486120455] + }, + "beta1": { + "feasible_points": [0.954922991734919] + }, + "beta2": { + "feasible_points": [0.9986188074995163] + }, + "warmup_steps": { + "feasible_points": [799] + }, + "weight_decay": { + "feasible_points": [0.00011065469792077193] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json index a33ae2ff5..31ce92bd1 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001493629901423942 - ] - }, - "beta1": { - "feasible_points": [ - 0.9592129978682067 - ] - }, - "beta2": { - "feasible_points": [ - 0.9824918272399145 - ] - }, - "warmup_steps": { - "feasible_points": [ - 399 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00038587516415285595 - ] - } + "learning_rate": { + "feasible_points": [0.001493629901423942] + }, + "beta1": { + "feasible_points": [0.9592129978682067] + }, + "beta2": { + "feasible_points": [0.9824918272399145] + }, + "warmup_steps": { + "feasible_points": [399] + }, + "weight_decay": { + "feasible_points": [0.00038587516415285595] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json index d8b4ed1b9..894ebb9fb 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 0.028609 - ] - }, - "beta1": { - "feasible_points": [ - 0.981543 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1357 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.984398 - ] - }, - "end_factor": { - "feasible_points": [ - 0.01 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.000576 - ] - } + "learning_rate": { + "feasible_points": [0.028609] + }, + "beta1": { + "feasible_points": [0.981543] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [1357] + }, + "decay_steps_factor": { + "feasible_points": [0.984398] + }, + "end_factor": { + "feasible_points": [0.01] + }, + "weight_decay": { + "feasible_points": [0.000576] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json index 7a49ea891..a3aa8ea08 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.008334676559764446 - ] - }, - "beta1": { - "feasible_points": [ - 0.8294338711079317 - ] - }, - "beta2": { - "feasible_points": [ - 0.8551723332825868 - ] - }, - "warmup_steps": { - "feasible_points": [ - 2714 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.01371235755699044 - ] - } + "learning_rate": { + "feasible_points": [0.008334676559764446] + }, + "beta1": { + "feasible_points": [0.8294338711079317] + }, + "beta2": { + "feasible_points": [0.8551723332825868] + }, + "warmup_steps": { + "feasible_points": [2714] + }, + "weight_decay": { + "feasible_points": [0.01371235755699044] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json index 5516242df..21c5ac87c 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.006173154695175443 - ] - }, - "beta1": { - "feasible_points": [ - 0.8496694604806512 - ] - }, - "beta2": { - "feasible_points": [ - 0.4639437428687345 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1357 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.1679001017957879 - ] - } + "learning_rate": { + "feasible_points": [0.006173154695175443] + }, + "beta1": { + "feasible_points": [0.8496694604806512] + }, + "beta2": { + "feasible_points": [0.4639437428687345] + }, + "warmup_steps": { + "feasible_points": [1357] + }, + "weight_decay": { + "feasible_points": [0.1679001017957879] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json index c3f06a686..59d624fe9 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.04037951750205473 - ] - }, - "beta1": { - "feasible_points": [ - 0.9932215932637941 - ] - }, - "beta2": { - "feasible_points": [ - 0.9425306939334134 - ] - }, - "warmup_steps": { - "feasible_points": [ - 542 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14877061239151607 - ] - } + "learning_rate": { + "feasible_points": [0.04037951750205473] + }, + "beta1": { + "feasible_points": [0.9932215932637941] + }, + "beta2": { + "feasible_points": [0.9425306939334134] + }, + "warmup_steps": { + "feasible_points": [542] + }, + "weight_decay": { + "feasible_points": [0.14877061239151607] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json index 649487c48..941bac70e 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json @@ -1,42 +1,26 @@ { - "learning_rate": { - "feasible_points": [ - 4.131896390902391 - ] - }, - "beta1": { - "feasible_points": [ - 0.9274758113254791 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9007765761611038 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 5.6687777311501786e-6 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [4.131896390902391] + }, + "beta1": { + "feasible_points": [0.9274758113254791] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9007765761611038] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [5.6687777311501786e-6] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json index 6524f5a5b..f72a4057d 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 0.3850582234619253 - ] - }, - "beta1": { - "feasible_points": [ - 0.9845129495436189 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9504205232618159 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.7359160785435053e-5 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [0.3850582234619253] + }, + "beta1": { + "feasible_points": [0.9845129495436189] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9504205232618159] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [1.7359160785435053e-5] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json index 7ad32bb60..f58474cc8 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 4.131896390902391 - ] - }, - "beta1": { - "feasible_points": [ - 0.9274758113254791 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9007765761611038 - ] - }, - "end_factor": { - "feasible_points": [ - 0.01 - ] - }, - "weight_decay": { - "feasible_points": [ - 5.6687777311501786e-6 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [4.131896390902391] + }, + "beta1": { + "feasible_points": [0.9274758113254791] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9007765761611038] + }, + "end_factor": { + "feasible_points": [0.01] + }, + "weight_decay": { + "feasible_points": [5.6687777311501786e-6] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json index 4556c6235..c63a87214 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.01897755400372091 - ] - }, - "beta1": { - "feasible_points": [ - 0.9666072782043229 - ] - }, - "beta2": { - "feasible_points": [ - 0.99681600289198 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.015653883841116094 - ] - } + "learning_rate": { + "feasible_points": [0.01897755400372091] + }, + "beta1": { + "feasible_points": [0.9666072782043229] + }, + "beta2": { + "feasible_points": [0.99681600289198] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.015653883841116094] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json index d6f2053ff..94711417c 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.00026032497966327757 - ] - }, - "beta1": { - "feasible_points": [ - 0.9709035036599892 - ] - }, - "beta2": { - "feasible_points": [ - 0.6572080806975734 - ] - }, - "warmup_steps": { - "feasible_points": [ - 13999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.03077045727617869 - ] - } + "learning_rate": { + "feasible_points": [0.00026032497966327757] + }, + "beta1": { + "feasible_points": [0.9709035036599892] + }, + "beta2": { + "feasible_points": [0.6572080806975734] + }, + "warmup_steps": { + "feasible_points": [13999] + }, + "weight_decay": { + "feasible_points": [0.03077045727617869] + } } diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index e6f8d915c..3fa5e4955 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -1,21 +1,21 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index b37464c1c..92d3f3a8f 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -2,21 +2,21 @@ from typing import Callable -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 7d4a1fcc1..6883b17ab 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -3,22 +3,22 @@ from typing import Any, Callable, NamedTuple, Optional, Union import chex -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) # Forked from diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 714cbb225..7f4d1cd86 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -2,21 +2,21 @@ from typing import Callable -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 557e6957c..d9b12f5ca 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple import jax -from jax import lax import jax.numpy as jnp import optax +from jax import lax from algoperf import spec diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json index 482a28931..936833cf3 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.002106913873888147 - ] - }, - "beta1": { - "feasible_points": [ - 0.8231189937738506 - ] - }, - "beta2": { - "feasible_points": [ - 0.8774571227688758 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1199 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.27590534177690645 - ] - } + "learning_rate": { + "feasible_points": [0.002106913873888147] + }, + "beta1": { + "feasible_points": [0.8231189937738506] + }, + "beta2": { + "feasible_points": [0.8774571227688758] + }, + "warmup_steps": { + "feasible_points": [1199] + }, + "weight_decay": { + "feasible_points": [0.27590534177690645] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json index 22f3376b4..faefa750e 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0007852999990476642 - ] - }, - "beta1": { - "feasible_points": [ - 0.6994142393023162 - ] - }, - "beta2": { - "feasible_points": [ - 0.9918636824608852 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.07286322158086678 - ] - } + "learning_rate": { + "feasible_points": [0.0007852999990476642] + }, + "beta1": { + "feasible_points": [0.6994142393023162] + }, + "beta2": { + "feasible_points": [0.9918636824608852] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.07286322158086678] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json index ad200c01b..16ab02525 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.000590120167916659 - ] - }, - "beta1": { - "feasible_points": [ - 0.737199286155609 - ] - }, - "beta2": { - "feasible_points": [ - 0.05919391544031072 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14128519778326312 - ] - } + "learning_rate": { + "feasible_points": [0.000590120167916659] + }, + "beta1": { + "feasible_points": [0.737199286155609] + }, + "beta2": { + "feasible_points": [0.05919391544031072] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.14128519778326312] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json index 8297cf0ae..d596dcd2b 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0014446807792420305 - ] - }, - "beta1": { - "feasible_points": [ - 0.7427148812902895 - ] - }, - "beta2": { - "feasible_points": [ - 0.8993064520764248 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.06875136511682291 - ] - } + "learning_rate": { + "feasible_points": [0.0014446807792420305] + }, + "beta1": { + "feasible_points": [0.7427148812902895] + }, + "beta2": { + "feasible_points": [0.8993064520764248] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.06875136511682291] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json index b31b711f7..dbcbecf78 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0035278622506232458 - ] - }, - "beta1": { - "feasible_points": [ - 0.8192305396005781 - ] - }, - "beta2": { - "feasible_points": [ - 0.495850879212151 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.04339748256184769 - ] - } + "learning_rate": { + "feasible_points": [0.0035278622506232458] + }, + "beta1": { + "feasible_points": [0.8192305396005781] + }, + "beta2": { + "feasible_points": [0.495850879212151] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.04339748256184769] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json index e20a2dae1..bbea133cb 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001308209823469072 - ] - }, - "beta1": { - "feasible_points": [ - 0.9731333693827139 - ] - }, - "beta2": { - "feasible_points": [ - 0.9981232922116359 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.16375311233774334 - ] - } + "learning_rate": { + "feasible_points": [0.001308209823469072] + }, + "beta1": { + "feasible_points": [0.9731333693827139] + }, + "beta2": { + "feasible_points": [0.9981232922116359] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.16375311233774334] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json index 0a9bfb3cf..52fe59d84 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.004958460849689891 - ] - }, - "beta1": { - "feasible_points": [ - 0.863744242567442 - ] - }, - "beta2": { - "feasible_points": [ - 0.6291854735396584 - ] - }, - "warmup_steps": { - "feasible_points": [ - 720 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.1147386261512052 - ] - } + "learning_rate": { + "feasible_points": [0.004958460849689891] + }, + "beta1": { + "feasible_points": [0.863744242567442] + }, + "beta2": { + "feasible_points": [0.6291854735396584] + }, + "warmup_steps": { + "feasible_points": [720] + }, + "weight_decay": { + "feasible_points": [0.1147386261512052] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json index e76a48325..898fc9e36 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0020162740358935045 - ] - }, - "beta1": { - "feasible_points": [ - 0.9604907112078142 - ] - }, - "beta2": { - "feasible_points": [ - 0.8765457000160508 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3600 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.0006149579248633481 - ] - } + "learning_rate": { + "feasible_points": [0.0020162740358935045] + }, + "beta1": { + "feasible_points": [0.9604907112078142] + }, + "beta2": { + "feasible_points": [0.8765457000160508] + }, + "warmup_steps": { + "feasible_points": [3600] + }, + "weight_decay": { + "feasible_points": [0.0006149579248633481] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json index 55f70f9fc..94f150ad3 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0014446807792420305 - ] - }, - "beta1": { - "feasible_points": [ - 0.7427148812902895 - ] - }, - "beta2": { - "feasible_points": [ - 0.8993064520764248 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1800 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.06875136511682291 - ] - } + "learning_rate": { + "feasible_points": [0.0014446807792420305] + }, + "beta1": { + "feasible_points": [0.7427148812902895] + }, + "beta2": { + "feasible_points": [0.8993064520764248] + }, + "warmup_steps": { + "feasible_points": [1800] + }, + "weight_decay": { + "feasible_points": [0.06875136511682291] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json index e5f906688..517e4a455 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.003604759885558324 - ] - }, - "beta1": { - "feasible_points": [ - 0.9931094324430452 - ] - }, - "beta2": { - "feasible_points": [ - 0.9976871843749077 - ] - }, - "warmup_steps": { - "feasible_points": [ - 720 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.120077307855989 - ] - } + "learning_rate": { + "feasible_points": [0.003604759885558324] + }, + "beta1": { + "feasible_points": [0.9931094324430452] + }, + "beta2": { + "feasible_points": [0.9976871843749077] + }, + "warmup_steps": { + "feasible_points": [720] + }, + "weight_decay": { + "feasible_points": [0.120077307855989] + } } diff --git a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json index 0f365a183..266f0f3f5 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 2.4917728606918423 - ] - }, - "beta1": { - "feasible_points": [ - 0.9449369031171744 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.861509027839639 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.2859640541025928e-7 - ] - } + "learning_rate": { + "feasible_points": [2.4917728606918423] + }, + "beta1": { + "feasible_points": [0.9449369031171744] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "decay_steps_factor": { + "feasible_points": [0.861509027839639] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [1.2859640541025928e-7] + } } diff --git a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json index 0749f96d6..b5b4aad30 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.01897755400372091 - ] - }, - "beta1": { - "feasible_points": [ - 0.9666072782043229 - ] - }, - "beta2": { - "feasible_points": [ - 0.99681600289198 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.015653883841116094 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.01897755400372091] + }, + "beta1": { + "feasible_points": [0.9666072782043229] + }, + "beta2": { + "feasible_points": [0.99681600289198] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.015653883841116094] + } +} diff --git a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json index d5af3b03e..b69114b88 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001734480757979605 - ] - }, - "beta1": { - "feasible_points": [ - 0.855609542347586 - ] - }, - "beta2": { - "feasible_points": [ - 0.9834185656478605 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.019843063335529494 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.001734480757979605] + }, + "beta1": { + "feasible_points": [0.855609542347586] + }, + "beta2": { + "feasible_points": [0.9834185656478605] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.019843063335529494] + } +} diff --git a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json index b9f83a5ed..e1512c02a 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.00027866530268792414 - ] - }, - "beta1": { - "feasible_points": [ - 0.9919340993463499 - ] - }, - "beta2": { - "feasible_points": [ - 0.9979843253162892 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00032418357325210813 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.00027866530268792414] + }, + "beta1": { + "feasible_points": [0.9919340993463499] + }, + "beta2": { + "feasible_points": [0.9979843253162892] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.00032418357325210813] + } +} diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py index f2474a706..14e8155d4 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py @@ -4,15 +4,15 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py index 030939de5..a23b835eb 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py @@ -4,18 +4,18 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import +) from reference_algorithms.target_setting_algorithms.jax_momentum import ( create_lr_schedule_fn, ) -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index ceeebda6d..d301a233f 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -8,15 +8,15 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py index ddbcaefdb..3a6294a28 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py @@ -4,18 +4,18 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import ( +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 data_selection, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import ( +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 get_batch_size, -) # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_nesterov import ( +) +from reference_algorithms.target_setting_algorithms.jax_momentum import ( create_lr_schedule_fn, ) -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 update_params, -) # pylint: disable=unused-import +) def init_optimizer_state( diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 36f736a6b..0bef4548f 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn +from absl import logging from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json index f0ef45daa..aee18d976 100644 --- a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.0017486387539278373 - ] - }, - "beta1": { - "feasible_points": [ - 0.9326607383586145 - ] - }, - "beta2": { - "feasible_points": [ - 0.9955159689799007 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08121616522670176 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.0017486387539278373] + }, + "beta1": { + "feasible_points": [0.9326607383586145] + }, + "beta2": { + "feasible_points": [0.9955159689799007] + }, + "warmup_steps": { + "feasible_points": [1999] + }, + "weight_decay": { + "feasible_points": [0.08121616522670176] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json index 266cdedbb..e1ce2229f 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.000590120167916659 - ] - }, - "beta1": { - "feasible_points": [ - 0.737199286155609 - ] - }, - "beta2": { - "feasible_points": [ - 0.05919391544031072 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14128519778326312 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.000590120167916659] + }, + "beta1": { + "feasible_points": [0.737199286155609] + }, + "beta2": { + "feasible_points": [0.05919391544031072] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.14128519778326312] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json index d288d9a49..0ed0f832a 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.000872041489644454 - ] - }, - "beta1": { - "feasible_points": [ - 0.45562164405092065 - ] - }, - "beta2": { - "feasible_points": [ - 0.9982167124443476 - ] - }, - "warmup_steps": { - "feasible_points": [ - 4999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.01536114562763022 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.1 - ] - } + "learning_rate": { + "feasible_points": [0.000872041489644454] + }, + "beta1": { + "feasible_points": [0.45562164405092065] + }, + "beta2": { + "feasible_points": [0.9982167124443476] + }, + "warmup_steps": { + "feasible_points": [4999] + }, + "weight_decay": { + "feasible_points": [0.01536114562763022] + }, + "label_smoothing": { + "feasible_points": [0.1] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json index 1327bcb38..d1b3045c7 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.0003477912008450351 - ] - }, - "beta1": { - "feasible_points": [ - 0.9936632117510711 - ] - }, - "beta2": { - "feasible_points": [ - 0.9967873550453692 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.04120183162940475 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.0003477912008450351] + }, + "beta1": { + "feasible_points": [0.9936632117510711] + }, + "beta2": { + "feasible_points": [0.9967873550453692] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.04120183162940475] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } From 566b6d9ee6f0489fc1e27e567f01bf9d23689fc1 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 11:16:43 +0200 Subject: [PATCH 116/123] Lint algoperf/ --- algoperf/checkpoint_utils.py | 6 +++--- algoperf/data_utils.py | 6 ++---- algoperf/logger_utils.py | 12 ++++++------ algoperf/profiler.py | 4 ++-- algoperf/pytorch_utils.py | 5 ++--- algoperf/random_utils.py | 3 +-- .../workloads/cifar/cifar_jax/input_pipeline.py | 2 +- algoperf/workloads/cifar/cifar_jax/models.py | 2 +- algoperf/workloads/cifar/cifar_pytorch/models.py | 4 +--- .../workloads/cifar/cifar_pytorch/workload.py | 5 +---- algoperf/workloads/cifar/workload.py | 2 +- algoperf/workloads/criteo1tb/workload.py | 2 +- .../imagenet_jax/input_pipeline.py | 5 ++--- .../imagenet_resnet/imagenet_jax/models.py | 2 +- .../imagenet_resnet/imagenet_jax/randaugment.py | 4 ---- .../imagenet_resnet/imagenet_pytorch/models.py | 3 +-- .../imagenet_pytorch/randaugment.py | 2 +- algoperf/workloads/imagenet_resnet/imagenet_v2.py | 3 +-- .../librispeech_conformer/input_pipeline.py | 2 +- .../librispeech_jax/librispeech_preprocessor.py | 4 ++-- .../librispeech_pytorch/preprocessor.py | 4 ++-- .../workloads/librispeech_conformer/metrics.py | 2 +- .../librispeech_pytorch/models.py | 15 +++++++-------- .../workloads/mnist/mnist_pytorch/workload.py | 8 +++----- algoperf/workloads/mnist/workload.py | 5 ++--- algoperf/workloads/ogbg/metrics.py | 4 ++-- algoperf/workloads/ogbg/workload.py | 3 +-- algoperf/workloads/wmt/bleu.py | 10 ++++------ algoperf/workloads/wmt/input_pipeline.py | 4 ++-- algoperf/workloads/wmt/tokenizer.py | 4 ++-- algoperf/workloads/wmt/wmt_jax/decode.py | 2 +- 31 files changed, 58 insertions(+), 81 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 577baaa34..f8cc40599 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -7,14 +7,14 @@ import os from typing import Sequence, Tuple +import jax +import numpy as np +import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax -import numpy as np from tensorflow.io import gfile # pytype: disable=import-error -import torch from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 919ccd125..f08d9d2db 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -7,9 +7,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.utils.data import DistributedSampler -from torch.utils.data import Sampler +from torch.utils.data import DataLoader, DistributedSampler, Sampler from algoperf import spec @@ -266,7 +264,7 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]: next_targets = next_targets.to(self.device, non_blocking=True) if not first: - yield inputs, targets + yield inputs, targets # noqa: F821 else: first = False diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py index 3c4898142..17eea74a6 100644 --- a/algoperf/logger_utils.py +++ b/algoperf/logger_utils.py @@ -11,12 +11,12 @@ import sys from typing import Any, Dict, Optional -from absl import flags -from clu import metric_writers import GPUtil import pandas as pd import psutil import torch.distributed as dist +from absl import flags +from clu import metric_writers from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -198,7 +198,7 @@ def _get_system_hardware_info() -> Dict: try: system_hardware_info['cpu_model_name'] = _get_cpu_model_name() system_hardware_info['cpu_count'] = psutil.cpu_count() - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record cpu information. Continuing without it.') gpus = GPUtil.getGPUs() @@ -207,7 +207,7 @@ def _get_system_hardware_info() -> Dict: system_hardware_info['gpu_model_name'] = gpus[0].name system_hardware_info['gpu_count'] = len(gpus) system_hardware_info['gpu_driver'] = gpus[0].driver - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record gpu information. Continuing without it.') return system_hardware_info @@ -232,7 +232,7 @@ def _get_system_software_info() -> Dict: system_software_info['git_commit_hash'] = _get_git_commit_hash() # Note: do not store git repo url as it may be sensitive or contain a # secret. - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record git information. Continuing without it.') return system_software_info @@ -279,7 +279,7 @@ def _get_workload_properties(workload: spec.Workload) -> Dict: for key in keys: try: attr = getattr(workload, key) - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info( f'Unable to record workload.{key} information. Continuing without it.' ) diff --git a/algoperf/profiler.py b/algoperf/profiler.py index 0e791d3a8..534a5ccfb 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -4,10 +4,10 @@ https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers. """ -from collections import defaultdict -from contextlib import contextmanager import os import time +from collections import defaultdict +from contextlib import contextmanager from typing import Dict, Generator, List, Optional, Tuple import numpy as np diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 429e4d1e2..af09e67fc 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -1,14 +1,13 @@ import os from typing import Tuple -from absl import logging import jax import tensorflow as tf import torch -from torch import nn -from torch import Tensor import torch.distributed as dist import torch.nn.functional as F +from absl import logging +from torch import Tensor, nn from algoperf import spec from algoperf.profiler import Profiler diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 41f4b6b41..1dc773e80 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -2,9 +2,8 @@ from typing import Any, List, Union -from absl import flags -from absl import logging import numpy as np +from absl import flags, logging try: import jax.random as jax_rng diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 3d831c4af..7fbc95bc6 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,10 +8,10 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 8d034796f..95238c997 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -7,8 +7,8 @@ import functools from typing import Any, Callable, Tuple -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 6beef89e6..0e08f5c5a 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -14,11 +14,9 @@ from algoperf.init_utils import pytorch_default_init from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( BasicBlock, -) -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( Bottleneck, + conv1x1, ) -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 class ResNet(nn.Module): diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index d7e858226..f1189bebc 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -12,10 +12,7 @@ from torchvision import transforms from torchvision.datasets import CIFAR10 -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import data_utils, param_utils, pytorch_utils, spec from algoperf.workloads.cifar.cifar_pytorch.models import resnet18 from algoperf.workloads.cifar.workload import BaseCifarWorkload diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 61880fbfa..31636807c 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -7,9 +7,9 @@ import jax import torch +import algoperf.random_utils as prng from algoperf import spec from algoperf.pytorch_utils import pytorch_setup -import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 9fb819203..2cb7e5450 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -4,8 +4,8 @@ import os from typing import Dict, Iterator, Optional, Tuple -from absl import flags import torch.distributed as dist +from absl import flags from algoperf import spec from algoperf.workloads.criteo1tb import input_pipeline diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index fc42f4a5b..f782e50a1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -7,13 +7,12 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils -from algoperf import data_utils -from algoperf import spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 1f3911708..84ad4fe21 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -7,8 +7,8 @@ import functools from typing import Any, Callable, Optional, Tuple -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 03b36e03d..87e218a0c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -11,11 +11,7 @@ from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( rotate_img, -) -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( transform, -) -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( translate, ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index ab3fc4a37..c980faa06 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -8,8 +8,7 @@ from typing import Any, Callable, List, Optional, Type, Union import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from algoperf import spec from algoperf.init_utils import pytorch_default_init diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py index 28ce00650..1c5c0d952 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py @@ -11,8 +11,8 @@ import PIL import torch from torch import Tensor -from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F from algoperf import spec diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 6ffb73367..7a8e38f02 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -8,8 +8,7 @@ import tensorflow_datasets as tfds -from algoperf import data_utils -from algoperf import spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py index 570db07b3..23ce8e3b7 100644 --- a/algoperf/workloads/librispeech_conformer/input_pipeline.py +++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py @@ -4,9 +4,9 @@ import csv -from absl import logging import numpy as np import torch +from absl import logging class LibriSpeechDataset(torch.utils.data.Dataset): diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py index bd36b1bb9..531e68a45 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py @@ -10,11 +10,11 @@ from typing import Any, Optional, Union -from flax import linen as nn -from flax import struct import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn +from flax import struct # mel spectrum constants. _MEL_BREAK_FREQUENCY_HERTZ = 700.0 diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py index f8c1bd0d2..58dd837dc 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py @@ -2,14 +2,14 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/librispeech_preprocessor.py. """ -from dataclasses import dataclass import math +from dataclasses import dataclass from typing import Any, Optional, Union import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn # mel spectrum constants. _MEL_BREAK_FREQUENCY_HERTZ = 700.0 diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py index d5c826575..7dd6a11dc 100644 --- a/algoperf/workloads/librispeech_conformer/metrics.py +++ b/algoperf/workloads/librispeech_conformer/metrics.py @@ -1,8 +1,8 @@ -from clu import metrics import flax import numpy as np import tensorflow as tf import tensorflow_text as tftxt +from clu import metrics gfile = tf.io.gfile diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index ddb7b5c37..aab75da63 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -2,14 +2,14 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. """ -from dataclasses import dataclass import os +from dataclasses import dataclass from typing import Tuple import torch -from torch import nn import torch.distributed.nn as dist_nn import torch.nn.functional as F +from torch import nn from algoperf.workloads.librispeech_conformer.librispeech_pytorch import ( preprocessor, @@ -88,7 +88,6 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) def forward(self, inputs, input_paddings, dropout_rate): - output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -209,7 +208,6 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) def forward(self, inputs, input_paddings, dropout_rate): - padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) @@ -381,9 +379,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, - output_paddings, - dropout_rate) + outputs, output_paddings = self.subsample( + outputs, output_paddings, dropout_rate + ) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) @@ -393,7 +391,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: outputs = outputs + self.ffns[idx]( - outputs, output_paddings, dropout_rate) + outputs, output_paddings, dropout_rate + ) else: outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py index ca861d551..b58898703 100644 --- a/algoperf/workloads/mnist/mnist_pytorch/workload.py +++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py @@ -1,18 +1,16 @@ """MNIST workload implemented in PyTorch.""" -from collections import OrderedDict import contextlib +from collections import OrderedDict from typing import Any, Dict, Iterator, Optional, Tuple import torch -from torch import nn import torch.distributed as dist import torch.nn.functional as F +from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import init_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import init_utils, param_utils, spec from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.mnist.workload import BaseMnistWorkload diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index 20aa975ae..38006b9ac 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -10,10 +10,9 @@ import tensorflow_datasets as tfds import torch -from algoperf import data_utils -from algoperf import spec -from algoperf.pytorch_utils import pytorch_setup import algoperf.random_utils as prng +from algoperf import data_utils, spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 668501788..3e7825219 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -2,14 +2,14 @@ # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py from typing import Any -from clu import metrics import flax import jax import jax.numpy as jnp import numpy as np -from sklearn.metrics import average_precision_score import torch import torch.distributed as dist +from clu import metrics +from sklearn.metrics import average_precision_score from algoperf.pytorch_utils import pytorch_setup diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 1d182fed5..8717e46d6 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -9,8 +9,7 @@ from algoperf import random_utils as prng from algoperf import spec -from algoperf.workloads.ogbg import input_pipeline -from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg import input_pipeline, metrics class BaseOgbgWorkload(spec.Workload): diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index e0064ba51..6e29b1b83 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -5,19 +5,17 @@ https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. """ -from collections import Counter -from collections import namedtuple -from itertools import zip_longest -import logging import math import re import sys -from typing import List, Sequence import unicodedata +from collections import Counter, namedtuple +from itertools import zip_longest +from typing import List, Sequence -from absl import logging import torch import torch.distributed as dist +from absl import logging from algoperf.pytorch_utils import pytorch_setup diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py index 1df1dfc55..3d184cd78 100644 --- a/algoperf/workloads/wmt/input_pipeline.py +++ b/algoperf/workloads/wmt/input_pipeline.py @@ -238,8 +238,8 @@ def preprocess_wmt_data( def length_filter(max_len): def filter_fn(x): source, target = x['inputs'], x['targets'] - l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) - return tf.less(l, max_len + 1) + length = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) + return tf.less(length, max_len + 1) return filter_fn diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py index d94cab808..273e11dfa 100644 --- a/algoperf/workloads/wmt/tokenizer.py +++ b/algoperf/workloads/wmt/tokenizer.py @@ -9,11 +9,11 @@ import time from typing import Any, Dict, Iterable, Tuple -from absl import logging import jax -from sentencepiece import SentencePieceTrainer import tensorflow as tf import tensorflow_text as tftxt +from absl import logging +from sentencepiece import SentencePieceTrainer Features = Dict[str, tf.Tensor] diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py index b5f5f1099..196d9175e 100644 --- a/algoperf/workloads/wmt/wmt_jax/decode.py +++ b/algoperf/workloads/wmt/wmt_jax/decode.py @@ -7,9 +7,9 @@ import flax import jax -from jax import lax import jax.numpy as jnp import numpy as np +from jax import lax # Constants # We assume the default End-of-Sentence token id is 2 (SentencePiece). From e8466489a87b631f2f73f79a4e50c0d8ffef7544 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 12:41:28 +0200 Subject: [PATCH 117/123] Remove unnecessary isort=off commands --- .../external_tuning/jax_nadamw_full_budget.py | 4 ---- .../external_tuning/jax_nadamw_target_setting.py | 4 ---- .../self_tuning/jax_nadamw_full_budget.py | 4 ---- .../self_tuning/jax_nadamw_target_setting.py | 4 ---- reference_algorithms/paper_baselines/nadamw/jax/submission.py | 4 ---- 5 files changed, 20 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ed721d167..62161b3d5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -1,9 +1,6 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. from typing import ( Any, Callable, @@ -15,7 +12,6 @@ Tuple, Union, ) -# isort: on import chex import jax diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index fdcdd5348..9752aef33 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -1,9 +1,6 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. from typing import ( Any, Callable, @@ -15,7 +12,6 @@ Tuple, Union, ) -# isort: on import chex import jax diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 0b4e5aba3..f61a7bdcd 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -1,9 +1,6 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. from typing import ( Any, Callable, @@ -15,7 +12,6 @@ Tuple, Union, ) -# isort: on import chex import jax diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index d3efc3a55..130ebdabe 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -1,9 +1,6 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. from typing import ( Any, Callable, @@ -15,7 +12,6 @@ Tuple, Union, ) -# isort: on import chex import jax diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ed721d167..62161b3d5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -1,9 +1,6 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. from typing import ( Any, Callable, @@ -15,7 +12,6 @@ Tuple, Union, ) -# isort: on import chex import jax From 3e425e07fa89eb25b7be839fa649db03915ea462 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 13:06:29 +0200 Subject: [PATCH 118/123] Update Ruff linting rules in pyproject.toml to include additional options for future use --- pyproject.toml | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2db40f61f..976857e8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,8 +121,36 @@ target-version = "py311" quote-style = "single" [tool.ruff.lint] -extend-select = ["I"] -# Could add (in the future): "E", "F", "UP", "B", "SIM", "PL" +# Could add the commented out rules in the future: +extend-select = [ + "BLE", # disallow catch-all exceptions + "COM", # enforce trailing comma rules + "F", # Pyflakes rules + "FA", # Enforce from __future__ import annotations + "I", # Isort rules + "ICN", # Use common import conventions + "TID", # Some good import practices + # "A", # flake8-builtins: detect shadowed builtins + # "B", # flake8-bugbear: + # "C4", # flake8-comprehensions: catch incorrect use of comprehensions + # "DOC", # pydoclint + # "D", # pydocstyle + # "DTZ", # flake8-datetimez: strict timezone manipulation with datetime + # "E", # pycodestyle errors + # "FBT", # flake8-boolean-trap: detect boolean traps + # "ISC", # flake8-implicit-str-concat: good use of string concatenation + # "N", # pep8-naming: enforce naming conventions + # "NPY", # Some numpy-specific things + # "PL", # Pylint rules + # "PTH", # flake8-use-pathlib: use pathlib instead of os.path + # "RET", # flake8-return: good return practices + # "S", # flake8-bandit: security testing + # "SIM", # flake8-simplify: common simplification rules + # "TC", # flake8-type-checking: enforce importing certain types in a TYPE_CHECKING block + # "TD", # flake8-todo: Be diligent with TODO comments + # "UP", # pyupgrade: Warn if things can changed due to newer versions + # "W", # pycodestyle warnings +] ignore = [ # Conflicting lint rules with Ruff's formatter # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules). From 8bca401e32c739c88fe0028a282e76a340dc4b57 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 23 Jun 2025 13:09:29 +0200 Subject: [PATCH 119/123] Add pylint errors to linting rules --- pyproject.toml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 976857e8c..b6cfa42cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,19 +129,23 @@ extend-select = [ "FA", # Enforce from __future__ import annotations "I", # Isort rules "ICN", # Use common import conventions + "PLE", # Pylint Errors "TID", # Some good import practices # "A", # flake8-builtins: detect shadowed builtins # "B", # flake8-bugbear: # "C4", # flake8-comprehensions: catch incorrect use of comprehensions - # "DOC", # pydoclint # "D", # pydocstyle + # "DOC", # pydoclint # "DTZ", # flake8-datetimez: strict timezone manipulation with datetime # "E", # pycodestyle errors # "FBT", # flake8-boolean-trap: detect boolean traps # "ISC", # flake8-implicit-str-concat: good use of string concatenation # "N", # pep8-naming: enforce naming conventions # "NPY", # Some numpy-specific things - # "PL", # Pylint rules + # "PL", # All Pylint rules + # "PLC", # Pylint Convention + # "PLR", # Pylint Refactor + # "PLW", # Pylint Warnings # "PTH", # flake8-use-pathlib: use pathlib instead of os.path # "RET", # flake8-return: good return practices # "S", # flake8-bandit: security testing From 09aca7f501fe5cce5b76596ffaa04e34dfdbc726 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 25 Jun 2025 11:47:53 +0200 Subject: [PATCH 120/123] Fix formatting --- algoperf/jax_utils.py | 149 ++++++++++++++++---------------- tests/test_jax_utils.py | 182 ++++++++++++++++++---------------------- 2 files changed, 157 insertions(+), 174 deletions(-) diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py index 28a4ba8c9..dab338328 100644 --- a/algoperf/jax_utils.py +++ b/algoperf/jax_utils.py @@ -1,95 +1,92 @@ from collections.abc import Sequence import flax.linen as nn -from flax.linen.module import compact -from flax.linen.module import merge_param -from flax.linen.module import Module -from flax.typing import PRNGKey import jax -from jax import lax -from jax import random import jax.numpy as jnp +from flax.linen.module import Module, compact, merge_param +from flax.typing import PRNGKey +from jax import lax, random # Custom Layers class Dropout(Module): # pylint: disable=line-too-long """Create a dropout layer. - Forked from - https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. - The reference dropout implementation is modified support changes - to dropout rate during training by: - 1) adding rate argument to the __call__ method. - 2) removing the if-else condition to check for edge cases, which - will trigger a recompile for jitted code. - - .. note:: - When using :meth:`Module.apply() `, make sure - to include an RNG seed named ``'dropout'``. Dropout isn't necessary for - variable initialization. - - Example usage:: - - >>> import flax.linen as nn - >>> import jax, jax.numpy as jnp - - >>> class MLP(nn.Module): - ... @nn.compact - ... def __call__(self, x, train): - ... x = nn.Dense(4)(x) - ... x = nn.Dropout(0.5, deterministic=not train)(x) - ... return x - - >>> model = MLP() - >>> x = jnp.ones((1, 3)) - >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout - >>> model.apply(variables, x, train=False) # don't use dropout - Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) - >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout - Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) - - Attributes: - rate: the dropout probability. (_not_ the keep rate!) - broadcast_dims: dimensions that will share the same dropout mask - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` - and masked, whereas if true, no mask is applied and the inputs are - returned as is. - rng_collection: the rng collection name to use when requesting an rng - key. - """ + Forked from + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes + to dropout rate during training by: + 1) adding rate argument to the __call__ method. + 2) removing the if-else condition to check for edge cases, which + will trigger a recompile for jitted code. + + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(4)(x) + ... x = nn.Dropout(0.5, deterministic=not train)(x) + ... return x + + >>> model = MLP() + >>> x = jnp.ones((1, 3)) + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + >>> model.apply(variables, x, train=False) # don't use dropout + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rng_collection: the rng collection name to use when requesting an rng + key. + """ rate: float | None = None broadcast_dims: Sequence[int] = () deterministic: bool | None = None - rng_collection: str = "dropout" + rng_collection: str = 'dropout' legacy: bool = False @compact def __call__( - self, - inputs, - deterministic: bool | None = None, - rate: float | None = None, - rng: PRNGKey | None = None, + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, ): """Applies a random dropout mask to the input. - Args: - inputs: the inputs that should be randomly masked. - deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` - and masked, whereas if true, no mask is applied and the inputs are - returned as is. - rate: the dropout probability. (_not_ the keep rate!) - rng: an optional PRNGKey used as the random key, if not specified, - one will be generated using ``make_rng`` with the - ``rng_collection`` name. - - Returns: - The masked inputs reweighted to preserve mean. - """ - deterministic = merge_param("deterministic", - self.deterministic, - deterministic) + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rate: the dropout probability. (_not_ the keep rate!) + rng: an optional PRNGKey used as the random key, if not specified, + one will be generated using ``make_rng`` with the + ``rng_collection`` name. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = merge_param( + 'deterministic', self.deterministic, deterministic + ) # Override self.rate if rate is passed to __call__ if rate is None: @@ -121,10 +118,12 @@ def __call__( def print_jax_model_summary(model, fake_inputs): """Prints a summary of the jax module.""" tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={ - "force_terminal": False, "force_jupyter": False, "width": 240 - }, + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, + 'force_jupyter': False, + 'width': 240, + }, ) print(tabulate_fn(fake_inputs, train=False)) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py index d54bf47aa..28e506400 100644 --- a/tests/test_jax_utils.py +++ b/tests/test_jax_utils.py @@ -5,14 +5,11 @@ from functools import partial -from absl.testing import absltest -from absl.testing import parameterized import flax.linen as nn import jax import jax.numpy as jnp -from jax.tree_util import tree_leaves -from jax.tree_util import tree_map -from jax.tree_util import tree_structure +from absl.testing import absltest, parameterized +from jax.tree_util import tree_leaves, tree_map, tree_structure from algoperf.jax_utils import Dropout @@ -22,9 +19,9 @@ def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): """ - A custom function to check if two PyTrees are equal, handling floats with - a tolerance. - """ + A custom function to check if two PyTrees are equal, handling floats with + a tolerance. + """ if tree_structure(a) != tree_structure(b): return False @@ -51,28 +48,22 @@ def __call__(self, x, train): class DropoutModel(nn.Module): - @nn.compact def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT): - return Dropout( - rate=dropout_rate, deterministic=not train)( - x, rate=dropout_rate) + return Dropout(rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate + ) class DropoutTest(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, - mode="train"), - dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), - dict( - testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, - mode="train"), - dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), ) def test_forward(self, dropout_rate, mode): - """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and + """Compare forward pass of Dropout layer to flax.linen.Dropout in train and eval mode. """ @@ -82,44 +73,39 @@ def test_forward(self, dropout_rate, mode): orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() - initial_variables_original = orig_model.init({"params": rng}, - fake_batch, - train=False) - initial_variables_custom = cust_model.init({"params": rng}, - fake_batch, - train=False) + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) assert pytrees_are_equal( - initial_variables_original, initial_variables_custom, rtol=1e-6) + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) # forward pass x = jnp.ones((10,)) - train = mode == "train" + train = mode == 'train' y1 = orig_model.apply( - initial_variables_original, - x, - train=train, - rngs={"dropout": dropout_rng}) + initial_variables_original, x, train=train, rngs={'dropout': dropout_rng} + ) y2 = cust_model.apply( - initial_variables_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={"dropout": dropout_rng}, + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, ) assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) @parameterized.named_parameters( - dict( - testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, - mode="train"), - dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), - dict( - testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, - mode="train"), - dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), ) def test_dropout_update(self, dropout_rate, mode): """Call forward pass of Dropout layer with two different dropout rates @@ -132,57 +118,51 @@ def test_dropout_update(self, dropout_rate, mode): orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() - initial_variables_original = orig_model.init({"params": rng}, - fake_batch, - train=False) + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) - initial_variables_custom = cust_model.init({"params": rng}, - fake_batch, - train=False) + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) assert pytrees_are_equal( - initial_variables_original, initial_variables_custom, rtol=1e-6) + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) # forward pass x = jnp.ones((10,)) - train = mode == "train" + train = mode == 'train' y1 = orig_model.apply( - initial_variables_original, - x, - train=train, - rngs={"dropout": dropout_rng}) + initial_variables_original, x, train=train, rngs={'dropout': dropout_rng} + ) _ = cust_model.apply( - initial_variables_custom, - x, - train=train, - dropout_rate=0.9, - rngs={"dropout": dropout_rng}, + initial_variables_custom, + x, + train=train, + dropout_rate=0.9, + rngs={'dropout': dropout_rng}, ) y2 = cust_model.apply( - initial_variables_custom, - x, - train=train, - dropout_rate=dropout_rate, - rngs={"dropout": dropout_rng}, + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, ) assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) @parameterized.named_parameters( - dict( - testcase_name="Dropout, p=0.0, train", dropout_rate=0.0, - mode="train"), - dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"), - dict( - testcase_name="Dropout, p=0.1, train", dropout_rate=0.1, - mode="train"), - dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"), + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), ) def test_jitted_updates(self, dropout_rate, mode): - """ Compare jitted updates with dropout. - """ + """Compare jitted updates with dropout.""" # initialize models rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) @@ -190,42 +170,46 @@ def test_jitted_updates(self, dropout_rate, mode): orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) cust_model = DropoutModel() - initial_variables_original = orig_model.init({"params": rng}, - fake_batch, - train=False) - initial_variables_custom = cust_model.init({"params": rng}, - fake_batch, - train=False) + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) assert pytrees_are_equal( - initial_variables_original, initial_variables_custom, rtol=1e-6) + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) # forward pass x = jnp.ones((10,)) - train = mode == "train" + train = mode == 'train' jitted_original_apply = jax.jit( - partial(orig_model.apply), static_argnames=['train']) + partial(orig_model.apply), static_argnames=['train'] + ) jitted_custom_apply = jax.jit( - partial(cust_model.apply), static_argnames=['train']) + partial(cust_model.apply), static_argnames=['train'] + ) for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: y1 = jitted_original_apply( - initial_variables_original, - x, - train=train, - rngs={"dropout": dropout_rng}) + initial_variables_original, + x, + train=train, + rngs={'dropout': dropout_rng}, + ) for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: y2 = jitted_custom_apply( - initial_variables_custom, - x, - train=train, - dropout_rate=d, - rngs={"dropout": dropout_rng}, + initial_variables_custom, + x, + train=train, + dropout_rate=d, + rngs={'dropout': dropout_rng}, ) assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) -if __name__ == "__main__": +if __name__ == '__main__': absltest.main() From 631f959dc2b63c4db1f0b6a314639add58321af1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 31 Jul 2025 16:17:47 +0000 Subject: [PATCH 121/123] add example sbatch script --- scoring/utils/slurm/run_jobs.sh | 83 +++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 scoring/utils/slurm/run_jobs.sh diff --git a/scoring/utils/slurm/run_jobs.sh b/scoring/utils/slurm/run_jobs.sh new file mode 100644 index 000000000..5fcf8f69e --- /dev/null +++ b/scoring/utils/slurm/run_jobs.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +#SBATCH --nodes=1 # give it a full node +#SBATCH --ntasks-per-node=1 +#SBATCH --array= +#SBATCH --partition=v100 +#SBATCH --gpus-per-node=8 +#SBATCH --exclusive #this will not allow other jobs to run on this cluster +#SBATCH --output=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.out +#SBATCH --error=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.err + +# Usage: sbatch .sh +# This script reads config.json and launches a sbatch job using task +# arrays where each job in the array corresponds to a training run +# for a workload given a random seed and tuning trial index. +# To generate the config.json use make_job_config.py. + +set -x + +# Pull docker image (ATTENTION: you may want to modify this) +REPO="" +IMAGE="" +y | gcloud auth configure-docker $REPO +docker pull $IMAGE +# Job config (ATTENTION: you may want to modify this) +config_file="" # Replace with your config file path +LOGS_BUCKET="" # replace with your bucket used for logging + + +# Function to read a JSON file and extract a value by key +read_json_value() { + local json_file="$1" + local index="$2" + local key="$3" + local value=$(jq -r ".[\"$index\"].$key" "$json_file") + echo "$value" +} + +# Check if jq is installed +if ! command -v jq &> /dev/null +then + echo "jq could not be found. Please install it." + exit 1 +fi + +TASK="$SLURM_ARRAY_TASK_ID" +FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework") +DATASET=$(read_json_value "$config_file" "$TASK" "dataset") +SUBMISSION_PATH=$(read_json_value "$config_file" "$TASK" "submission_path") +FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework") +TUNING_SEARCH_SPACE=$(read_json_value "$config_file" "$TASK" "tuning_search_space") +EXPERIMENT_DIR=$(read_json_value "$config_file" "$TASK" "experiment_dir") +MAX_STEPS=$(read_json_value "$config_file" "$TASK" "max_steps") +RNG_SEED=$(read_json_value "$config_file" "$TASK" "rng_seed") +WORKLOAD=$(read_json_value "$config_file" "$TASK" "workload") +HPARAM_START_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_start_index") +HPARAM_END_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_end_index") +NUM_TUNING_TRIALS=$(read_json_value "$config_file" "$TASK" "num_tuning_trials") +TUNING_RULESET=$(read_json_value "$config_file" "$TASK" "tuning_ruleset") +MAX_GLOBAL_STEPS=$(read_json_value "$config_file" "$MAX_GLOBAL_STEPS" "max_global_steps") + +docker run \ + -v /opt/data/:/data/ \ + -v $HOME/submissions_algorithms/:/algorithmic-efficiency/submissions_algorithms \ + --gpus all \ + --ipc=host \ + $IMAGE \ + -d $DATASET \ + -f $FRAMEWORK \ + -s $SUBMISSION_PATH \ + -w $WORKLOAD \ + -t $TUNING_SEARCH_SPACE \ + -e $EXPERIMENT_DIR \ + -c False \ + -o True \ + --rng_seed $RNG_SEED \ + --hparam_start_index $HPARAM_START_INDEX \ + --hparam_end_index $HPARAM_END_INDEX \ + --num_tuning_trials $NUM_TUNING_TRIALS \ + --tuning_ruleset $TUNING_RULESET \ + --logs_bucket $LOGS_BUCKET \ + -i true \ + -r false \ No newline at end of file From db69fcea7174fb085547237bd09eb7b63fd2b789 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 31 Jul 2025 20:59:57 +0000 Subject: [PATCH 122/123] remove index url for jax installation --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9926b0542..81fb22a6b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ - && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + && pip install -e '.[jax_gpu]'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ @@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_gpu]' \ - && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ + && pip install -e '.[jax_gpu]'; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ From 51eb65fe9364780907fdc860df81015d4f392cff Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 31 Jul 2025 21:00:51 +0000 Subject: [PATCH 123/123] revert change in dockerfile on dev --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 81fb22a6b..9926b0542 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ - && pip install -e '.[jax_gpu]'; \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ @@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ && pip install -e '.[pytorch_gpu]' \ - && pip install -e '.[jax_gpu]'; \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \